This commit is contained in:
chenxuanhong
2022-01-17 13:17:49 +08:00
parent bf2df5c5a6
commit 601d2ee43d
58 changed files with 2748 additions and 5696 deletions
+13 -27
View File
@@ -5,15 +5,17 @@
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 9th January 2022 12:31:03 am
# Last Modified: Tuesday, 11th January 2022 3:06:14 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import time
import random
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from utilities.utilities import denorm
@@ -182,12 +184,10 @@ class Trainer(object):
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
sample_dir = self.config["project_samples"]
@@ -231,32 +231,30 @@ class Trainer(object):
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
randindex = [i for i in range(batch_size)]
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
image1, image2 = self.train_loader.next()
random.shuffle(randindex)
src_image1, src_image2 = self.train_loader.next()
img_att = src_image1
img_att = image1
if step%2 == 0:
img_id = src_image2
img_id = image2 # swap with same id, different pose
else:
img_id = src_image2[randindex]
img_id = image2[randindex] # swap with different face
src_image1_112 = F.interpolate(src_image1,size=(112,112), mode='bicubic')
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
img_id_112_norm = spnorm(img_id_112)
latent_id = model.netArc(img_id_112_norm)
latent_id = self.arcface(img_id_112)
latent_id = F.normalize(latent_id, p=2, dim=1)
losses, img_fake= self.gen(src_image1, latent_id)
losses, img_fake= self.gen(image1, latent_id)
# update Generator weights
losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
@@ -275,18 +273,6 @@ class Trainer(object):
loss_D.backward()
optimizer_D.step()
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
s_loss = MSE_loss(Gram(value), style_gram[key])
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w