update multi gpu
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 8th February 2022 2:29:34 pm
|
||||
# Last Modified: Tuesday, 8th February 2022 10:48:58 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -13,6 +13,7 @@
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
@@ -63,7 +64,13 @@ def init_framework(config, reporter, device, rank):
|
||||
|
||||
if config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
file1 = os.path.join("components", model_config["g_model"]["script"]+".py")
|
||||
tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py")
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
dscript_name = "components." + model_config["d_model"]["script"]
|
||||
file1 = os.path.join("components", model_config["d_model"]["script"]+".py")
|
||||
tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py")
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
|
||||
elif config["phase"] == "finetune":
|
||||
gscript_name = config["com_base"] + model_config["g_model"]["script"]
|
||||
@@ -92,6 +99,20 @@ def init_framework(config, reporter, device, rank):
|
||||
|
||||
# train in GPU
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if config["phase"] == "finetune":
|
||||
model_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(config["ckpt"],
|
||||
config["checkpoint_names"]["generator_name"]))
|
||||
gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu"))
|
||||
|
||||
model_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(config["ckpt"],
|
||||
config["checkpoint_names"]["discriminator_name"]))
|
||||
dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu"))
|
||||
|
||||
print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"]))
|
||||
|
||||
gen = gen.to(device)
|
||||
dis = dis.to(device)
|
||||
arcface= arcface.to(device)
|
||||
@@ -99,19 +120,7 @@ def init_framework(config, reporter, device, rank):
|
||||
arcface.eval()
|
||||
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if config["phase"] == "finetune":
|
||||
model_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(config["checkpoint_step"],
|
||||
config["checkpoint_names"]["generator_name"]))
|
||||
gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
model_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(config["checkpoint_step"],
|
||||
config["checkpoint_names"]["discriminator_name"]))
|
||||
dis.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"]))
|
||||
|
||||
return gen, dis, arcface
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
@@ -149,12 +158,12 @@ def setup_optimizers(config, reporter, gen, dis, rank):
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
if config["phase"] == "finetune":
|
||||
opt_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(config["checkpoint_step"],
|
||||
"step%d_optim_%s.pth"%(config["ckpt"],
|
||||
config["optimizer_names"]["generator_name"]))
|
||||
g_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
opt_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(config["checkpoint_step"],
|
||||
"step%d_optim_%s.pth"%(config["ckpt"],
|
||||
config["optimizer_names"]["discriminator_name"]))
|
||||
d_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
@@ -270,19 +279,19 @@ def train_loop(
|
||||
# TODO replace below lines to build your losses
|
||||
# MSE_loss = torch.nn.MSELoss()
|
||||
l1_loss = torch.nn.L1Loss()
|
||||
cos_loss = cosin_metric
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
|
||||
g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank)
|
||||
|
||||
# Initialize logs.
|
||||
if rank == 0:
|
||||
print('Initializing logs...')
|
||||
if rank == 0:
|
||||
#==============build tensorboard=================#
|
||||
if config["logger"] == "tensorboard":
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"])
|
||||
logger = tensorboard_writer
|
||||
tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"])
|
||||
logger = tensorboard_writer
|
||||
|
||||
elif config["logger"] == "wandb":
|
||||
import wandb
|
||||
wandb.init(project="Simswap_HQ", entity="xhchen", notes="512",
|
||||
@@ -297,11 +306,10 @@ def train_loop(
|
||||
|
||||
random.seed(random_seed)
|
||||
randindex = [i for i in range(batch_gpu)]
|
||||
random.shuffle(randindex)
|
||||
|
||||
# set the start point for training loop
|
||||
if config["phase"] == "finetune":
|
||||
start = config["checkpoint_step"]
|
||||
start = config["ckpt"]
|
||||
else:
|
||||
start = 0
|
||||
if rank == 0:
|
||||
@@ -315,11 +323,12 @@ def train_loop(
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
|
||||
dis.feature_network.requires_grad_(False)
|
||||
|
||||
for step in range(start, total_step):
|
||||
gen.train()
|
||||
dis.train()
|
||||
dis.feature_network.eval()
|
||||
for interval in range(2):
|
||||
random.shuffle(randindex)
|
||||
src_image1, src_image2 = dataloader.next()
|
||||
|
||||
Reference in New Issue
Block a user