From 5271a9f3c28d7a971be6ab8426bd7e92c23d77ed Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Wed, 9 Feb 2022 12:46:34 +0800 Subject: [PATCH] update multi gpu --- GUI/file_sync/filestate_machine0.json | 8 +-- data_tools/data_loader_VGGFace2HQ_multigpu.py | 66 +------------------ train_multigpu.py | 20 +++--- train_scripts/trainer_multi_gpu.py | 55 +++++++++------- 4 files changed, 49 insertions(+), 100 deletions(-) diff --git a/GUI/file_sync/filestate_machine0.json b/GUI/file_sync/filestate_machine0.json index bc46d99..80d4bd2 100644 --- a/GUI/file_sync/filestate_machine0.json +++ b/GUI/file_sync/filestate_machine0.json @@ -60,10 +60,10 @@ "face_crop.py": 1643789609.1834445, "face_crop_video.py": 1643815024.5516832, "similarity.py": 1643269705.1073737, - "train_multigpu.py": 1644296706.054128, + "train_multigpu.py": 1644331842.3490777, "components\\arcface_decoder.py": 1643396144.2575414, "components\\Generator_nobias.py": 1643179001.810856, - "data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1644299401.8480241, + "data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1644330414.9587426, "data_tools\\data_loader_VGGFace2HQ_Rec.py": 1643398754.86898, "test_scripts\\tester_arcface_Rec.py": 1643431261.9333818, "test_scripts\\tester_image.py": 1643428951.5532105, @@ -83,9 +83,9 @@ "torch_utils\\ops\\__init__.py": 1640773190.0, "train_scripts\\trainer_arcface_rec.py": 1643399647.0182135, "train_scripts\\trainer_multigpu_base.py": 1644131205.772292, - "train_scripts\\trainer_multi_gpu.py": 1644301774.3077753, + "train_scripts\\trainer_multi_gpu.py": 1644331738.7729652, "train_yamls\\train_arcface_rec.yaml": 1643398807.3434353, - "train_yamls\\train_multigpu.yaml": 1644301838.3615713, + "train_yamls\\train_multigpu.yaml": 1644331809.0680442, "wandb\\run-20220129_032741-340btp9k\\files\\conda-environment.yaml": 1643398065.409959, "wandb\\run-20220129_032741-340btp9k\\files\\config.yaml": 1643398069.2392955, "wandb\\run-20220129_032939-2nmaozxq\\files\\conda-environment.yaml": 1643398182.647548, diff --git a/data_tools/data_loader_VGGFace2HQ_multigpu.py b/data_tools/data_loader_VGGFace2HQ_multigpu.py index b96e2fc..8e71fbf 100644 --- a/data_tools/data_loader_VGGFace2HQ_multigpu.py +++ b/data_tools/data_loader_VGGFace2HQ_multigpu.py @@ -5,7 +5,7 @@ # Created Date: Sunday February 6th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 8th February 2022 1:50:00 pm +# Last Modified: Tuesday, 8th February 2022 10:26:54 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -181,66 +181,4 @@ def GetLoader( dataset_roots, def denorm(x): out = (x + 1) / 2 - return out.clamp_(0, 1) - -if __name__ == "__main__": - from torchvision.utils import save_image - style_class = ["vangogh","picasso","samuel"] - categories_names = \ - ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor', - 'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field', - 'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse', - 'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus', - 'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet', - 'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse', - 'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway', - 'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor', - 'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior', - 'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge', - 'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe', - 'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah', - 'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse', - 'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain', - 'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park', - 'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track', - 'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge', - 'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole', - 's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house', - 'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct', - 'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave', - 'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft', - 'b/building_facade', 'c/cemetery'] - - s_datapath = "D:\\F_Disk\\data_set\\Art_Data\\data_art_backup" - c_datapath = "D:\\Downloads\\data_large" - savepath = "D:\\PatchFace\\PleaseWork\\multi-style-gan\\StyleTransfer\\dataloader_test" - - imsize = 512 - s_datasetloader= getLoader(s_datapath,c_datapath, - style_class, categories_names, - crop_size=imsize, batch_size=16, num_workers=4) - wocao = iter(s_datasetloader) - for i in range(500): - print("new batch") - s_image,c_image,label = next(wocao) - print(label) - # print(label) - # saved_image1 = torch.cat([denorm(image.data),denorm(hahh.data)],3) - # save_image(denorm(image), "%s\\%d-label-%d.jpg"%(savepath,i), nrow=1, padding=1) - pass - # import cv2 - # import os - # for dir_item in categories_names: - # join_path = Path(contentdatapath,dir_item) - # if join_path.exists(): - # print("processing %s"%dir_item,end='\r') - # images = join_path.glob('*.%s'%("jpg")) - # for item in images: - # temp_path = str(item) - # # temp = cv2.imread(temp_path) - # temp = Image.open(temp_path) - # if temp.layers<3: - # print("remove broken image...") - # print("image name:%s"%temp_path) - # del temp - # os.remove(item) \ No newline at end of file + return out.clamp_(0, 1) \ No newline at end of file diff --git a/train_multigpu.py b/train_multigpu.py index 5b64ced..d5ba081 100644 --- a/train_multigpu.py +++ b/train_multigpu.py @@ -5,7 +5,7 @@ # Created Date: Tuesday April 28th 2020 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 8th February 2022 1:05:05 pm +# Last Modified: Tuesday, 8th February 2022 10:50:37 pm # Modified By: Chen Xuanhong # Copyright (c) 2020 Shanghai Jiao Tong University ############################################################# @@ -31,7 +31,7 @@ def getParameters(): parser = argparse.ArgumentParser() # general settings - parser.add_argument('-v', '--version', type=str, default='multigpu2', + parser.add_argument('-v', '--version', type=str, default='multigpu3', help="version name for train, test, finetune") parser.add_argument('-t', '--tag', type=str, default='multigpu', help="tag for current experiment") @@ -225,13 +225,15 @@ def main(): # print some important information # TODO - print("Start to run training script: {}".format(moduleName)) - print("Traning version: %s"%sys_state["version"]) - print("Dataloader Name: %s"%sys_state["dataloader"]) - # print("Image Size: %d"%sys_state["imsize"]) - print("Batch size: %d"%(sys_state["batch_size"])) - print("GPUs:", gpus) - + # print("Start to run training script: {}".format(moduleName)) + # print("Traning version: %s"%sys_state["version"]) + # print("Dataloader Name: %s"%sys_state["dataloader"]) + # # print("Image Size: %d"%sys_state["imsize"]) + # print("Batch size: %d"%(sys_state["batch_size"])) + # print("GPUs:", gpus) + print("\n========================================================================\n") + print(sys_state) + print("\n========================================================================\n") # Load the training script and start to train diff --git a/train_scripts/trainer_multi_gpu.py b/train_scripts/trainer_multi_gpu.py index 0e6bc95..de56e85 100644 --- a/train_scripts/trainer_multi_gpu.py +++ b/train_scripts/trainer_multi_gpu.py @@ -5,7 +5,7 @@ # Created Date: Sunday January 9th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 8th February 2022 2:29:34 pm +# Last Modified: Tuesday, 8th February 2022 10:48:58 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -13,6 +13,7 @@ import os import time import random +import shutil import tempfile import numpy as np @@ -63,7 +64,13 @@ def init_framework(config, reporter, device, rank): if config["phase"] == "train": gscript_name = "components." + model_config["g_model"]["script"] + file1 = os.path.join("components", model_config["g_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) dscript_name = "components." + model_config["d_model"]["script"] + file1 = os.path.join("components", model_config["d_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) elif config["phase"] == "finetune": gscript_name = config["com_base"] + model_config["g_model"]["script"] @@ -92,6 +99,20 @@ def init_framework(config, reporter, device, rank): # train in GPU + # if in finetune phase, load the pretrained checkpoint + if config["phase"] == "finetune": + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["generator_name"])) + gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["discriminator_name"])) + dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + gen = gen.to(device) dis = dis.to(device) arcface= arcface.to(device) @@ -99,19 +120,7 @@ def init_framework(config, reporter, device, rank): arcface.eval() - # if in finetune phase, load the pretrained checkpoint - if config["phase"] == "finetune": - model_path = os.path.join(config["project_checkpoints"], - "step%d_%s.pth"%(config["checkpoint_step"], - config["checkpoint_names"]["generator_name"])) - gen.load_state_dict(torch.load(model_path)) - - model_path = os.path.join(config["project_checkpoints"], - "step%d_%s.pth"%(config["checkpoint_step"], - config["checkpoint_names"]["discriminator_name"])) - dis.load_state_dict(torch.load(model_path)) - - print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + return gen, dis, arcface # TODO modify this function to configurate the optimizer of your pipeline @@ -149,12 +158,12 @@ def setup_optimizers(config, reporter, gen, dis, rank): # self.optimizers.append(self.optimizer_g) if config["phase"] == "finetune": opt_path = os.path.join(config["project_checkpoints"], - "step%d_optim_%s.pth"%(config["checkpoint_step"], + "step%d_optim_%s.pth"%(config["ckpt"], config["optimizer_names"]["generator_name"])) g_optimizer.load_state_dict(torch.load(opt_path)) opt_path = os.path.join(config["project_checkpoints"], - "step%d_optim_%s.pth"%(config["checkpoint_step"], + "step%d_optim_%s.pth"%(config["ckpt"], config["optimizer_names"]["discriminator_name"])) d_optimizer.load_state_dict(torch.load(opt_path)) @@ -270,19 +279,19 @@ def train_loop( # TODO replace below lines to build your losses # MSE_loss = torch.nn.MSELoss() l1_loss = torch.nn.L1Loss() - cos_loss = cosin_metric + cos_loss = torch.nn.CosineSimilarity() g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank) # Initialize logs. if rank == 0: print('Initializing logs...') - if rank == 0: #==============build tensorboard=================# if config["logger"] == "tensorboard": import torch.utils.tensorboard as tensorboard - tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) - logger = tensorboard_writer + tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) + logger = tensorboard_writer + elif config["logger"] == "wandb": import wandb wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", @@ -297,11 +306,10 @@ def train_loop( random.seed(random_seed) randindex = [i for i in range(batch_gpu)] - random.shuffle(randindex) # set the start point for training loop if config["phase"] == "finetune": - start = config["checkpoint_step"] + start = config["ckpt"] else: start = 0 if rank == 0: @@ -315,11 +323,12 @@ def train_loop( from utilities.logo_class import logo_class logo_class.print_start_training() + dis.feature_network.requires_grad_(False) + for step in range(start, total_step): gen.train() dis.train() - dis.feature_network.eval() for interval in range(2): random.shuffle(randindex) src_image1, src_image2 = dataloader.next()