This commit is contained in:
Xuanhong Chen
2022-01-10 17:04:25 +08:00
parent 3783ef0e75
commit 591c650dd9
5 changed files with 314 additions and 174 deletions
+59 -18
View File
@@ -16,8 +16,8 @@ import time
import torch
from torchvision.utils import save_image
from utilities.utilities import denorm, Gram, img2tensor255crop
from pretrained_weights.vgg import VGG16
from utilities.utilities import denorm
class Trainer(object):
@@ -92,10 +92,25 @@ class Trainer(object):
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
# id extractor network
arcface_ckpt = self.config["arcface_ckpt"]
arcface_ckpt = torch.load(arcface_ckpt, map_location=torch.device("cpu"))
self.arcface = arcface_ckpt['model'].module
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
self.gen = self.gen.cuda()
self.arcface = self.arcface.cuda()
self.arcface.eval()
self.arcface.requires_grad_(False)
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
@@ -216,24 +231,50 @@ class Trainer(object):
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
VGG = VGG16().cuda()
MEAN_VAL = 127.5
SCALE_VAL= 127.5
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
style_gram = {}
for key, value in style_features.items():
style_gram[key] = Gram(value)
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
src_image1, src_image2 = self.train_loader.next()
img_att = src_image1
if step%2 == 0:
img_id = src_image2
else:
img_id = src_image2[randindex]
src_image1_112 = F.interpolate(src_image1,size=(112,112), mode='bicubic')
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
img_id_112_norm = spnorm(img_id_112)
latent_id = model.netArc(img_id_112_norm)
latent_id = F.normalize(latent_id, p=2, dim=1)
losses, img_fake= model(None, src_image1, latent_id, None, for_G=True)
# update Generator weights
losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
loss_dict = dict(zip(model.loss_names, losses))
loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict['G_ID'] * opt.lambda_id
if step%2 == 0:
loss_G += loss_dict['G_Rec']
optimizer_G.zero_grad()
loss_G.backward(retain_graph=True)
optimizer_G.step()
loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 + loss_dict['D_GP']
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
self.gen.train()
content_images = self.train_loader.next()