From 94534e2e3023d46c10cd74e3c87afa0aa29bd8bf Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Mon, 24 Jan 2022 19:01:00 +0800 Subject: [PATCH] update --- GUI/file_sync/filestate_machine0.json | 28 +- components/Generator_reduce.py | 81 ++---- losses/PatchNCE.py | 218 ++++++++++++++- start_train.sh | 3 + train.py | 10 +- train_scripts/trainer_FM.py | 25 +- train_scripts/trainer_GramFM.py | 351 ++++++++++++++++++++++++ train_yamls/train_512FM.yaml | 2 +- train_yamls/train_512FM_Modulation.yaml | 63 +++++ train_yamls/train_GramFM.yaml | 63 +++++ 10 files changed, 765 insertions(+), 79 deletions(-) create mode 100644 start_train.sh create mode 100644 train_scripts/trainer_GramFM.py create mode 100644 train_yamls/train_512FM_Modulation.yaml create mode 100644 train_yamls/train_GramFM.yaml diff --git a/GUI/file_sync/filestate_machine0.json b/GUI/file_sync/filestate_machine0.json index e65df0a..37d712d 100644 --- a/GUI/file_sync/filestate_machine0.json +++ b/GUI/file_sync/filestate_machine0.json @@ -1,7 +1,7 @@ { "GUI.py": 1642351532.4558506, - "test.py": 1634039043.4872007, - "train.py": 1642408973.489742, + "test.py": 1642733759.8015468, + "train.py": 1643009568.7253726, "components\\Generator.py": 1642347735.351465, "components\\projected_discriminator.py": 1642348101.4661522, "components\\pg_modules\\blocks.py": 1640773190.0, @@ -23,7 +23,7 @@ "test_scripts\\tester_common.py": 1625369535.199175, "test_scripts\\tester_FastNST.py": 1634041357.607633, "train_scripts\\trainer_base.py": 1642396105.3868554, - "train_scripts\\trainer_FM.py": 1642396334.407562, + "train_scripts\\trainer_FM.py": 1642826710.3532298, "train_scripts\\trainer_naiv512.py": 1642315674.9740853, "utilities\\checkpoint_manager.py": 1611123530.6624403, "utilities\\figure.py": 1611123530.6634378, @@ -37,7 +37,23 @@ "utilities\\transfer_checkpoint.py": 1642397157.0163105, "utilities\\utilities.py": 1634019485.0783668, "utilities\\yaml_config.py": 1611123530.6614666, - "train_yamls\\train_512FM.yaml": 1642408586.2747152, - "train_scripts\\trainer_2layer_FM.py": 1642408727.3950334, - "train_yamls\\train_2layer_FM.yaml": 1642408813.9488862 + "train_yamls\\train_512FM.yaml": 1642412254.0831068, + "train_scripts\\trainer_2layer_FM.py": 1642826548.2530458, + "train_yamls\\train_2layer_FM.yaml": 1642411635.5534878, + "components\\Generator_reduce.py": 1642690262.572149, + "insightface_func\\face_detect_crop_multi.py": 1638370471.789609, + "insightface_func\\face_detect_crop_single.py": 1638370471.7967434, + "insightface_func\\__init__.py": 1624197300.011183, + "insightface_func\\utils\\face_align_ffhqandnewarc.py": 1638370471.850638, + "losses\\PatchNCE.py": 1642989384.9713614, + "parsing_model\\model.py": 1626745709.554252, + "parsing_model\\resnet.py": 1626745709.554252, + "test_scripts\\tester_common copy.py": 1625369535.199175, + "test_scripts\\tester_video.py": 1642734397.3307388, + "train_scripts\\trainer_cycleloss.py": 1642580463.495596, + "train_scripts\\trainer_GramFM.py": 1643010471.1077821, + "utilities\\ImagenetNorm.py": 1642732910.5280058, + "utilities\\reverse2original.py": 1642733688.7976837, + "train_yamls\\train_cycleloss.yaml": 1642577741.345273, + "train_yamls\\train_GramFM.yaml": 1643011210.82505 } \ No newline at end of file diff --git a/components/Generator_reduce.py b/components/Generator_reduce.py index 67b68e3..881a0cb 100644 --- a/components/Generator_reduce.py +++ b/components/Generator_reduce.py @@ -5,7 +5,7 @@ # Created Date: Sunday January 16th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 20th January 2022 10:51:02 pm +# Last Modified: Monday, 24th January 2022 6:47:22 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -15,17 +15,16 @@ from torch import nn from torch.nn import init from torch.nn import functional as F -class InstanceNorm(nn.Module): +class Demodule(nn.Module): def __init__(self, epsilon=1e-8): """ @notice: avoid in-place ops. https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 """ - super(InstanceNorm, self).__init__() + super(Demodule, self).__init__() self.epsilon = epsilon def forward(self, x): - x = x - torch.mean(x, (2, 3), True) tmp = torch.mul(x, x) # or x ** 2 tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) return x * tmp @@ -46,9 +45,22 @@ class ApplyStyle(nn.Module): x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1 return x -class ResnetBlock_Adain(nn.Module): +class Modulation(nn.Module): + def __init__(self, latent_size, channels): + super(Modulation, self).__init__() + self.linear = nn.Linear(latent_size, channels) + + def forward(self, x, latent): + style = self.linear(latent) # style => [batch_size, n_channels*2] + shape = [-1, x.size(1), 1, 1] + style = style.view(shape) # [batch_size, 2, n_channels, ...] + #x = x * (style[:, 0] + 1.) + style[:, 1] + x = x * style + return x + +class ResnetBlock_Modulation(nn.Module): def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): - super(ResnetBlock_Adain, self).__init__() + super(ResnetBlock_Modulation, self).__init__() p = 0 conv1 = [] @@ -60,9 +72,9 @@ class ResnetBlock_Adain(nn.Module): p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) - conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] + conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), Demodule()] self.conv1 = nn.Sequential(*conv1) - self.style1 = ApplyStyle(latent_size, dim) + self.style1 = Modulation(latent_size, dim) self.act1 = activation p = 0 @@ -75,52 +87,9 @@ class ResnetBlock_Adain(nn.Module): p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) - conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] + conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), Demodule()] self.conv2 = nn.Sequential(*conv2) - self.style2 = ApplyStyle(latent_size, dim) - - - def forward(self, x, dlatents_in_slice): - y = self.conv1(x) - y = self.style1(y, dlatents_in_slice) - y = self.act1(y) - y = self.conv2(y) - y = self.style2(y, dlatents_in_slice) - out = x + y - return out - -class newres(nn.Module): - def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): - super(ResnetBlock_Adain, self).__init__() - - p = 0 - conv1 = [] - if padding_type == 'reflect': - conv1 += [nn.ReflectionPad2d(1)] - elif padding_type == 'replicate': - conv1 += [nn.ReplicationPad2d(1)] - elif padding_type == 'zero': - p = 1 - else: - raise NotImplementedError('padding [%s] is not implemented' % padding_type) - conv1 += [nn.Conv2d(dim, dim, kernel_size=1), InstanceNorm()] - self.conv1 = nn.Sequential(*conv1) - self.style1 = ApplyStyle(latent_size, dim) - self.act1 = activation - - p = 0 - conv2 = [] - if padding_type == 'reflect': - conv2 += [nn.ReflectionPad2d(1)] - elif padding_type == 'replicate': - conv2 += [nn.ReplicationPad2d(1)] - elif padding_type == 'zero': - p = 1 - else: - raise NotImplementedError('padding [%s] is not implemented' % padding_type) - conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] - self.conv2 = nn.Sequential(*conv2) - self.style2 = ApplyStyle(latent_size, dim) + self.style2 = Modulation(latent_size, dim) def forward(self, x, dlatents_in_slice): @@ -148,7 +117,7 @@ class Generator(nn.Module): activation = nn.ReLU(True) - self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1), + self.stem = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), activation) ### downsample self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), @@ -167,7 +136,7 @@ class Generator(nn.Module): BN = [] for i in range(res_num): BN += [ - ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)] + ResnetBlock_Modulation(512, latent_size=chn, padding_type=padding_type, activation=activation)] self.BottleNeck = nn.Sequential(*BN) self.up4 = nn.Sequential( @@ -210,7 +179,7 @@ class Generator(nn.Module): def forward(self, input, id): x = input # 3*224*224 - skip1 = self.first_layer(x) + skip1 = self.stem(x) skip2 = self.down1(skip1) skip3 = self.down2(skip2) skip4 = self.down3(skip3) diff --git a/losses/PatchNCE.py b/losses/PatchNCE.py index b1dc8d5..aabdedb 100644 --- a/losses/PatchNCE.py +++ b/losses/PatchNCE.py @@ -5,9 +5,225 @@ # Created Date: Friday January 21st 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Friday, 21st January 2022 5:04:43 pm +# Last Modified: Monday, 24th January 2022 9:56:24 am # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# +import torch +import torch.nn as nn +from torch.nn import init +import torch.nn.functional as F +from packaging import version +import numpy as np + + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + # if not amp: + # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training + if initialize_weights: + init_weights(net, init_type, init_gain=init_gain, debug=debug) + return net + +def init_weights(net, init_type='normal', init_gain=0.02, debug=False): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if debug: + print(classname) + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + net.apply(init_func) # apply the initialization function + +class Normalize(nn.Module): + + def __init__(self, power=2): + super(Normalize, self).__init__() + self.power = power + + def forward(self, x): + norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) + out = x.div(norm + 1e-7) + return out + +class PatchSampleF(nn.Module): + def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]): + # potential issues: currently, we use the same patch_ids for multiple images in the batch + super(PatchSampleF, self).__init__() + self.l2norm = Normalize(2) + self.use_mlp = use_mlp + self.nc = nc # hard-coded + self.mlp_init = False + self.init_type = init_type + self.init_gain = init_gain + self.gpu_ids = gpu_ids + + def create_mlp(self, feats): + for mlp_id, feat in enumerate(feats): + input_nc = feat.shape[1] + mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]) + if len(self.gpu_ids) > 0: + mlp.cuda() + setattr(self, 'mlp_%d' % mlp_id, mlp) + init_net(self, self.init_type, self.init_gain, self.gpu_ids) + self.mlp_init = True + + def forward(self, feats, num_patches=64, patch_ids=None): + return_ids = [] + return_feats = [] + if self.use_mlp and not self.mlp_init: + self.create_mlp(feats) + for feat_id, feat in enumerate(feats): + B, H, W = feat.shape[0], feat.shape[2], feat.shape[3] + feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2) + if num_patches > 0: + if patch_ids is not None: + patch_id = patch_ids[feat_id] + else: + # torch.randperm produces cudaErrorIllegalAddress for newer versions of PyTorch. https://github.com/taesungp/contrastive-unpaired-translation/issues/83 + #patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device) + patch_id = np.random.permutation(feat_reshape.shape[1]) + patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device) + x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1]) + else: + x_sample = feat_reshape + patch_id = [] + if self.use_mlp: + mlp = getattr(self, 'mlp_%d' % feat_id) + x_sample = mlp(x_sample) + return_ids.append(patch_id) + x_sample = self.l2norm(x_sample) + + if num_patches == 0: + x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W]) + return_feats.append(x_sample) + return return_feats, return_ids + + + +# def calculate_NCE_loss(src, tgt): +# n_layers = len(self.nce_layers) +# feat_q = self.netG(tgt, self.nce_layers, encode_only=True) + +# if self.opt.flip_equivariance and self.flipped_for_equivariance: +# feat_q = [torch.flip(fq, [3]) for fq in feat_q] + +# feat_k = self.netG(src, self.nce_layers, encode_only=True) +# feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None) +# feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids) + +# total_nce_loss = 0.0 +# for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers): +# loss = crit(f_q, f_k) * self.opt.lambda_NCE +# total_nce_loss += loss.mean() + + # return total_nce_loss / n_layers + +class PatchNCELoss(nn.Module): + def __init__(self,batch_size, nce_T = 0.07): + super().__init__() + self.nce_T = nce_T + self.batch_size = batch_size + self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none') + self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool + + def forward(self, feat_q, feat_k): + num_patches = feat_q.shape[0] + dim = feat_q.shape[1] + feat_k = feat_k.detach() + + # pos logit + l_pos = torch.bmm( + feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1)) + l_pos = l_pos.view(num_patches, 1) + + # neg logit + + # Should the negatives from the other samples of a minibatch be utilized? + # In CUT and FastCUT, we found that it's best to only include negatives + # from the same image. Therefore, we set + # --nce_includes_all_negatives_from_minibatch as False + # However, for single-image translation, the minibatch consists of + # crops from the "same" high-resolution image. + # Therefore, we will include the negatives from the entire minibatch. + # if self.opt.nce_includes_all_negatives_from_minibatch: + # # reshape features as if they are all negatives of minibatch of size 1. + # batch_dim_for_bmm = 1 + # else: + batch_dim_for_bmm = self.batch_size + + # reshape features to batch size + feat_q = feat_q.view(batch_dim_for_bmm, -1, dim) + feat_k = feat_k.view(batch_dim_for_bmm, -1, dim) + npatches = feat_q.size(1) + l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1)) + + # diagonal entries are similarity between same features, and hence meaningless. + # just fill the diagonal with very small number, which is exp(-10) and almost zero + diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :] + l_neg_curbatch.masked_fill_(diagonal, -10.0) + l_neg = l_neg_curbatch.view(-1, npatches) + + out = torch.cat((l_pos, l_neg), dim=1) / self.nce_T + + loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long, + device=feat_q.device)) + + return loss + +if __name__ == "__main__": + batch = 16 + nc = 256 + num_patches = 64 + netF = PatchSampleF(use_mlp=True, init_type='normal', init_gain=0.02, gpu_ids=["cuda:0"], nc=nc) + + feat_q = [torch.ones((batch,nc,32,32)).cuda()] + + # if self.opt.flip_equivariance and self.flipped_for_equivariance: + # feat_q = [torch.flip(fq, [3]) for fq in feat_q] + crit = PatchNCELoss(batch) + feat_k = [torch.ones((batch,nc,32,32)).cuda()] + feat_k_pool, sample_ids = netF(feat_k, num_patches, None) + print(feat_k_pool[0].shape) + feat_q_pool, _ = netF(feat_q, num_patches, sample_ids) + print(feat_q_pool[0].shape) + loss = crit(feat_q_pool[0], feat_k_pool[0]) + print(loss) \ No newline at end of file diff --git a/start_train.sh b/start_train.sh new file mode 100644 index 0000000..0f7a2e0 --- /dev/null +++ b/start_train.sh @@ -0,0 +1,3 @@ + + +nohup python train.py > GramFM.log 2>&1 & \ No newline at end of file diff --git a/train.py b/train.py index dbcd33e..46a2fac 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ # Created Date: Tuesday April 28th 2020 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Monday, 17th January 2022 4:42:53 pm +# Last Modified: Monday, 24th January 2022 3:32:47 pm # Modified By: Chen Xuanhong # Copyright (c) 2020 Shanghai Jiao Tong University ############################################################# @@ -31,9 +31,9 @@ def getParameters(): parser = argparse.ArgumentParser() # general settings - parser.add_argument('-v', '--version', type=str, default='2layerFM', + parser.add_argument('-v', '--version', type=str, default='GramFM', help="version name for train, test, finetune") - parser.add_argument('-t', '--tag', type=str, default='Feature_match', + parser.add_argument('-t', '--tag', type=str, default='Gram_Feature_match', help="tag for current experiment") parser.add_argument('-p', '--phase', type=str, default="train", @@ -46,9 +46,9 @@ def getParameters(): # training parser.add_argument('--experiment_description', type=str, - default="减小重建和feature match的权重,使用2和3的feature作为feature") + default="使用3作为feature, 尝试使用gram矩阵来计算feature matching") - parser.add_argument('--train_yaml', type=str, default="train_2layer_FM.yaml") + parser.add_argument('--train_yaml', type=str, default="train_GramFM.yaml") # system logger parser.add_argument('--logger', type=str, diff --git a/train_scripts/trainer_FM.py b/train_scripts/trainer_FM.py index b817bef..f838550 100644 --- a/train_scripts/trainer_FM.py +++ b/train_scripts/trainer_FM.py @@ -5,7 +5,7 @@ # Created Date: Sunday January 9th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Saturday, 22nd January 2022 12:45:09 pm +# Last Modified: Monday, 24th January 2022 6:56:17 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -28,6 +28,9 @@ class Trainer(TrainerBase): config, reporter): super(Trainer, self).__init__(config, reporter) + + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) @@ -276,25 +279,27 @@ class Trainer(TrainerBase): elapsed = str(datetime.timedelta(seconds=elapsed)) epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ - G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ - D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ - format(self.config["version"], elapsed, step, total_step, \ - loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ - loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ + D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + format(self.config["version"], elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ + loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) print(epochinformation) self.reporter.writeInfo(epochinformation) if self.config["logger"] == "tensorboard": self.logger.add_scalar('G/G_loss', loss_G.item(), step) - self.logger.add_scalar('G/Rec_loss', loss_G_Rec.item(), step) - self.logger.add_scalar('G/Fm_loss', feat_match_loss.item(), step) + self.logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + self.logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step) + self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step) self.logger.add_scalar('D/D_loss', loss_D.item(), step) self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step) self.logger.add_scalar('D/D_real', loss_Dreal.item(), step) elif self.config["logger"] == "wandb": self.logger.log({"G_loss": loss_G.item()}, step = step) - self.logger.log({"Rec_loss": loss_G_Rec.item()}, step = step) - self.logger.log({"Fm_loss": feat_match_loss.item()}, step = step) + self.logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + self.logger.log({"G_feat_match": feat_match_loss.item()}, step = step) + self.logger.log({"G_ID": loss_G_ID.item()}, step = step) self.logger.log({"D_loss": loss_D.item()}, step = step) self.logger.log({"D_fake": loss_Dgen.item()}, step = step) self.logger.log({"D_real": loss_Dreal.item()}, step = step) diff --git a/train_scripts/trainer_GramFM.py b/train_scripts/trainer_GramFM.py new file mode 100644 index 0000000..27f0c37 --- /dev/null +++ b/train_scripts/trainer_GramFM.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Monday, 24th January 2022 6:23:16 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random + +import numpy as np + +import torch +import torch.nn.functional as F +from utilities.plot import plot_batch + +from train_scripts.trainer_base import TrainerBase + +from utilities.utilities import Gram + +class Trainer(TrainerBase): + + def __init__(self, + config, + reporter): + super(Trainer, self).__init__(config, reporter) + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) + + self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + # TODO modify this function to build your models + def init_framework(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + + model_config = self.config["model_configs"] + + if self.config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + dscript_name = "components." + model_config["d_model"]["script"] + + elif self.config["phase"] == "finetune": + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + dscript_name = self.config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + self.reporter.writeInfo("Generator structure:") + self.reporter.writeModel(self.gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + self.dis = dis_class(**model_config["d_model"]["module_params"]) + self.dis.feature_network.requires_grad_(False) + + # print and recorde model structure + self.reporter.writeInfo("Discriminator structure:") + self.reporter.writeModel(self.dis.__str__()) + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + + # train in GPU + if self.config["cuda"] >=0: + self.gen = self.gen.cuda() + self.dis = self.dis.cuda() + self.arcface= self.arcface.cuda() + + self.arcface.eval() + self.arcface.requires_grad_(False) + + # if in finetune phase, load the pretrained checkpoint + if self.config["phase"] == "finetune": + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.gen.load_state_dict(torch.load(model_path)) + + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["discriminator_name"])) + self.dis.load_state_dict(torch.load(model_path)) + + print('loaded trained backbone model step {}...!'.format(self.config["project_checkpoints"])) + + # TODO modify this function to configurate the optimizer of your pipeline + def __setup_optimizers__(self): + g_train_opt = self.config['g_optim_config'] + d_train_opt = self.config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in self.gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + self.reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in self.dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + self.reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = self.config['optim_type'] + + if optim_type == 'Adam': + self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if self.config["phase"] == "finetune": + opt_path = os.path.join(self.config["project_checkpoints"], + "step%d_optim_%s.pth"%(self.config["checkpoint_step"], + self.config["optimizer_names"]["generator_name"])) + self.g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(self.config["project_checkpoints"], + "step%d_optim_%s.pth"%(self.config["checkpoint_step"], + self.config["optimizer_names"]["discriminator_name"])) + self.d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(self.config["project_checkpoints"])) + + + # TODO modify this function to evaluate your model + # Evaluate the checkpoint + def __evaluation__(self, + step = 0, + **kwargs + ): + src_image1 = kwargs["src1"] + src_image2 = kwargs["src2"] + batch_size = self.batch_size + self.gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* self.img_std + self.img_mean).numpy() + for r in range(batch_size): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = self.arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_size): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_size, 1, 1, 1) + img_fake = self.gen(image_infer, id_vector_src1).cpu() + + img_fake = img_fake * self.img_std + img_fake = img_fake + self.img_mean + img_fake = img_fake.numpy() + for j in range(batch_size): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(self.sample_dir, 'step_'+str(step+1)+'.jpg')) + + + + + def train(self): + + ckpt_dir = self.config["project_checkpoints"] + log_freq = self.config["log_step"] + model_freq = self.config["model_save_step"] + sample_freq = self.config["sample_step"] + total_step = self.config["total_step"] + random_seed = self.config["dataset_params"]["random_seed"] + + self.batch_size = self.config["batch_size"] + self.sample_dir = self.config["project_samples"] + self.arcface_ckpt= self.config["arcface_ckpt"] + + + # prep_weights= self.config["layersWeight"] + id_w = self.config["id_weight"] + rec_w = self.config["reconstruct_weight"] + feat_w = self.config["feature_match_weight"] + + + + super().train() + + #===============build losses===================# + # TODO replace below lines to build your losses + MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + random.seed(random_seed) + randindex = [i for i in range(self.batch_size)] + random.shuffle(randindex) + import datetime + for step in range(self.start, total_step): + self.gen.train() + self.dis.train() + for interval in range(2): + random.shuffle(randindex) + src_image1, src_image2 = self.train_loader.next() + + if step%2 == 0: + img_id = src_image2 + else: + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = self.arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + if interval: + + img_fake = self.gen(src_image1, latent_id) + gen_logits,_ = self.dis(img_fake.detach(), None) + loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() + + real_logits,_ = self.dis(src_image2,None) + loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean() + + loss_D = loss_Dgen + loss_Dreal + self.d_optimizer.zero_grad() + loss_D.backward() + self.d_optimizer.step() + else: + + # model.netD.requires_grad_(True) + img_fake = self.gen(src_image1, latent_id) + # G loss + gen_logits,feat = self.dis(img_fake, None) + + loss_Gmain = (-gen_logits).mean() + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = self.arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean() + real_feat = self.dis.get_feature(src_image1) + feat_match_loss = l1_loss(Gram(feat["3"]), Gram(real_feat["3"])) + \ + l1_loss(Gram(feat["2"]), Gram(real_feat["2"])) + loss_G = loss_Gmain + loss_G_ID * id_w + \ + feat_match_loss * feat_w + if step%2 == 0: + #G_Rec + loss_G_Rec = l1_loss(img_fake, src_image1) + loss_G += loss_G_Rec * rec_w + + self.g_optimizer.zero_grad() + loss_G.backward() + self.g_optimizer.step() + + # Print out log info + if (step + 1) % log_freq == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + + epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ + G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ + D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + format(self.config["version"], elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ + loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + print(epochinformation) + self.reporter.writeInfo(epochinformation) + + if self.config["logger"] == "tensorboard": + self.logger.add_scalar('G/G_loss', loss_G.item(), step) + self.logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + self.logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step) + self.logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + self.logger.add_scalar('D/D_loss', loss_D.item(), step) + self.logger.add_scalar('D/D_fake', loss_Dgen.item(), step) + self.logger.add_scalar('D/D_real', loss_Dreal.item(), step) + elif self.config["logger"] == "wandb": + self.logger.log({"G_loss": loss_G.item()}, step = step) + self.logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + self.logger.log({"G_feat_match": feat_match_loss.item()}, step = step) + self.logger.log({"G_ID": loss_G_ID.item()}, step = step) + self.logger.log({"D_loss": loss_D.item()}, step = step) + self.logger.log({"D_fake": loss_Dgen.item()}, step = step) + self.logger.log({"D_real": loss_Dreal.item()}, step = step) + + if (step + 1) % sample_freq == 0: + self.__evaluation__( + step = step, + **{ + "src1": src_image1, + "src2": src_image2 + }) + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if (step+1) % model_freq==0: + + torch.save(self.gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + self.config["checkpoint_names"]["generator_name"]))) + torch.save(self.dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + self.config["checkpoint_names"]["discriminator_name"]))) + + torch.save(self.g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + self.config["checkpoint_names"]["generator_name"]))) + + torch.save(self.d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + self.config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + + self.__evaluation__( + step = step, + **{ + "src1": src_image1, + "src2": src_image2 + }) \ No newline at end of file diff --git a/train_yamls/train_512FM.yaml b/train_yamls/train_512FM.yaml index 59dc115..d7fe3b4 100644 --- a/train_yamls/train_512FM.yaml +++ b/train_yamls/train_512FM.yaml @@ -50,7 +50,7 @@ d_optim_config: eps: !!float 1e-8 id_weight: 20.0 -reconstruct_weight: 1.0 +reconstruct_weight: 10.0 feature_match_weight: 10.0 # Log diff --git a/train_yamls/train_512FM_Modulation.yaml b/train_yamls/train_512FM_Modulation.yaml new file mode 100644 index 0000000..7d119ad --- /dev/null +++ b/train_yamls/train_512FM_Modulation.yaml @@ -0,0 +1,63 @@ +# Related scripts +train_script_name: FM + +# models' scripts +model_configs: + g_model: + script: Generator_reduce + class_name: Generator + module_params: + g_conv_dim: 512 + g_kernel_size: 3 + res_num: 9 + + d_model: + script: projected_discriminator + class_name: ProjectedDiscriminator + module_params: + diffaug: False + interp224: False + backbone_kwargs: {} + +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 12 + +# Dataset +dataloader: VGGFace2HQ +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 8 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +id_weight: 20.0 +reconstruct_weight: 10.0 +feature_match_weight: 10.0 + +# Log +log_step: 300 +model_save_step: 10000 +sample_step: 1000 +total_step: 1000000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_GramFM.yaml b/train_yamls/train_GramFM.yaml new file mode 100644 index 0000000..2962cc6 --- /dev/null +++ b/train_yamls/train_GramFM.yaml @@ -0,0 +1,63 @@ +# Related scripts +train_script_name: GramFM + +# models' scripts +model_configs: + g_model: + script: Generator + class_name: Generator + module_params: + g_conv_dim: 512 + g_kernel_size: 3 + res_num: 9 + + d_model: + script: projected_discriminator + class_name: ProjectedDiscriminator + module_params: + diffaug: False + interp224: False + backbone_kwargs: {} + +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 12 + +# Dataset +dataloader: VGGFace2HQ +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 8 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +id_weight: 20.0 +reconstruct_weight: 10.0 +feature_match_weight: 100.0 + +# Log +log_step: 300 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file