update
This commit is contained in:
@@ -16,8 +16,8 @@ import time
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from utilities.utilities import denorm, Gram, img2tensor255crop
|
||||
from pretrained_weights.vgg import VGG16
|
||||
from utilities.utilities import denorm
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
@@ -92,10 +92,25 @@ class Trainer(object):
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
|
||||
|
||||
|
||||
# id extractor network
|
||||
arcface_ckpt = self.config["arcface_ckpt"]
|
||||
arcface_ckpt = torch.load(arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface_ckpt['model'].module
|
||||
|
||||
|
||||
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
self.gen = self.gen.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
@@ -216,24 +231,50 @@ class Trainer(object):
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
style_gram = {}
|
||||
for key, value in style_features.items():
|
||||
style_gram[key] = Gram(value)
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
|
||||
self.gen.train()
|
||||
|
||||
src_image1, src_image2 = self.train_loader.next()
|
||||
|
||||
|
||||
img_att = src_image1
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
else:
|
||||
img_id = src_image2[randindex]
|
||||
|
||||
src_image1_112 = F.interpolate(src_image1,size=(112,112), mode='bicubic')
|
||||
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
|
||||
|
||||
img_id_112_norm = spnorm(img_id_112)
|
||||
|
||||
latent_id = model.netArc(img_id_112_norm)
|
||||
|
||||
latent_id = F.normalize(latent_id, p=2, dim=1)
|
||||
|
||||
losses, img_fake= model(None, src_image1, latent_id, None, for_G=True)
|
||||
|
||||
# update Generator weights
|
||||
losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
|
||||
loss_dict = dict(zip(model.loss_names, losses))
|
||||
|
||||
loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict['G_ID'] * opt.lambda_id
|
||||
if step%2 == 0:
|
||||
loss_G += loss_dict['G_Rec']
|
||||
|
||||
optimizer_G.zero_grad()
|
||||
loss_G.backward(retain_graph=True)
|
||||
optimizer_G.step()
|
||||
|
||||
loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 + loss_dict['D_GP']
|
||||
optimizer_D.zero_grad()
|
||||
loss_D.backward()
|
||||
optimizer_D.step()
|
||||
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
|
||||
Reference in New Issue
Block a user