update multi gpu

This commit is contained in:
chenxuanhong
2022-02-09 12:46:34 +08:00
parent 6f92dbc896
commit 5271a9f3c2
4 changed files with 49 additions and 100 deletions
+32 -23
View File
@@ -5,7 +5,7 @@
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 8th February 2022 2:29:34 pm
# Last Modified: Tuesday, 8th February 2022 10:48:58 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -13,6 +13,7 @@
import os
import time
import random
import shutil
import tempfile
import numpy as np
@@ -63,7 +64,13 @@ def init_framework(config, reporter, device, rank):
if config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
file1 = os.path.join("components", model_config["g_model"]["script"]+".py")
tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py")
shutil.copyfile(file1,tgtfile1)
dscript_name = "components." + model_config["d_model"]["script"]
file1 = os.path.join("components", model_config["d_model"]["script"]+".py")
tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py")
shutil.copyfile(file1,tgtfile1)
elif config["phase"] == "finetune":
gscript_name = config["com_base"] + model_config["g_model"]["script"]
@@ -92,6 +99,20 @@ def init_framework(config, reporter, device, rank):
# train in GPU
# if in finetune phase, load the pretrained checkpoint
if config["phase"] == "finetune":
model_path = os.path.join(config["project_checkpoints"],
"step%d_%s.pth"%(config["ckpt"],
config["checkpoint_names"]["generator_name"]))
gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu"))
model_path = os.path.join(config["project_checkpoints"],
"step%d_%s.pth"%(config["ckpt"],
config["checkpoint_names"]["discriminator_name"]))
dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu"))
print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"]))
gen = gen.to(device)
dis = dis.to(device)
arcface= arcface.to(device)
@@ -99,19 +120,7 @@ def init_framework(config, reporter, device, rank):
arcface.eval()
# if in finetune phase, load the pretrained checkpoint
if config["phase"] == "finetune":
model_path = os.path.join(config["project_checkpoints"],
"step%d_%s.pth"%(config["checkpoint_step"],
config["checkpoint_names"]["generator_name"]))
gen.load_state_dict(torch.load(model_path))
model_path = os.path.join(config["project_checkpoints"],
"step%d_%s.pth"%(config["checkpoint_step"],
config["checkpoint_names"]["discriminator_name"]))
dis.load_state_dict(torch.load(model_path))
print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"]))
return gen, dis, arcface
# TODO modify this function to configurate the optimizer of your pipeline
@@ -149,12 +158,12 @@ def setup_optimizers(config, reporter, gen, dis, rank):
# self.optimizers.append(self.optimizer_g)
if config["phase"] == "finetune":
opt_path = os.path.join(config["project_checkpoints"],
"step%d_optim_%s.pth"%(config["checkpoint_step"],
"step%d_optim_%s.pth"%(config["ckpt"],
config["optimizer_names"]["generator_name"]))
g_optimizer.load_state_dict(torch.load(opt_path))
opt_path = os.path.join(config["project_checkpoints"],
"step%d_optim_%s.pth"%(config["checkpoint_step"],
"step%d_optim_%s.pth"%(config["ckpt"],
config["optimizer_names"]["discriminator_name"]))
d_optimizer.load_state_dict(torch.load(opt_path))
@@ -270,19 +279,19 @@ def train_loop(
# TODO replace below lines to build your losses
# MSE_loss = torch.nn.MSELoss()
l1_loss = torch.nn.L1Loss()
cos_loss = cosin_metric
cos_loss = torch.nn.CosineSimilarity()
g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank)
# Initialize logs.
if rank == 0:
print('Initializing logs...')
if rank == 0:
#==============build tensorboard=================#
if config["logger"] == "tensorboard":
import torch.utils.tensorboard as tensorboard
tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"])
logger = tensorboard_writer
tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"])
logger = tensorboard_writer
elif config["logger"] == "wandb":
import wandb
wandb.init(project="Simswap_HQ", entity="xhchen", notes="512",
@@ -297,11 +306,10 @@ def train_loop(
random.seed(random_seed)
randindex = [i for i in range(batch_gpu)]
random.shuffle(randindex)
# set the start point for training loop
if config["phase"] == "finetune":
start = config["checkpoint_step"]
start = config["ckpt"]
else:
start = 0
if rank == 0:
@@ -315,11 +323,12 @@ def train_loop(
from utilities.logo_class import logo_class
logo_class.print_start_training()
dis.feature_network.requires_grad_(False)
for step in range(start, total_step):
gen.train()
dis.train()
dis.feature_network.eval()
for interval in range(2):
random.shuffle(randindex)
src_image1, src_image2 = dataloader.next()