update
This commit is contained in:
@@ -5,15 +5,17 @@
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 9th January 2022 12:31:03 am
|
||||
# Last Modified: Tuesday, 11th January 2022 3:06:14 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from utilities.utilities import denorm
|
||||
@@ -182,12 +184,10 @@ class Trainer(object):
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
@@ -231,32 +231,30 @@ class Trainer(object):
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
randindex = [i for i in range(batch_size)]
|
||||
|
||||
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
|
||||
self.gen.train()
|
||||
image1, image2 = self.train_loader.next()
|
||||
random.shuffle(randindex)
|
||||
|
||||
src_image1, src_image2 = self.train_loader.next()
|
||||
|
||||
|
||||
img_att = src_image1
|
||||
img_att = image1
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
img_id = image2 # swap with same id, different pose
|
||||
else:
|
||||
img_id = src_image2[randindex]
|
||||
img_id = image2[randindex] # swap with different face
|
||||
|
||||
src_image1_112 = F.interpolate(src_image1,size=(112,112), mode='bicubic')
|
||||
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
|
||||
|
||||
img_id_112_norm = spnorm(img_id_112)
|
||||
|
||||
latent_id = model.netArc(img_id_112_norm)
|
||||
latent_id = self.arcface(img_id_112)
|
||||
|
||||
latent_id = F.normalize(latent_id, p=2, dim=1)
|
||||
|
||||
losses, img_fake= self.gen(src_image1, latent_id)
|
||||
losses, img_fake= self.gen(image1, latent_id)
|
||||
|
||||
# update Generator weights
|
||||
losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
|
||||
@@ -275,18 +273,6 @@ class Trainer(object):
|
||||
loss_D.backward()
|
||||
optimizer_D.step()
|
||||
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
s_loss = MSE_loss(Gram(value), style_gram[key])
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
|
||||
Reference in New Issue
Block a user