update involution

This commit is contained in:
chenxuanhong
2022-02-12 22:13:10 +08:00
parent 5271a9f3c2
commit 9429c6d7be
14 changed files with 906 additions and 20 deletions
+5 -1
View File
@@ -5,7 +5,7 @@
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 8th February 2022 10:48:58 pm
# Last Modified: Friday, 11th February 2022 11:18:47 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -94,6 +94,7 @@ def init_framework(config, reporter, device, rank):
# print and recorde model structure
reporter.writeInfo("Discriminator structure:")
reporter.writeModel(dis.__str__())
arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
arcface = arcface1['model'].module
@@ -428,6 +429,9 @@ def train_loop(
if rank == 0 and (step + 1) % log_freq == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# print("ready to report losses")
# ID_Total= loss_G_ID
# torch.distributed.all_reduce(ID_Total)
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \