This commit is contained in:
XHChen0528
2022-03-20 16:57:23 +08:00
parent 9efbe03d3f
commit be1f9e6f71
6 changed files with 76 additions and 43 deletions
+4 -4
View File
@@ -54,7 +54,7 @@
"train_scripts\\trainer_GramFM.py": 1647657822.9582758, "train_scripts\\trainer_GramFM.py": 1647657822.9582758,
"utilities\\ImagenetNorm.py": 1647657822.9642777, "utilities\\ImagenetNorm.py": 1647657822.9642777,
"utilities\\reverse2original.py": 1647657822.9662776, "utilities\\reverse2original.py": 1647657822.9662776,
"train_yamls\\train_cycleloss.yaml": 1647700399.7641768, "train_yamls\\train_cycleloss.yaml": 1647704989.2310138,
"train_yamls\\train_GramFM.yaml": 1647657822.9622767, "train_yamls\\train_GramFM.yaml": 1647657822.9622767,
"train_yamls\\train_512FM_Modulation.yaml": 1647657822.961277, "train_yamls\\train_512FM_Modulation.yaml": 1647657822.961277,
"face_crop.py": 1647657822.9422722, "face_crop.py": 1647657822.9422722,
@@ -181,10 +181,10 @@
"arcface_torch\\utils\\utils_config.py": 1647657822.927269, "arcface_torch\\utils\\utils_config.py": 1647657822.927269,
"arcface_torch\\utils\\utils_logging.py": 1647657822.927269, "arcface_torch\\utils\\utils_logging.py": 1647657822.927269,
"arcface_torch\\utils\\__init__.py": 1647657822.9262686, "arcface_torch\\utils\\__init__.py": 1647657822.9262686,
"components\\LSTU.py": 1647697688.593807, "components\\LSTU.py": 1647702612.240765,
"test_scripts\\tester_ID_Pose.py": 1647657822.946273, "test_scripts\\tester_ID_Pose.py": 1647657822.946273,
"train_scripts\\trainer_distillation_mgpu_withrec_importweight.py": 1647657822.9592762, "train_scripts\\trainer_distillation_mgpu_withrec_importweight.py": 1647657822.9592762,
"train_scripts\\trainer_multi_gpu_CUT.py": 1647676964.475, "train_scripts\\trainer_multi_gpu_CUT.py": 1647676964.475,
"train_scripts\\trainer_multi_gpu_cycle.py": 1647699496.9083836, "train_scripts\\trainer_multi_gpu_cycle.py": 1647705628.7020626,
"components\\Generator_LSTU_config.py": 1647697793.0348723 "components\\Generator_LSTU_config.py": 1647704615.1532204
} }
+44 -22
View File
@@ -13,7 +13,7 @@
import torch import torch
from torch import nn from torch import nn
from LSTU import LSTU from components.LSTU import LSTU
# from components.DeConv_Invo import DeConv # from components.DeConv_Invo import DeConv
class InstanceNorm(nn.Module): class InstanceNorm(nn.Module):
@@ -48,7 +48,12 @@ class ApplyStyle(nn.Module):
return x return x
class ResnetBlock_Adain(nn.Module): class ResnetBlock_Adain(nn.Module):
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): def __init__(self,
dim,
latent_size,
padding_type,
activation=nn.ReLU(True),
res_mode="depthwise"):
super(ResnetBlock_Adain, self).__init__() super(ResnetBlock_Adain, self).__init__()
p = 0 p = 0
@@ -61,7 +66,16 @@ class ResnetBlock_Adain(nn.Module):
p = 1 p = 1
else: else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type) raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] if res_mode.lower() == "conv":
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
elif res_mode.lower() == "depthwise":
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,groups=dim, bias=False),
nn.Conv2d(dim, dim, kernel_size=1),
InstanceNorm()]
elif res_mode.lower() == "depthwise_eca":
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,groups=dim, bias=False),
nn.Conv2d(dim, dim, kernel_size=1),
InstanceNorm()]
self.conv1 = nn.Sequential(*conv1) self.conv1 = nn.Sequential(*conv1)
self.style1 = ApplyStyle(latent_size, dim) self.style1 = ApplyStyle(latent_size, dim)
self.act1 = activation self.act1 = activation
@@ -76,7 +90,16 @@ class ResnetBlock_Adain(nn.Module):
p = 1 p = 1
else: else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type) raise NotImplementedError('padding [%s] is not implemented' % padding_type)
if res_mode.lower() == "conv":
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
elif res_mode.lower() == "depthwise":
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,groups=dim, bias=False),
nn.Conv2d(dim, dim, kernel_size=1),
InstanceNorm()]
elif res_mode.lower() == "depthwise_eca":
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p,groups=dim, bias=False),
nn.Conv2d(dim, dim, kernel_size=1),
InstanceNorm()]
self.conv2 = nn.Sequential(*conv2) self.conv2 = nn.Sequential(*conv2)
self.style2 = ApplyStyle(latent_size, dim) self.style2 = ApplyStyle(latent_size, dim)
@@ -104,7 +127,7 @@ class Generator(nn.Module):
up_mode = kwargs["up_mode"] up_mode = kwargs["up_mode"]
aggregator = kwargs["aggregator"] aggregator = kwargs["aggregator"]
res_mode = aggregator res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2) padding_size= int((k_size -1)/2)
padding_type= 'reflect' padding_type= 'reflect'
@@ -122,28 +145,24 @@ class Generator(nn.Module):
# nn.BatchNorm2d(64), activation) # nn.BatchNorm2d(64), activation)
### downsample ### downsample
self.down1 = nn.Sequential( self.down1 = nn.Sequential(
nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1, groups=in_channel), nn.Conv2d(in_channel, in_channel*2, stride=2, kernel_size=3, padding=1, bias=False),
nn.Conv2d(in_channel, in_channel*2, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channel*2), nn.BatchNorm2d(in_channel*2),
activation) activation)
self.down2 = nn.Sequential( self.down2 = nn.Sequential(
nn.Conv2d(in_channel*2, in_channel*2, kernel_size=3, padding=1, groups=in_channel*2), nn.Conv2d(in_channel*2, in_channel*4, stride=2, kernel_size=3, padding=1, bias=False),
nn.Conv2d(in_channel*2, in_channel*4, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channel*4), nn.BatchNorm2d(in_channel*4),
activation) activation)
self.lstu = LSTU(in_channel*4,in_channel*4,in_channel*8,4) # self.lstu = LSTU(in_channel*4,in_channel*4,in_channel*8,4)
self.down3 = nn.Sequential( self.down3 = nn.Sequential(
nn.Conv2d(in_channel*4, in_channel*4, kernel_size=3, padding=1, groups=in_channel*4), nn.Conv2d(in_channel*4, in_channel*8, stride=2, kernel_size=3, padding=1, bias=False),
nn.Conv2d(in_channel*4, in_channel*8, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channel*8), nn.BatchNorm2d(in_channel*8),
activation) activation)
self.down4 = nn.Sequential( self.down4 = nn.Sequential(
nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, padding=1, groups=in_channel*8), nn.Conv2d(in_channel*8, in_channel*8, stride=2, kernel_size=3, padding=1, bias=False),
nn.Conv2d(in_channel*8, in_channel*8, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channel*8), nn.BatchNorm2d(in_channel*8),
activation) activation)
@@ -158,25 +177,29 @@ class Generator(nn.Module):
self.BottleNeck = nn.Sequential(*BN) self.BottleNeck = nn.Sequential(*BN)
self.up4 = nn.Sequential( self.up4 = nn.Sequential(
DeConv(in_channel*8,in_channel*8,3,up_mode=up_mode), nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel*8), nn.BatchNorm2d(in_channel*8),
activation activation
) )
self.up3 = nn.Sequential( self.up3 = nn.Sequential(
DeConv(in_channel*8,in_channel*4,3,up_mode=up_mode), nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channel*8, in_channel*4, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel*4), nn.BatchNorm2d(in_channel*4),
activation activation
) )
self.up2 = nn.Sequential( self.up2 = nn.Sequential(
DeConv(in_channel*4,in_channel*2,3,up_mode=up_mode), nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channel*4, in_channel*2, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel*2), nn.BatchNorm2d(in_channel*2),
activation activation
) )
self.up1 = nn.Sequential( self.up1 = nn.Sequential(
DeConv(in_channel*2,in_channel,3,up_mode=up_mode), nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel), nn.BatchNorm2d(in_channel),
activation activation
) )
@@ -201,17 +224,16 @@ class Generator(nn.Module):
def forward(self, img, id): def forward(self, img, id):
res = self.first_layer(img) res = self.first_layer(img)
res = self.down1(res) res = self.down1(res)
res1 = self.down2(res) res = self.down2(res)
res = self.down3(res1) res = self.down3(res)
res = self.down4(res) res = self.down4(res)
for i in range(len(self.BottleNeck)): for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id) res = self.BottleNeck[i](res, id)
# skip = self.lstu(res1, res)
res = self.up4(res) res = self.up4(res)
res = self.up3(res) res = self.up3(res)
skip = self.lstu(res1) res = self.up2(res) # + skip
res = self.up2(res + skip)
res = self.up1(res) res = self.up1(res)
res = self.last_layer(res) res = self.last_layer(res)
+1 -1
View File
@@ -24,7 +24,7 @@ class LSTU(nn.Module):
): ):
super().__init__() super().__init__()
sig = nn.Sigmoid() sig = nn.Sigmoid()
self.relu = nn.Relu() self.relu = nn.ReLU(True)
self.up_sample = nn.Sequential(nn.ConvTranspose2d(latent_channel, out_channel, kernel_size=4, stride=scale, padding=0, bias=False), self.up_sample = nn.Sequential(nn.ConvTranspose2d(latent_channel, out_channel, kernel_size=4, stride=scale, padding=0, bias=False),
nn.BatchNorm2d(out_channel), sig) nn.BatchNorm2d(out_channel), sig)
+1 -1
View File
@@ -1,3 +1,3 @@
nohup python train_multigpu.py > depthwise_config0.log 2>&1 & nohup python train_multigpu.py > cycle_lstu1.log 2>&1 &
+17 -11
View File
@@ -100,9 +100,11 @@ def init_framework(config, reporter, device, rank):
# arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) # arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
# arcface = arcface1['model'].module # arcface = arcface1['model'].module
arcface = iresnet100(pretrained=False, fp16=False) # arcface = iresnet100(pretrained=False, fp16=False)
arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu')) # arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu'))
arcface.eval() # arcface.eval()
arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
arcface = arcface1['model'].module
# train in GPU # train in GPU
@@ -402,7 +404,7 @@ def train_loop(
latent_fake = arcface(img_fake_down) latent_fake = arcface(img_fake_down)
latent_fake = F.normalize(latent_fake, p=2, dim=1) latent_fake = F.normalize(latent_fake, p=2, dim=1)
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean() loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
loss_G = loss_Gmain + loss_G_ID * id_w
if step%2 == 0: if step%2 == 0:
#G_Rec #G_Rec
rec_fm = l1_loss(feat["3"],real_feat["3"]) rec_fm = l1_loss(feat["3"],real_feat["3"])
@@ -418,7 +420,7 @@ def train_loop(
cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"])
loss_G += cycle_loss * cycle_w + cycle_fm * cycle_fm_w loss_G += cycle_loss * cycle_w + cycle_fm * cycle_fm_w
loss_G = loss_Gmain + loss_G_ID * id_w
g_optimizer.zero_grad(set_to_none=True) g_optimizer.zero_grad(set_to_none=True)
loss_G.backward() loss_G.backward()
with torch.autograd.profiler.record_function('generator_opt'): with torch.autograd.profiler.record_function('generator_opt'):
@@ -447,18 +449,20 @@ def train_loop(
# torch.distributed.all_reduce(ID_Total) # torch.distributed.all_reduce(ID_Total)
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
format(version, elapsed, step, total_step, \ format(version, elapsed, step, total_step, \
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \ loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \
loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
print(epochinformation) print(epochinformation)
reporter.writeInfo(epochinformation) reporter.writeInfo(epochinformation)
if config["logger"] == "tensorboard": if config["logger"] == "tensorboard":
logger.add_scalar('G/G_loss', loss_G.item(), step) logger.add_scalar('G/G_loss', loss_G.item(), step)
logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step) logger.add_scalar('G/cycle_loss', cycle_loss.item(), step)
logger.add_scalar('G/cycle_fm', cycle_fm.item(), step)
logger.add_scalar('G/rec_fm', rec_fm.item(), step)
logger.add_scalar('G/G_ID', loss_G_ID.item(), step) logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
logger.add_scalar('D/D_loss', loss_D.item(), step) logger.add_scalar('D/D_loss', loss_D.item(), step)
logger.add_scalar('D/D_fake', loss_Dgen.item(), step) logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
@@ -466,7 +470,9 @@ def train_loop(
elif config["logger"] == "wandb": elif config["logger"] == "wandb":
logger.log({"G_Loss": loss_G.item()}, step = step) logger.log({"G_Loss": loss_G.item()}, step = step)
logger.log({"G_Rec": loss_G_Rec.item()}, step = step) logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
logger.log({"G_feat_match": feat_match_loss.item()}, step = step) logger.log({"cycle_loss": cycle_loss.item()}, step = step)
logger.log({"cycle_fm": cycle_fm.item()}, step = step)
logger.log({"rec_fm": rec_fm.item()}, step = step)
logger.log({"G_ID": loss_G_ID.item()}, step = step) logger.log({"G_ID": loss_G_ID.item()}, step = step)
logger.log({"D_loss": loss_D.item()}, step = step) logger.log({"D_loss": loss_D.item()}, step = step)
logger.log({"D_fake": loss_Dgen.item()}, step = step) logger.log({"D_fake": loss_Dgen.item()}, step = step)
+8 -3
View File
@@ -7,9 +7,13 @@ model_configs:
script: Generator_LSTU_config script: Generator_LSTU_config
class_name: Generator class_name: Generator
module_params: module_params:
g_conv_dim: 512 id_dim: 512
g_kernel_size: 3 g_kernel_size: 3
in_channel: 64
res_num: 9 res_num: 9
up_mode: bilinear
aggregator: "conv"
res_mode: "conv"
d_model: d_model:
script: projected_discriminator script: projected_discriminator
@@ -19,17 +23,18 @@ model_configs:
interp224: False interp224: False
backbone_kwargs: {} backbone_kwargs: {}
# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
# Training information # Training information
batch_size: 24 batch_size: 20
# Dataset # Dataset
dataloader: VGGFace2HQ_multigpu dataloader: VGGFace2HQ_multigpu
dataset_name: vggface2_hq dataset_name: vggface2_hq
dataset_params: dataset_params:
random_seed: 1234 random_seed: 1234
dataloader_workers: 8 dataloader_workers: 6
eval_dataloader: DIV2K_hdf5 eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval eval_dataset_name: DF2K_H5_Eval