diff --git a/GUI.py b/GUI.py index f205cbf..a4ffb86 100644 --- a/GUI.py +++ b/GUI.py @@ -5,14 +5,17 @@ # Created Date: Wednesday December 22nd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 17th February 2022 10:47:35 pm +# Last Modified: Friday, 22nd April 2022 11:23:17 am # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# +from glob import glob +from ipaddress import ip_address import os +import re import sys import time import json @@ -37,6 +40,7 @@ import tkinter.ttk as ttk import subprocess from pathlib import Path +from tkinter.filedialog import askopenfilename @@ -163,6 +167,27 @@ class fileUploaderClass(object): self.__ssh__.close() return roots + def sshScpGetRNamesBySuffix(self, remoteDir, suffix): + self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__) + sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport()) + sftp = self.__ssh__.open_sftp() + wocao = sftp.listdir(remoteDir) + # print(wocao.st_mtime) + roots = {} + for item in wocao: + wocao = sftp.stat(remoteDir+"/"+item) + roots[item] = { + "t":wocao.st_mtime, + "p":remoteDir+"/"+item + } + # temp= remoteDir+ "/"+item + # child_dirs = sftp.listdir(temp) + # child_dirs = ["save\\" +item + "\\" + i for i in child_dirs] + # list_name += child_dirs + sftp.close() + self.__ssh__.close() + return roots + def sshScpGet(self, remoteFile, localFile, showProgress=False): self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__) sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport()) @@ -308,16 +333,16 @@ class Application(tk.Frame): "config_json_name":"model_config.json" } machine_text = { - "ip": "0.0.0.0", + "ip": "localhost", "user": "username", "port": 22, "passwd": "12345678", - "path": "/path/to/remote_host", + "path": ".", "ckp_path":"save", - "logfilename": "filestate_machine0.json" + "logfilename": "filestate_machine_localhost.json" } current_log = {} - + current_ckpt = {} def __init__(self, master=None): tk.Frame.__init__(self, master,bg='black') @@ -435,7 +460,7 @@ class Application(tk.Frame): config_frame.pack(fill="both", padx=5,pady=5) config_frame.columnconfigure(0, weight=1) config_frame.columnconfigure(1, weight=1) - # config_frame.columnconfigure(2, weight=1) + config_frame.columnconfigure(2, weight=1) machine_btn = tk.Button(config_frame, text = "Ignore Conf", font=font_list, command = self.IgnoreConfig, bg='#660099', fg='#F5F5F5') @@ -445,12 +470,17 @@ class Application(tk.Frame): font=font_list, command = self.EnvConfig, bg='#660099', fg='#F5F5F5') machine_btn2.grid(row=0,column=1,sticky=tk.EW) + machine_btn2 = tk.Button(config_frame, text = "Test Conf", + font=font_list, command = self.TestConfig, bg='#660099', fg='#F5F5F5') + machine_btn2.grid(row=0,column=2,sticky=tk.EW) + ################################################################################################# log_frame = tk.Frame(self.master) log_frame.pack(fill="both", padx=5,pady=5) log_frame.columnconfigure(0, weight=1) log_frame.columnconfigure(1, weight=1) log_frame.columnconfigure(2, weight=1) + log_frame.columnconfigure(3, weight=1) self.log_var = tkinter.StringVar() @@ -460,36 +490,104 @@ class Application(tk.Frame): self.update_ckpt_task() self.log_com.bind("<>",select_log) + self.test_var = tkinter.StringVar() + + self.test_com = ttk.Combobox(log_frame, textvariable=self.test_var) + self.test_com.grid(row=0,column=1,sticky=tk.EW) log_update_button = tk.Button(log_frame, text = "Fresh", font=font_list, command = self.UpdateLog, bg='#F4A460', fg='#F5F5F5') - log_update_button.grid(row=0,column=1,sticky=tk.EW) - - log_update_button = tk.Button(log_frame, text = "Pull Log", - font=font_list, command = self.PullLog, bg='#F4A460', fg='#F5F5F5') log_update_button.grid(row=0,column=2,sticky=tk.EW) + # log_update_button = tk.Button(log_frame, text = "Pull Log", + # font=font_list, command = self.PullLog, bg='#F4A460', fg='#F5F5F5') + # log_update_button.grid(row=0,column=2,sticky=tk.EW) + + log_update_button = tk.Button(log_frame, text = "Fresh CKPT", + font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5') + log_update_button.grid(row=0,column=3,sticky=tk.EW) + ################################################################################################# test_frame = tk.Frame(self.master) test_frame.pack(fill="both", padx=5,pady=5) test_frame.columnconfigure(0, weight=1) test_frame.columnconfigure(1, weight=1) test_frame.columnconfigure(2, weight=1) + # test_frame.columnconfigure(3, weight=1) - self.test_var = tkinter.StringVar() + self.testscript_var = tkinter.StringVar() - self.test_com = ttk.Combobox(test_frame, textvariable=self.test_var) - self.test_com.grid(row=0,column=0,sticky=tk.EW) - + self.testscript_com = ttk.Combobox(test_frame, textvariable=self.testscript_var) + self.testscript_com.grid(row=0,column=0,sticky=tk.EW) - test_update_button = tk.Button(test_frame, text = "Fresh CKPT", - font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5') - test_update_button.grid(row=0,column=1,sticky=tk.EW) + testscript_files = Path("./test_scripts").glob("*.py") + testscript_list = [] + for item in testscript_files: + basename = item.name + basename = os.path.splitext(basename)[0] + testscript_list.append(basename) + self.testscript_com["value"] = testscript_list + + # test_update_button = tk.Button(test_frame, text = "Fresh CKPT", + # font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5') + # test_update_button.grid(row=0,column=1,sticky=tk.EW) test_update_button = tk.Button(test_frame, text = "Test", font=font_list, command = self.Test, bg='#F4A460', fg='#F5F5F5') + test_update_button.grid(row=0,column=1,sticky=tk.EW) + + # test_update_button = tk.Button(test_frame, text = "Test Config", + # font=font_list, command = self.TestConfig, bg='#660099', fg='#F5F5F5') + # test_update_button.grid(row=0,column=2,sticky=tk.EW) + + test_update_button = tk.Button(test_frame, text = "Sample", + font=font_list, command = self.OpenSample, bg='#0033FF', fg='#F5F5F5') test_update_button.grid(row=0,column=2,sticky=tk.EW) + # ################################################################################################# + + select_frame = tk.Frame(self.master) + select_frame.pack(fill="both", padx=5,pady=5) + + select_frame.columnconfigure(0, weight=2) + select_frame.columnconfigure(1, weight=5) + select_frame.columnconfigure(2, weight=1) + select_frame.columnconfigure(3, weight=2) + select_frame.columnconfigure(4, weight=5) + select_frame.columnconfigure(5, weight=1) + select_frame.columnconfigure(6, weight=3) + + self.preprocess_var = tkinter.StringVar() + + self.preprocess_com = ttk.Combobox(select_frame, textvariable=self.preprocess_var) + self.preprocess_com.grid(row=0,column=6,sticky=tk.EW) + self.preprocess_com["value"] = ["True", "False"] + self.preprocess_com.current(1) + self.ID_path = tkinter.StringVar() + self.ID_path.set("...") + tk.Label(select_frame, text="ID:",font=font_list,justify="left")\ + .grid(row=0,column=0,sticky=tk.EW) + + tk.Entry(select_frame, textvariable= self.ID_path, font=font_list)\ + .grid(row=0,column=1,sticky=tk.EW) + + tk.Button(select_frame, text = "...", font=font_list, + command = self.Select_ID_path, bg='#F4A460', fg='#F5F5F5')\ + .grid(row=0,column=2,sticky=tk.EW) + + + self.Attr_path = tkinter.StringVar() + self.Attr_path.set("...") + tk.Label(select_frame, text="Attr:",font=font_list,justify="left")\ + .grid(row=0,column=3,sticky=tk.EW) + + tk.Entry(select_frame, textvariable= self.Attr_path, font=font_list)\ + .grid(row=0,column=4,sticky=tk.EW) + + tk.Button(select_frame, text = "...", font=font_list, + command = self.Select_Attr_path, bg='#F4A460', fg='#F5F5F5')\ + .grid(row=0,column=5,sticky=tk.EW) + # ################################################################################################# # tbtext_frame = tk.Frame(self.master) @@ -527,23 +625,61 @@ class Application(tk.Frame): self.master.protocol("WM_DELETE_WINDOW", self.on_closing) - # def __scaning_logs__(self): + + def Select_ID_path(self): + thread_update = threading.Thread(target=self.select_ID_task) + thread_update.start() + + def select_ID_task(self): + path = askopenfilename() + print("Selected ID: %s"%path) + self.ID_path.set(path) + + def Select_Attr_path(self): + thread_update = threading.Thread(target=self.select_Attr_task) + thread_update.start() + + def select_Attr_task(self): + path = askopenfilename() + print("Selected Attibutes: %s"%path) + self.Attr_path.set(path) + def UpdateCKPT(self): thread_update = threading.Thread(target=self.update_ckpt_task) thread_update.start() - + def update_ckpt_task(self): - ip = self.list_com.get() - log = self.log_com.get() - cur_mac = self.machine_dict[ip] - files = Path('.',cur_mac["ckp_path"], log) - files = files.glob('*.pth') - all_files = [] - for one_file in files: - all_files.append(one_file.name) - self.test_com["value"] =all_files - if len(all_files): + print("Loading checkpoints..........................") + remotemachine, mac = self.connection() + log = self.log_com.get() + remote_path = os.path.join(mac["path"],mac["ckp_path"], log, "checkpoints").replace("\\", "/") + + if remotemachine == []: + files = Path(remote_path).glob("*/") + first_level = {} + for one_file in files: + first_level[one_file.name] = { + "t":"", + "p":"" + } + else: + first_level = remotemachine.sshScpGetNames(remote_path) + + if len(first_level) == 0: + self.test_com["value"] = [""] self.test_com.current(0) + print("No checkpoint found!") + return + logs = [] + for k,v in first_level.items(): + logs.append([k,v["t"]]) + # logs = sorted(logs) + logs = sorted(logs, key= lambda logs : logs[1],reverse=True) + self.test_com["value"] =[item[0] for item in logs] + + self.test_com.current(0) + self.current_ckpt = first_level + print("Checkpoints list update success!") def CopyPasswd(self): def copy(): @@ -557,13 +693,25 @@ class Application(tk.Frame): def Test(self): def test_task(): + ip = self.list_com.get() log = self.log_com.get() ckpt = self.test_com.get() + script_name = self.testscript_com.get() + id_path = self.ID_path.get() + attr_path = self.Attr_path.get() + preprocess = self.preprocess_com.get() + # if preprocess == "Preprocess-Off": + # preprocess = "off" + # else: + # preprocess = "on" + ckpt = re.sub("\D", "", ckpt) cwd = os.getcwd() - files = str(Path(log, ckpt)) - print(files) + # files = str(Path(log, ckpt)) + # print(files) + print("start cmd /k \"cd /d %s && conda activate base \ + && python test.py -v %s -s %s -t %s -n %s -i %s -a %s --preprocess %s \""%(cwd, log, ckpt, script_name, ip, id_path, attr_path, preprocess)) subprocess.check_call("start cmd /k \"cd /d %s && conda activate base \ - && python test.py --model %s\""%(cwd, files), shell=True) + && python test.py -v %s -s %s -t %s -n %s --preprocess %s -i %s -a %s\""%(cwd, log, ckpt, script_name, ip, preprocess , id_path, attr_path), shell=True) thread_update = threading.Thread(target=test_task) thread_update.start() @@ -595,7 +743,7 @@ class Application(tk.Frame): ssh_username = cur_mac["user"] ssh_passwd = cur_mac["passwd"] ssh_port = int(cur_mac["port"]) - print(ssh_ip) + print("Processing IP: %s."%ssh_ip) if ip.lower() == "local" or ip.lower() == "localhost": print("localhost no need to connect!") return [], cur_mac @@ -607,6 +755,7 @@ class Application(tk.Frame): print(cells) def update_log_task(self): + print("Processing! Do not touch!") remotemachine,mac = self.connection() remote_path = os.path.join(mac["path"],mac["ckp_path"]).replace("\\", "/") if remotemachine == []: @@ -617,16 +766,23 @@ class Application(tk.Frame): "t":"", "p":"" } + # elif remotemachine == "localhost": + else: first_level = remotemachine.sshScpGetNames(remote_path) + if len(first_level) == 0: + print("No training log found!") + return logs = [] for k,v in first_level.items(): - logs.append(k) - logs = sorted(logs) - self.log_com["value"] =logs + logs.append([k,v["t"]]) + # logs = sorted(logs) + logs = sorted(logs, key= lambda logs : logs[1],reverse=True) + self.log_com["value"] = [item[0] for item in logs] self.log_com.current(0) self.current_log = first_level self.update_ckpt_task() + print("Done!") def UpdateLog(self): thread_update = threading.Thread(target=self.update_log_task) @@ -666,6 +822,21 @@ class Application(tk.Frame): thread_update = threading.Thread(target=pull_log_task) thread_update.start() + + def TestConfig(self): + def test_config_task(): + subprocess.call("start %s"%"test.py", shell=True) + thread_update = threading.Thread(target=test_config_task) + thread_update.start() + + def OpenSample(self): + def open_cmd_task(): + log = self.log_com.get() + cwd = os.getcwd() + sample = os.path.join(cwd,"test_logs",log,"samples") + subprocess.call("explorer "+sample, shell=False) + thread_update = threading.Thread(target=open_cmd_task) + thread_update.start() def OpenCMD(self): def open_cmd_task(): @@ -685,6 +856,7 @@ class Application(tk.Frame): ip = self.list_com.get() if ip.lower() == "local" or ip.lower() == "localhost": print("localhost no need to connect!") + return cur_mac = self.machine_dict[ip] ssh_ip = cur_mac["ip"] ssh_username = cur_mac["user"] @@ -707,6 +879,10 @@ class Application(tk.Frame): def GPUUsage(self): def gpu_usage_task(): remotemachine,_ = self.connection() + if remotemachine == "local" or remotemachine == "localhost": + print("localhost no need to connect!") + return + results = remotemachine.sshExec("nvidia-smi") print(results) diff --git a/GUI/file_sync/filestate_machine0.json b/GUI/file_sync/filestate_machine0.json index 260ae6c..cba068f 100644 --- a/GUI/file_sync/filestate_machine0.json +++ b/GUI/file_sync/filestate_machine0.json @@ -1,91 +1,91 @@ { - "GUI.py": 1647657822.9152665, - "test.py": 1647879709.2723496, - "train.py": 1647657822.9562755, - "components\\Generator.py": 1647657822.93127, - "components\\projected_discriminator.py": 1647657822.938272, - "components\\pg_modules\\blocks.py": 1647657822.9362714, - "components\\pg_modules\\diffaug.py": 1647657822.9362714, - "components\\pg_modules\\discriminator.py": 1647657822.937271, - "components\\pg_modules\\networks_fastgan.py": 1647657822.937271, - "components\\pg_modules\\networks_stylegan2.py": 1647657822.937271, - "components\\pg_modules\\projector.py": 1647657822.938272, - "data_tools\\data_loader.py": 1647657822.9392715, - "data_tools\\data_loader_condition.py": 1647657822.9402719, - "data_tools\\data_loader_VGGFace2HQ.py": 1647657822.9392715, - "data_tools\\StyleResize.py": 1647657822.9392715, - "data_tools\\test_dataloader_dir.py": 1647657822.941272, - "losses\\PerceptualLoss.py": 1647657822.9432724, - "losses\\SliceWassersteinDistance.py": 1647657822.9432724, - "models\\arcface_models.py": 1647657822.9442725, - "models\\config.py": 1647657822.9442725, - "models\\__init__.py": 1647657822.9442725, - "test_scripts\\tester_common.py": 1647657822.9472733, + "GUI.py": 1649868392.5891902, + "test.py": 1649641910.555175, + "train.py": 1643397924.974299, + "components\\Generator.py": 1644689001.9005148, + "components\\projected_discriminator.py": 1642348101.4661522, + "components\\pg_modules\\blocks.py": 1640773190.0, + "components\\pg_modules\\diffaug.py": 1640773190.0, + "components\\pg_modules\\discriminator.py": 1642349784.9407308, + "components\\pg_modules\\networks_fastgan.py": 1640773190.0, + "components\\pg_modules\\networks_stylegan2.py": 1640773190.0, + "components\\pg_modules\\projector.py": 1642349764.3896568, + "data_tools\\data_loader.py": 1611123530.660446, + "data_tools\\data_loader_condition.py": 1625411562.8217106, + "data_tools\\data_loader_VGGFace2HQ.py": 1644234949.3769877, + "data_tools\\StyleResize.py": 1624954084.7176485, + "data_tools\\test_dataloader_dir.py": 1634041792.6743984, + "losses\\PerceptualLoss.py": 1615020169.668723, + "losses\\SliceWassersteinDistance.py": 1634022704.6082795, + "models\\arcface_models.py": 1642390690.623, + "models\\config.py": 1632643596.2908099, + "models\\__init__.py": 1642390864.8828168, + "test_scripts\\tester_common.py": 1625369535.199175, "test_scripts\\tester_FastNST.py": 1634041357.607633, - "train_scripts\\trainer_base.py": 1647657822.9582758, - "train_scripts\\trainer_FM.py": 1647657822.957276, - "train_scripts\\trainer_naiv512.py": 1647657822.9602764, - "utilities\\checkpoint_manager.py": 1647657822.9652774, - "utilities\\figure.py": 1647657822.9652774, - "utilities\\json_config.py": 1647657822.9652774, - "utilities\\learningrate_scheduler.py": 1647657822.9652774, - "utilities\\logo_class.py": 1647657822.9662776, - "utilities\\plot.py": 1647657822.9662776, - "utilities\\reporter.py": 1647657822.9662776, - "utilities\\save_heatmap.py": 1647657822.967278, - "utilities\\sshupload.py": 1647657822.967278, - "utilities\\transfer_checkpoint.py": 1647657822.967278, - "utilities\\utilities.py": 1647657822.9682784, - "utilities\\yaml_config.py": 1647657822.9682784, - "train_yamls\\train_512FM.yaml": 1647657822.961277, - "train_scripts\\trainer_2layer_FM.py": 1647657822.957276, - "train_yamls\\train_2layer_FM.yaml": 1647657822.961277, - "components\\Generator_reduce.py": 1647657822.934271, + "train_scripts\\trainer_base.py": 1642396105.3868554, + "train_scripts\\trainer_FM.py": 1643021959.3577182, + "train_scripts\\trainer_naiv512.py": 1642315674.9740853, + "utilities\\checkpoint_manager.py": 1611123530.6624403, + "utilities\\figure.py": 1611123530.6634378, + "utilities\\json_config.py": 1611123530.6614666, + "utilities\\learningrate_scheduler.py": 1611123530.675422, + "utilities\\logo_class.py": 1633883995.3093486, + "utilities\\plot.py": 1641911100.7995758, + "utilities\\reporter.py": 1646311333.3067005, + "utilities\\save_heatmap.py": 1611123530.679439, + "utilities\\sshupload.py": 1649910787.5441866, + "utilities\\transfer_checkpoint.py": 1642397157.0163105, + "utilities\\utilities.py": 1649907294.9180465, + "utilities\\yaml_config.py": 1611123530.6614666, + "train_yamls\\train_512FM.yaml": 1643021615.8106658, + "train_scripts\\trainer_2layer_FM.py": 1642826548.2530458, + "train_yamls\\train_2layer_FM.yaml": 1642411635.5534878, + "components\\Generator_reduce.py": 1645020911.0651233, "insightface_func\\face_detect_crop_multi.py": 1643796928.6362474, "insightface_func\\face_detect_crop_single.py": 1638370471.7967434, "insightface_func\\__init__.py": 1624197300.011183, "insightface_func\\utils\\face_align_ffhqandnewarc.py": 1638370471.850638, - "losses\\PatchNCE.py": 1647677173.0239084, + "losses\\PatchNCE.py": 1647769120.567006, "parsing_model\\model.py": 1626745709.554252, "parsing_model\\resnet.py": 1626745709.554252, - "test_scripts\\tester_common copy.py": 1647657822.9472733, - "test_scripts\\tester_video.py": 1647657822.9482737, - "train_scripts\\trainer_cycleloss.py": 1647657822.9592762, - "train_scripts\\trainer_GramFM.py": 1647657822.9582758, - "utilities\\ImagenetNorm.py": 1647657822.9642777, - "utilities\\reverse2original.py": 1647657822.9662776, - "train_yamls\\train_cycleloss.yaml": 1647704989.2310138, - "train_yamls\\train_GramFM.yaml": 1647657822.9622767, - "train_yamls\\train_512FM_Modulation.yaml": 1647657822.961277, - "face_crop.py": 1647657822.9422722, - "face_crop_video.py": 1647657822.9422722, - "similarity.py": 1647657822.945273, - "train_multigpu.py": 1647967698.603863, - "components\\arcface_decoder.py": 1647657822.9352713, + "test_scripts\\tester_common copy.py": 1625369535.199175, + "test_scripts\\tester_video.py": 1649749352.843031, + "train_scripts\\trainer_cycleloss.py": 1642580463.495596, + "train_scripts\\trainer_GramFM.py": 1643095575.2628715, + "utilities\\ImagenetNorm.py": 1642732910.5280058, + "utilities\\reverse2original.py": 1648533907.6187606, + "train_yamls\\train_cycleloss.yaml": 1647769120.6110919, + "train_yamls\\train_GramFM.yaml": 1643398791.363959, + "train_yamls\\train_512FM_Modulation.yaml": 1643022022.3165789, + "face_crop.py": 1649079350.7075238, + "face_crop_video.py": 1649988435.4005315, + "similarity.py": 1643269705.1073737, + "train_multigpu.py": 1650004781.6307411, + "components\\arcface_decoder.py": 1643396144.2575414, "components\\Generator_nobias.py": 1643179001.810856, - "data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1647657822.9402719, - "data_tools\\data_loader_VGGFace2HQ_Rec.py": 1647657822.9392715, - "test_scripts\\tester_arcface_Rec.py": 1647657822.946273, - "test_scripts\\tester_image.py": 1647657822.9472733, - "torch_utils\\custom_ops.py": 1647657822.9482737, - "torch_utils\\misc.py": 1647657822.9492736, - "torch_utils\\persistence.py": 1647657822.9552753, - "torch_utils\\training_stats.py": 1647657822.9562755, - "torch_utils\\utils_spectrum.py": 1647657822.9562755, - "torch_utils\\__init__.py": 1647657822.9482737, - "torch_utils\\ops\\bias_act.py": 1647657822.9502747, - "torch_utils\\ops\\conv2d_gradfix.py": 1647657822.9512744, - "torch_utils\\ops\\conv2d_resample.py": 1647657822.9512744, - "torch_utils\\ops\\filtered_lrelu.py": 1647657822.9532747, - "torch_utils\\ops\\fma.py": 1647657822.9532747, - "torch_utils\\ops\\grid_sample_gradfix.py": 1647657822.9542756, - "torch_utils\\ops\\upfirdn2d.py": 1647657822.9552753, - "torch_utils\\ops\\__init__.py": 1647657822.9492736, - "train_scripts\\trainer_arcface_rec.py": 1647657822.9582758, - "train_scripts\\trainer_multigpu_base.py": 1647657822.9602764, - "train_scripts\\trainer_multi_gpu.py": 1647657822.9592762, - "train_yamls\\train_arcface_rec.yaml": 1647657822.9622767, - "train_yamls\\train_multigpu.yaml": 1647657822.963277, + "data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1649177633.2426238, + "data_tools\\data_loader_VGGFace2HQ_Rec.py": 1643398754.86898, + "test_scripts\\tester_arcface_Rec.py": 1643431261.9333818, + "test_scripts\\tester_image.py": 1648028827.923354, + "torch_utils\\custom_ops.py": 1640773190.0, + "torch_utils\\misc.py": 1640773190.0, + "torch_utils\\persistence.py": 1640773190.0, + "torch_utils\\training_stats.py": 1640773190.0, + "torch_utils\\utils_spectrum.py": 1640773190.0, + "torch_utils\\__init__.py": 1640773190.0, + "torch_utils\\ops\\bias_act.py": 1640773190.0, + "torch_utils\\ops\\conv2d_gradfix.py": 1640773190.0, + "torch_utils\\ops\\conv2d_resample.py": 1640773190.0, + "torch_utils\\ops\\filtered_lrelu.py": 1640773190.0, + "torch_utils\\ops\\fma.py": 1640773190.0, + "torch_utils\\ops\\grid_sample_gradfix.py": 1640773190.0, + "torch_utils\\ops\\upfirdn2d.py": 1640773190.0, + "torch_utils\\ops\\__init__.py": 1640773190.0, + "train_scripts\\trainer_arcface_rec.py": 1643399647.0182135, + "train_scripts\\trainer_multigpu_base.py": 1644131205.772292, + "train_scripts\\trainer_multi_gpu.py": 1648285132.309124, + "train_yamls\\train_arcface_rec.yaml": 1643398807.3434353, + "train_yamls\\train_multigpu.yaml": 1644549590.0652373, "wandb\\run-20220129_032741-340btp9k\\files\\conda-environment.yaml": 1643398065.409959, "wandb\\run-20220129_032741-340btp9k\\files\\config.yaml": 1643398069.2392955, "wandb\\run-20220129_032939-2nmaozxq\\files\\conda-environment.yaml": 1643398182.647548, @@ -100,93 +100,208 @@ "wandb\\run-20220129_034859-2puk6sph\\files\\config.yaml": 1643399477.881678, "wandb\\run-20220129_035624-3hmwgcgw\\files\\conda-environment.yaml": 1643399787.8899708, "wandb\\run-20220129_035624-3hmwgcgw\\files\\config.yaml": 1643426465.6088357, - "dnnlib\\util.py": 1647657822.941272, - "dnnlib\\__init__.py": 1647657822.941272, - "components\\Generator_ori.py": 1647657822.9332705, - "losses\\cos.py": 1647657822.9442725, - "data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1647657822.9402719, - "speed_test.py": 1647657822.945273, - "components\\DeConv_Invo.py": 1647657822.9302697, + "dnnlib\\util.py": 1640773190.0, + "dnnlib\\__init__.py": 1640773190.0, + "components\\Generator_ori.py": 1644689174.414655, + "losses\\cos.py": 1644229583.4023254, + "data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1644860106.943826, + "speed_test.py": 1648982366.0803514, + "components\\DeConv_Invo.py": 1644426607.1588645, "components\\Generator_reduce_up.py": 1644688655.2096283, - "components\\Generator_upsample.py": 1647657822.9352713, - "components\\misc\\Involution.py": 1647657822.9352713, - "train_yamls\\train_Invoup.yaml": 1647657822.9622767, - "flops.py": 1647657822.9422722, - "detection_test.py": 1647657822.941272, - "components\\DeConv_Depthwise.py": 1647657822.9292698, - "components\\DeConv_Depthwise1.py": 1647657822.9292698, + "components\\Generator_upsample.py": 1644689723.8293872, + "components\\misc\\Involution.py": 1644509321.5267963, + "train_yamls\\train_Invoup.yaml": 1644689981.9794765, + "flops.py": 1649040334.6186154, + "detection_test.py": 1644935512.6830947, + "components\\DeConv_Depthwise.py": 1645064447.4379447, + "components\\DeConv_Depthwise1.py": 1644946969.5054545, "components\\Generator_modulation_depthwise.py": 1644861291.4467516, - "components\\Generator_modulation_depthwise_config.py": 1647657822.93227, - "components\\Generator_modulation_up.py": 1647657822.9332705, - "components\\Generator_oriae_modulation.py": 1647657822.934271, - "components\\Generator_ori_config.py": 1647657822.934271, - "train_scripts\\trainer_multi_gpu1.py": 1647657822.9602764, - "train_yamls\\train_Depthwise.yaml": 1647657822.961277, - "train_yamls\\train_depthwise_modulation.yaml": 1647657822.963277, - "train_yamls\\train_oriae_modulation.yaml": 1647657822.9642777, - "train_distillation_mgpu.py": 1647657822.9562755, - "components\\DeConv.py": 1647657822.9292698, - "components\\DeConv_Depthwise_ECA.py": 1647657822.9292698, - "components\\ECA.py": 1647657822.9302697, - "components\\ECA_Depthwise_Conv.py": 1647657822.93127, - "components\\Generator_eca_depthwise.py": 1647657822.93227, - "losses\\KA.py": 1647657822.9432724, - "train_scripts\\trainer_distillation_mgpu.py": 1647657822.9592762, - "train_yamls\\train_distillation.yaml": 1647657822.963277, - "annotation.py": 1647657822.9172668, - "components\\DeConv_ECA_Invo.py": 1647657822.9302697, - "components\\DeConv_Invobn.py": 1647657822.9302697, - "components\\Generator_Invobn_config.py": 1647657822.93127, - "components\\Generator_Invobn_config1.py": 1647657822.93127, - "components\\misc\\Involution_BN.py": 1647657822.9362714, - "components\\misc\\Involution_ECA.py": 1647657822.9362714, - "train_yamls\\train_Invobn_config.yaml": 1647657822.9622767, - "components\\Generator_Invobn_config2.py": 1647657822.93227, - "components\\Generator_Invobn_config3.py": 1647657822.93227, - "components\\Generator_ori_modulation_config.py": 1647657822.934271, - "test_scripts\\tester_image_allstep.py": 1647657822.9482737, - "train_yamls\\train_ori_modulation_config.yaml": 1647657822.9642777, - "test_arcface.py": 1647657822.946273, - "arcface_torch\\dataset.py": 1647657822.9222684, - "arcface_torch\\eval_ijbc.py": 1647657822.9242685, - "arcface_torch\\inference.py": 1647657822.9242685, - "arcface_torch\\losses.py": 1647657822.9242685, - "arcface_torch\\lr_scheduler.py": 1647657822.9242685, - "arcface_torch\\onnx_helper.py": 1647657822.9252684, - "arcface_torch\\onnx_ijbc.py": 1647657822.9252684, - "arcface_torch\\partial_fc.py": 1647657822.9252684, - "arcface_torch\\torch2onnx.py": 1647657822.9262686, - "arcface_torch\\train.py": 1647657822.9262686, - "arcface_torch\\backbones\\iresnet.py": 1647657822.918267, - "arcface_torch\\backbones\\iresnet2060.py": 1647657822.9192681, - "arcface_torch\\backbones\\mobilefacenet.py": 1647657822.9192681, - "arcface_torch\\backbones\\__init__.py": 1647657822.918267, - "arcface_torch\\configs\\3millions.py": 1647657822.9192681, - "arcface_torch\\configs\\base.py": 1647657822.9192681, - "arcface_torch\\configs\\glint360k_mobileface_lr02_bs4k.py": 1647657822.9202676, - "arcface_torch\\configs\\glint360k_r100_lr02_bs4k_16gpus.py": 1647657822.9202676, - "arcface_torch\\configs\\ms1mv3_mobileface_lr02.py": 1647657822.9202676, - "arcface_torch\\configs\\ms1mv3_r100_lr02.py": 1647657822.9202676, - "arcface_torch\\configs\\ms1mv3_r50_lr02.py": 1647657822.9202676, - "arcface_torch\\configs\\webface42m_mobilefacenet_pfc02_bs8k_16gpus.py": 1647657822.9212687, - "arcface_torch\\configs\\webface42m_r100_lr01_pfc02_bs4k_16gpus.py": 1647657822.9212687, - "arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_32gpus.py": 1647657822.9212687, - "arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_8gpus.py": 1647657822.9212687, - "arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs8k_16gpus.py": 1647657822.9212687, - "arcface_torch\\configs\\__init__.py": 1647657822.9192681, - "arcface_torch\\eval\\verification.py": 1647657822.923268, - "arcface_torch\\eval\\__init__.py": 1647657822.923268, - "arcface_torch\\utils\\plot.py": 1647657822.927269, - "arcface_torch\\utils\\utils_callbacks.py": 1647657822.927269, - "arcface_torch\\utils\\utils_config.py": 1647657822.927269, - "arcface_torch\\utils\\utils_logging.py": 1647657822.927269, - "arcface_torch\\utils\\__init__.py": 1647657822.9262686, - "components\\LSTU.py": 1647702612.240765, - "test_scripts\\tester_ID_Pose.py": 1647657822.946273, - "train_scripts\\trainer_distillation_mgpu_withrec_importweight.py": 1647657822.9592762, - "train_scripts\\trainer_multi_gpu_CUT.py": 1647676964.475, - "train_scripts\\trainer_multi_gpu_cycle.py": 1647705628.7020626, - "components\\Generator_LSTU_config.py": 1647954099.1135788, - "components\\Generator_Res_config.py": 1648006159.4385264, - "train_yamls\\train_cycleloss_res.yaml": 1648006232.5734456 + "components\\Generator_modulation_depthwise_config.py": 1645262162.9779513, + "components\\Generator_modulation_up.py": 1644946498.7005584, + "components\\Generator_oriae_modulation.py": 1644897798.1987727, + "components\\Generator_ori_config.py": 1646329319.6131227, + "train_scripts\\trainer_multi_gpu1.py": 1644859528.8428593, + "train_yamls\\train_Depthwise.yaml": 1644860961.099242, + "train_yamls\\train_depthwise_modulation.yaml": 1645035964.9551077, + "train_yamls\\train_oriae_modulation.yaml": 1644897891.2576747, + "train_distillation_mgpu.py": 1645554603.908166, + "components\\DeConv.py": 1645263338.9001615, + "components\\DeConv_Depthwise_ECA.py": 1645265769.1076133, + "components\\ECA.py": 1614848426.9604986, + "components\\ECA_Depthwise_Conv.py": 1645265754.2023985, + "components\\Generator_eca_depthwise.py": 1645266338.9750814, + "losses\\KA.py": 1646388425.4841197, + "train_scripts\\trainer_distillation_mgpu.py": 1645601961.4139585, + "train_yamls\\train_distillation.yaml": 1645600099.540936, + "annotation.py": 1648654581.017103, + "components\\DeConv_ECA_Invo.py": 1645869347.379311, + "components\\DeConv_Invobn.py": 1645862876.018001, + "components\\Generator_Invobn_config.py": 1645929418.6924264, + "components\\Generator_Invobn_config1.py": 1645862695.8743145, + "components\\misc\\Involution_BN.py": 1645867197.3984175, + "components\\misc\\Involution_ECA.py": 1645869012.4927464, + "train_yamls\\train_Invobn_config.yaml": 1646101598.499709, + "components\\Generator_Invobn_config2.py": 1645962618.7056074, + "components\\Generator_Invobn_config3.py": 1646302561.1984286, + "components\\Generator_ori_modulation_config.py": 1646329636.719998, + "test_scripts\\tester_image_allstep.py": 1646312637.9363256, + "train_yamls\\train_ori_modulation_config.yaml": 1646330406.200162, + "test_arcface.py": 1647448497.69041, + "arcface_torch\\dataset.py": 1647445446.261035, + "arcface_torch\\eval_ijbc.py": 1647445446.2630043, + "arcface_torch\\inference.py": 1647445446.2630043, + "arcface_torch\\losses.py": 1647445446.2630043, + "arcface_torch\\lr_scheduler.py": 1647445446.2630043, + "arcface_torch\\onnx_helper.py": 1647445446.2630043, + "arcface_torch\\onnx_ijbc.py": 1647445446.2640254, + "arcface_torch\\partial_fc.py": 1647445446.2640254, + "arcface_torch\\torch2onnx.py": 1647445446.2640254, + "arcface_torch\\train.py": 1647445446.2649992, + "arcface_torch\\backbones\\iresnet.py": 1647445446.2580183, + "arcface_torch\\backbones\\iresnet2060.py": 1647445446.2580183, + "arcface_torch\\backbones\\mobilefacenet.py": 1647445446.2580183, + "arcface_torch\\backbones\\__init__.py": 1647445446.25702, + "arcface_torch\\configs\\3millions.py": 1647445446.2580183, + "arcface_torch\\configs\\base.py": 1647445446.259039, + "arcface_torch\\configs\\glint360k_mobileface_lr02_bs4k.py": 1647445446.259039, + "arcface_torch\\configs\\glint360k_r100_lr02_bs4k_16gpus.py": 1647445446.259039, + "arcface_torch\\configs\\ms1mv3_mobileface_lr02.py": 1647445446.259039, + "arcface_torch\\configs\\ms1mv3_r100_lr02.py": 1647445446.259039, + "arcface_torch\\configs\\ms1mv3_r50_lr02.py": 1647445446.260039, + "arcface_torch\\configs\\webface42m_mobilefacenet_pfc02_bs8k_16gpus.py": 1647445446.260039, + "arcface_torch\\configs\\webface42m_r100_lr01_pfc02_bs4k_16gpus.py": 1647445446.260039, + "arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_32gpus.py": 1647445446.260039, + "arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_8gpus.py": 1647445446.260039, + "arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs8k_16gpus.py": 1647445446.260039, + "arcface_torch\\configs\\__init__.py": 1647445446.2580183, + "arcface_torch\\eval\\verification.py": 1647445446.2620306, + "arcface_torch\\eval\\__init__.py": 1647445446.2620306, + "arcface_torch\\utils\\plot.py": 1647445446.2649992, + "arcface_torch\\utils\\utils_callbacks.py": 1647445446.2649992, + "arcface_torch\\utils\\utils_config.py": 1647445446.2649992, + "arcface_torch\\utils\\utils_logging.py": 1647445446.2659965, + "arcface_torch\\utils\\__init__.py": 1647445446.2649992, + "components\\LSTU.py": 1648482475.4378786, + "test_scripts\\tester_ID_Pose.py": 1646558809.85301, + "train_scripts\\trainer_distillation_mgpu_withrec_importweight.py": 1646391740.1106014, + "train_scripts\\trainer_multi_gpu_CUT.py": 1647769120.5968685, + "train_scripts\\trainer_multi_gpu_cycle.py": 1648313934.2140906, + "components\\Generator_LSTU_config.py": 1648028831.4087331, + "components\\Generator_Res_config.py": 1648053382.053794, + "train_yamls\\train_cycleloss_res.yaml": 1648103614.4515965, + "clear_dataset.py": 1648920044.7807672, + "id_cos.py": 1648569510.5593822, + "translation_list2json.py": 1648106406.478007, + "components\\Generator_ResSkip_config.py": 1648526573.2056787, + "components\\Generator_Res_config1.py": 1648092270.7609806, + "components\\Generator_Res_config2.py": 1648103885.6257715, + "test_scripts\\tester_image_list.py": 1648145245.0948818, + "test_scripts\\tester_image_nofusion.py": 1648096849.0402405, + "train_yamls\\train_cycleloss_resskip.yaml": 1648313962.968481, + "check_list.txt": 1648657338.5051336, + "test_imgs_list.txt": 1649574560.1498592, + "vggface2hq_failed.txt": 1648926017.999394, + "arcface_torch\\requirement.txt": 1647445446.2640254, + "wandb\\run-20220129_032741-340btp9k\\files\\requirements.txt": 1643398065.409959, + "wandb\\run-20220129_032939-2nmaozxq\\files\\requirements.txt": 1643398182.647548, + "wandb\\run-20220129_033051-21z19tyg\\files\\requirements.txt": 1643398254.926299, + "wandb\\run-20220129_033202-16la4gpu\\files\\requirements.txt": 1643398325.8784783, + "wandb\\run-20220129_034327-1bmseytq\\files\\requirements.txt": 1643399010.865907, + "wandb\\run-20220129_034859-2puk6sph\\files\\requirements.txt": 1643399343.3508356, + "wandb\\run-20220129_035624-3hmwgcgw\\files\\requirements.txt": 1643399787.8869605, + "components\\Generator_featout_config.py": 1648877243.4813964, + "components\\Generator_ResSkip_config1.py": 1648530486.7970698, + "components\\LSTU_Config.py": 1648528200.7229428, + "components\\Nonstau_Discriminator.py": 1648476236.8430562, + "components\\Nonstau_Discriminator_FM.py": 1649833901.6153808, + "metrics\\equivariance.py": 1640773190.0, + "metrics\\frechet_inception_distance.py": 1640773190.0, + "metrics\\inception_score.py": 1640773190.0, + "metrics\\kernel_inception_distance.py": 1640773190.0, + "metrics\\metric_main.py": 1640773190.0, + "metrics\\metric_utils.py": 1640773190.0, + "metrics\\perceptual_path_length.py": 1640773190.0, + "metrics\\precision_recall.py": 1640773190.0, + "test_scripts\\tester_image_list_w_mask.py": 1649074097.8263073, + "test_scripts\\tester_image_w_mask.py": 1648529828.1194224, + "train_scripts\\trainer_mgpu_fm.py": 1648878513.5066502, + "train_scripts\\trainer_multi_gpu_cycle_nonstatue_dis.py": 1648559801.6695006, + "train_yamls\\train_cycleloss_fm_nonstatu.yaml": 1648572102.661056, + "train_yamls\\train_cycleloss_resskip_nonstatu.yaml": 1648527833.9132054, + "components\\Generator_Res_config3.py": 1648628068.1252878, + "data_tools\\data_loader_FFHQ_multigpu.py": 1648640139.408, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\align_and_crop_dir.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\generate_mask.py": 1648652827.224725, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\test_enhance_dir_align.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\test_enhance_dir_unalign.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\test_enhance_single_unalign.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\train.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\base_dataset.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\celebahqmask_dataset.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\ffhq_dataset.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\image_folder.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\single_dataset.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\__init__.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\base_model.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\blocks.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\enhance_model.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\loss.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\networks.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\parse_model.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\psfrnet.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\__init__.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\base_options.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\test_options.py": 1648647365.0084722, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\train_options.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\__init__.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\utils\\logger.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\utils\\timer.py": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\utils\\utils.py": 1648653194.9661705, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\requirements.txt": 1640774152.0, + "face_parse\\PSFRGAN-master\\PSFRGAN-master\\check_points\\experiment_name\\test_opt.txt": 1648653211.0647676, + "components\\Generator_maskhead_config.py": 1648919192.4128494, + "data_tools\\data_loader_VGGFace2HQ_multigpu_w_mask.py": 1648950658.8917239, + "train_scripts\\trainer_mgpu_fm_w_mask.py": 1649842622.1032736, + "train_scripts\\trainer_mgpu_maskloss.py": 1650003673.00616, + "train_yamls\\train_maskhead_fm_nonstatu.yaml": 1648918857.2325976, + "train_yamls\\train_maskhead_maskloss.yaml": 1648919545.869737, + "dataset.check.py": 1648925868.9100032, + "components\\Generator_involution_maskhead_config.py": 1648919192.4128494, + "components\\Generator_VGGStyle_maskhead_config.py": 1649177890.0438528, + "train_yamls\\train_maskhead_fm_vggstyle.yaml": 1649177995.515654, + "train_yamls\\train_maskloss.yaml": 1648920140.455, + "components\\Generator_maskhead_config1.py": 1649248730.7064345, + "train_yamls\\train_maskhead_hififace.yaml": 1649345218.4856737, + "filter.py": 1649836163.761168, + "test_json.py": 1649232568.5223434, + "components\\Generator_2maskhead_config copy.py": 1649816572.443475, + "components\\Generator_2maskhead_config.py": 1649816572.443475, + "components\\Generator_maskhead_config2.py": 1649833973.4127545, + "components\\Generator_starganv2.py": 1649845840.4322417, + "face_enhancer\\gfpgan\\train.py": 1644942770.0, + "face_enhancer\\gfpgan\\utils.py": 1644942770.0, + "face_enhancer\\gfpgan\\version.py": 1648618484.3356814, + "face_enhancer\\gfpgan\\__init__.py": 1644942770.0, + "face_enhancer\\gfpgan\\archs\\arcface_arch.py": 1644942770.0, + "face_enhancer\\gfpgan\\archs\\gfpganv1_arch.py": 1644942770.0, + "face_enhancer\\gfpgan\\archs\\gfpganv1_clean_arch.py": 1644942770.0, + "face_enhancer\\gfpgan\\archs\\gfpgan_bilinear_arch.py": 1644942770.0, + "face_enhancer\\gfpgan\\archs\\stylegan2_bilinear_arch.py": 1644942770.0, + "face_enhancer\\gfpgan\\archs\\stylegan2_clean_arch.py": 1644942770.0, + "face_enhancer\\gfpgan\\archs\\__init__.py": 1644942770.0, + "face_enhancer\\gfpgan\\data\\ffhq_degradation_dataset.py": 1644942770.0, + "face_enhancer\\gfpgan\\data\\__init__.py": 1644942770.0, + "face_enhancer\\gfpgan\\models\\gfpgan_model.py": 1644942770.0, + "face_enhancer\\gfpgan\\models\\__init__.py": 1644942770.0, + "face_enhancer\\scripts\\convert_gfpganv_to_clean.py": 1644942770.0, + "face_enhancer\\scripts\\parse_landmark.py": 1644942770.0, + "test_scripts\\tester_image_list_w_2mask.py": 1649768648.90081, + "test_scripts\\tester_image_w_2mask.py": 1649729361.2799067, + "test_scripts\\tester_image_w_mask_gfpgan.py": 1649872098.8747168, + "test_scripts\\tester_video_gfpgan.py": 1649907748.9359775, + "train_scripts\\trainer_mgpu_2maskloss.py": 1649699504.6697042, + "train_yamls\\train_1maskhead.yaml": 1650003571.40854, + "train_yamls\\train_2maskhead.yaml": 1649953481.822017, + "train_yamls\\train_maskhead_hififace1.yaml": 1649471700.4954374, + "components\\Generator_2mask.py": 1649953828.8860412 } \ No newline at end of file diff --git a/GUI/guiignore.json b/GUI/guiignore.json index 7b212a2..43b2f1c 100644 --- a/GUI/guiignore.json +++ b/GUI/guiignore.json @@ -2,7 +2,8 @@ "white_list": { "extension": [ "py", - "yaml" + "yaml", + "txt" ], "file": [], "path": [] diff --git a/GUI/machines.json b/GUI/machines.json index b6bad2a..313a9a2 100644 --- a/GUI/machines.json +++ b/GUI/machines.json @@ -8,6 +8,15 @@ "ckp_path": "train_logs", "logfilename": "filestate_machine0.json" }, + { + "ip": "119.29.91.52", + "user": "ubuntu", + "port": 22, + "passwd": "zpKlOW0sMlyt!xhE", + "path": "/home/ubuntu/CXH/simswap_plus", + "ckp_path": "train_logs", + "logfilename": "filestate_machine3.json" + }, { "ip": "2001:da8:8000:6880:f284:d61c:3c76:f9cb", "user": "ps", diff --git a/README.md b/README.md index fde62be..52d0739 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Simswap++ ## Dependencies +- moviepy - python >= 3.7 - yaml (pip install pyyaml) - paramiko (For ssh file transportation) diff --git a/annotation.py b/annotation.py index d4e5a36..81b0c9d 100644 --- a/annotation.py +++ b/annotation.py @@ -5,7 +5,7 @@ # Created Date: Saturday February 26th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Sunday, 27th February 2022 11:03:58 am +# Last Modified: Wednesday, 30th March 2022 11:36:20 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -33,7 +33,7 @@ def str2bool(v): def getParameters(): parser = argparse.ArgumentParser() # general - parser.add_argument('--image_dir', type=str, default="G:\\VGGFace2-HQ\\VGGface2_None_norm_512_true_bygfpgan") + parser.add_argument('--image_dir', type=str, default="G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan") parser.add_argument('--savetxt', type=str, default="./check_list.txt") parser.add_argument('--winWidth', type=int, default=512) parser.add_argument('--winHeight', type=int, default=512) diff --git a/breakpoint.json b/breakpoint.json index fa8f0bf..6726984 100644 --- a/breakpoint.json +++ b/breakpoint.json @@ -1,6 +1,6 @@ { "breakpoint": [ - 31, - 110 + 54, + 101 ] } \ No newline at end of file diff --git a/check_list.txt b/check_list.txt index 8ea0a9d..54858f9 100644 --- a/check_list.txt +++ b/check_list.txt @@ -6984,3 +6984,1281 @@ n000035\0159_02.jpg n000035\0167_01.jpg n000035\0170_01.jpg n000035\0200_01.jpg +n000002\0013_01.jpg +n000002\0018_01.jpg +n000002\0023_01.jpg +n000002\0027_01.jpg +n000002\0031_06.jpg +n000002\0031_08.jpg +n000002\0042_01.jpg +n000002\0058_01.jpg +n000002\0068_01.jpg +n000002\0075_01.jpg +n000002\0078_01.jpg +n000002\0094_01.jpg +n000002\0095_01.jpg +n000002\0110_03.jpg +n000002\0125_01.jpg +n000002\0141_01.jpg +n000002\0142_01.jpg +n000002\0152_02.jpg +n000002\0170_01.jpg +n000002\0171_01.jpg +n000002\0179_01.jpg +n000002\0180_01.jpg +n000002\0184_01.jpg +n000002\0193_01.jpg +n000002\0197_01.jpg +n000002\0199_01.jpg +n000002\0201_01.jpg +n000002\0209_01.jpg +n000002\0210_01.jpg +n000002\0216_01.jpg +n000002\0217_01.jpg +n000002\0218_01.jpg +n000002\0227_02.jpg +n000002\0231_01.jpg +n000002\0233_01.jpg +n000002\0237_01.jpg +n000002\0240_01.jpg +n000002\0239_01.jpg +n000002\0245_01.jpg +n000002\0249_01.jpg +n000002\0257_01.jpg +n000002\0259_01.jpg +n000002\0261_01.jpg +n000002\0262_01.jpg +n000002\0265_01.jpg +n000002\0268_01.jpg +n000002\0270_01.jpg +n000002\0275_01.jpg +n000002\0277_01.jpg +n000002\0279_01.jpg +n000002\0284_02.jpg +n000002\0298_01.jpg +n000002\0304_01.jpg +n000002\0305_01.jpg +n000002\0311_01.jpg +n000002\0312_01.jpg +n000002\0316_01.jpg +n000002\0317_01.jpg +n000002\0321_01.jpg +n000002\0323_01.jpg +n000003\0006_01.jpg +n000003\0010_01.jpg +n000003\0011_02.jpg +n000003\0013_01.jpg +n000003\0013_01.jpg +n000003\0021_01.jpg +n000003\0026_01.jpg +n000003\0027_02.jpg +n000003\0036_01.jpg +n000003\0038_01.jpg +n000003\0044_02.jpg +n000003\0054_01.jpg +n000003\0055_01.jpg +n000003\0064_02.jpg +n000003\0073_01.jpg +n000003\0074_02.jpg +n000003\0083_01.jpg +n000003\0085_01.jpg +n000003\0086_01.jpg +n000003\0099_03.jpg +n000003\0097_01.jpg +n000003\0100_03.jpg +n000003\0101_01.jpg +n000003\0102_01.jpg +n000003\0103_01.jpg +n000003\0104_01.jpg +n000003\0108_01.jpg +n000003\0115_01.jpg +n000003\0116_02.jpg +n000003\0118_01.jpg +n000003\0120_01.jpg +n000003\0122_06.jpg +n000003\0124_03.jpg +n000003\0125_01.jpg +n000003\0129_01.jpg +n000003\0130_01.jpg +n000003\0131_01.jpg +n000003\0133_01.jpg +n000003\0136_01.jpg +n000003\0137_01.jpg +n000003\0143_01.jpg +n000003\0144_01.jpg +n000003\0149_01.jpg +n000003\0155_01.jpg +n000003\0157_02.jpg +n000003\0161_01.jpg +n000003\0162_01.jpg +n000003\0163_02.jpg +n000003\0164_02.jpg +n000003\0165_01.jpg +n000003\0167_02.jpg +n000003\0168_02.jpg +n000003\0170_01.jpg +n000003\0172_02.jpg +n000003\0173_01.jpg +n000003\0177_02.jpg +n000003\0181_01.jpg +n000003\0183_02.jpg +n000003\0200_02.jpg +n000003\0201_01.jpg +n000003\0202_01.jpg +n000003\0206_01.jpg +n000003\0207_02.jpg +n000003\0222_01.jpg +n000003\0226_01.jpg +n000003\0240_01.jpg +n000003\0241_01.jpg +n000003\0244_02.jpg +n000003\0245_01.jpg +n000003\0246_01.jpg +n000003\0249_01.jpg +n000003\0253_03.jpg +n000004\0018_01.jpg +n000004\0040_01.jpg +n000004\0041_01.jpg +n000004\0057_03.jpg +n000004\0060_01.jpg +n000004\0073_01.jpg +n000004\0090_01.jpg +n000004\0097_01.jpg +n000004\0124_01.jpg +n000004\0131_01.jpg +n000004\0165_01.jpg +n000004\0171_03.jpg +n000004\0175_01.jpg +n000004\0178_01.jpg +n000004\0182_02.jpg +n000004\0184_02.jpg +n000004\0224_02.jpg +n000004\0225_01.jpg +n000004\0228_01.jpg +n000004\0235_01.jpg +n000004\0241_01.jpg +n000004\0243_01.jpg +n000004\0248_01.jpg +n000004\0252_01.jpg +n000004\0251_01.jpg +n000004\0253_02.jpg +n000004\0255_02.jpg +n000004\0260_04.jpg +n000004\0268_02.jpg +n000004\0272_01.jpg +n000004\0274_01.jpg +n000004\0276_01.jpg +n000004\0277_01.jpg +n000004\0279_01.jpg +n000004\0290_02.jpg +n000004\0296_02.jpg +n000004\0315_01.jpg +n000004\0324_02.jpg +n000004\0328_01.jpg +n000004\0334_01.jpg +n000004\0340_01.jpg +n000004\0343_01.jpg +n000004\0350_01.jpg +n000004\0354_01.jpg +n000004\0391_01.jpg +n000004\0393_01.jpg +n000004\0396_01.jpg +n000004\0402_01.jpg +n000004\0420_01.jpg +n000005\0025_01.jpg +n000005\0045_01.jpg +n000005\0052_01.jpg +n000005\0063_01.jpg +n000005\0078_01.jpg +n000005\0080_01.jpg +n000005\0087_01.jpg +n000005\0101_01.jpg +n000005\0102_01.jpg +n000005\0104_01.jpg +n000005\0105_01.jpg +n000005\0106_01.jpg +n000005\0108_01.jpg +n000005\0117_01.jpg +n000005\0124_01.jpg +n000005\0130_01.jpg +n000005\0136_01.jpg +n000005\0142_01.jpg +n000005\0143_01.jpg +n000005\0146_01.jpg +n000005\0148_01.jpg +n000005\0150_02.jpg +n000005\0156_01.jpg +n000005\0160_02.jpg +n000005\0163_02.jpg +n000005\0164_01.jpg +n000005\0165_01.jpg +n000005\0167_02.jpg +n000005\0174_01.jpg +n000005\0175_01.jpg +n000005\0180_01.jpg +n000005\0181_02.jpg +n000005\0182_01.jpg +n000005\0185_01.jpg +n000005\0190_01.jpg +n000005\0192_01.jpg +n000005\0194_03.jpg +n000005\0195_01.jpg +n000005\0197_02.jpg +n000005\0203_01.jpg +n000005\0205_01.jpg +n000005\0210_02.jpg +n000005\0213_01.jpg +n000005\0219_01.jpg +n000005\0220_01.jpg +n000005\0221_01.jpg +n000005\0222_02.jpg +n000005\0226_01.jpg +n000005\0229_01.jpg +n000005\0233_01.jpg +n000005\0241_01.jpg +n000005\0284_02.jpg +n000005\0306_01.jpg +n000005\0350_01.jpg +n000005\0406_01.jpg +n000005\0413_01.jpg +n000005\0424_01.jpg +n000005\0430_02.jpg +n000005\0431_01.jpg +n000006\0001_01.jpg +n000006\0004_04.jpg +n000006\0051_01.jpg +n000006\0101_01.jpg +n000006\0146_01.jpg +n000006\0156_01.jpg +n000006\0165_02.jpg +n000006\0174_01.jpg +n000006\0183_01.jpg +n000006\0185_02.jpg +n000006\0187_03.jpg +n000006\0187_04.jpg +n000006\0189_01.jpg +n000006\0198_01.jpg +n000006\0206_01.jpg +n000006\0225_01.jpg +n000006\0231_01.jpg +n000006\0235_01.jpg +n000006\0242_01.jpg +n000006\0248_01.jpg +n000006\0249_01.jpg +n000006\0252_01.jpg +n000006\0257_01.jpg +n000006\0258_03.jpg +n000006\0262_01.jpg +n000006\0264_01.jpg +n000006\0268_01.jpg +n000006\0275_04.jpg +n000006\0279_01.jpg +n000006\0283_01.jpg +n000006\0284_01.jpg +n000006\0314_01.jpg +n000006\0316_01.jpg +n000006\0319_01.jpg +n000006\0323_01.jpg +n000006\0324_01.jpg +n000006\0325_01.jpg +n000006\0326_01.jpg +n000006\0328_02.jpg +n000006\0329_01.jpg +n000006\0332_01.jpg +n000006\0333_01.jpg +n000006\0334_01.jpg +n000006\0335_01.jpg +n000006\0336_01.jpg +n000006\0337_01.jpg +n000006\0338_01.jpg +n000006\0341_01.jpg +n000006\0347_01.jpg +n000006\0349_02.jpg +n000006\0284_01.jpg +n000006\0283_01.jpg +n000006\0314_01.jpg +n000006\0315_01.jpg +n000006\0316_01.jpg +n000006\0319_01.jpg +n000006\0324_01.jpg +n000006\0333_01.jpg +n000006\0332_01.jpg +n000006\0334_01.jpg +n000006\0335_01.jpg +n000006\0336_01.jpg +n000006\0340_01.jpg +n000006\0341_01.jpg +n000006\0343_01.jpg +n000006\0349_02.jpg +n000006\0350_01.jpg +n000006\0352_01.jpg +n000006\0353_01.jpg +n000006\0354_01.jpg +n000006\0356_01.jpg +n000006\0358_01.jpg +n000006\0359_03.jpg +n000006\0360_01.jpg +n000006\0361_01.jpg +n000006\0362_01.jpg +n000006\0363_01.jpg +n000006\0367_01.jpg +n000006\0368_02.jpg +n000006\0369_01.jpg +n000006\0372_01.jpg +n000006\0374_01.jpg +n000006\0377_01.jpg +n000006\0380_01.jpg +n000006\0384_01.jpg +n000006\0388_01.jpg +n000006\0389_01.jpg +n000006\0396_01.jpg +n000006\0397_01.jpg +n000006\0399_04.jpg +n000006\0400_04.jpg +n000006\0404_01.jpg +n000006\0406_01.jpg +n000006\0411_05.jpg +n000006\0413_01.jpg +n000006\0418_01.jpg +n000006\0419_01.jpg +n000006\0420_01.jpg +n000006\0426_01.jpg +n000006\0432_01.jpg +n000006\0433_03.jpg +n000006\0457_03.jpg +n000006\0467_01.jpg +n000006\0475_01.jpg +n000006\0480_02.jpg +n000006\0521_01.jpg +n000006\0523_02.jpg +n000006\0524_01.jpg +n000006\0526_01.jpg +n000006\0528_01.jpg +n000006\0532_01.jpg +n000006\0533_01.jpg +n000006\0536_01.jpg +n000006\0538_01.jpg +n000006\0542_01.jpg +n000006\0543_01.jpg +n000006\0544_02.jpg +n000006\0545_02.jpg +n000006\0548_03.jpg +n000006\0549_02.jpg +n000006\0552_01.jpg +n000006\0554_01.jpg +n000007\0002_01.jpg +n000007\0006_02.jpg +n000007\0007_01.jpg +n000007\0011_01.jpg +n000007\0012_01.jpg +n000007\0017_01.jpg +n000007\0018_01.jpg +n000007\0022_02.jpg +n000007\0023_01.jpg +n000007\0028_01.jpg +n000007\0033_02.jpg +n000007\0039_01.jpg +n000007\0040_02.jpg +n000007\0044_01.jpg +n000007\0052_01.jpg +n000007\0053_01.jpg +n000007\0054_02.jpg +n000007\0055_01.jpg +n000007\0057_01.jpg +n000007\0058_02.jpg +n000007\0060_01.jpg +n000007\0070_01.jpg +n000007\0071_01.jpg +n000007\0081_01.jpg +n000007\0085_03.jpg +n000007\0096_01.jpg +n000007\0099_01.jpg +n000007\0113_02.jpg +n000007\0118_04.jpg +n000007\0121_01.jpg +n000007\0122_01.jpg +n000007\0123_01.jpg +n000007\0124_01.jpg +n000007\0138_03.jpg +n000007\0141_03.jpg +n000007\0142_02.jpg +n000007\0145_04.jpg +n000007\0146_02.jpg +n000007\0151_03.jpg +n000007\0152_01.jpg +n000007\0153_01.jpg +n000007\0160_03.jpg +n000007\0160_04.jpg +n000007\0160_05.jpg +n000007\0160_05.jpg +n000007\0165_02.jpg +n000007\0166_01.jpg +n000007\0168_01.jpg +n000007\0169_01.jpg +n000007\0170_01.jpg +n000007\0171_02.jpg +n000007\0171_04.jpg +n000007\0172_01.jpg +n000007\0175_01.jpg +n000007\0176_01.jpg +n000007\0177_04.jpg +n000007\0185_01.jpg +n000007\0187_01.jpg +n000007\0188_01.jpg +n000007\0189_01.jpg +n000007\0195_01.jpg +n000007\0196_01.jpg +n000007\0197_02.jpg +n000007\0198_03.jpg +n000007\0200_02.jpg +n000007\0201_01.jpg +n000007\0205_01.jpg +n000007\0208_01.jpg +n000007\0209_01.jpg +n000007\0215_02.jpg +n000007\0218_01.jpg +n000007\0221_02.jpg +n000007\0227_03.jpg +n000007\0233_02.jpg +n000007\0239_01.jpg +n000007\0241_01.jpg +n000007\0246_01.jpg +n000007\0246_02.jpg +n000007\0247_01.jpg +n000007\0271_02.jpg +n000007\0280_01.jpg +n000007\0283_02.jpg +n000007\0311_02.jpg +n000007\0327_05.jpg +n000007\0379_02.jpg +n000007\0381_01.jpg +n000007\0391_01.jpg +n000007\0411_02.jpg +n000007\0419_01.jpg +n000007\0428_01.jpg +n000007\0430_01.jpg +n000008\0003_01.jpg +n000008\0005_02.jpg +n000008\0020_01.jpg +n000008\0079_01.jpg +n000008\0080_01.jpg +n000008\0091_01.jpg +n000008\0094_01.jpg +n000008\0095_01.jpg +n000008\0096_01.jpg +n000008\0098_01.jpg +n000008\0101_01.jpg +n000008\0102_01.jpg +n000008\0111_01.jpg +n000008\0112_01.jpg +n000008\0118_01.jpg +n000008\0121_01.jpg +n000008\0124_01.jpg +n000008\0127_01.jpg +n000008\0143_01.jpg +n000008\0153_01.jpg +n000008\0166_02.jpg +n000008\0174_01.jpg +n000008\0177_01.jpg +n000008\0193_01.jpg +n000008\0195_01.jpg +n000008\0196_01.jpg +n000008\0197_01.jpg +n000008\0199_01.jpg +n000008\0201_01.jpg +n000008\0204_01.jpg +n000008\0205_01.jpg +n000008\0207_01.jpg +n000008\0208_01.jpg +n000008\0212_02.jpg +n000008\0213_01.jpg +n000008\0218_02.jpg +n000008\0227_01.jpg +n000008\0239_01.jpg +n000008\0250_02.jpg +n000008\0251_01.jpg +n000008\0259_01.jpg +n000008\0276_01.jpg +n000008\0277_01.jpg +n000008\0278_01.jpg +n000008\0285_01.jpg +n000008\0288_03.jpg +n000008\0302_01.jpg +n000008\0303_01.jpg +n000008\0308_01.jpg +n000008\0309_01.jpg +n000008\0327_01.jpg +n000008\0347_01.jpg +n000010\0063_01.jpg +n000010\0079_10.jpg +n000010\0080_01.jpg +n000010\0085_01.jpg +n000010\0102_01.jpg +n000010\0105_05.jpg +n000010\0127_01.jpg +n000010\0128_01.jpg +n000010\0130_02.jpg +n000010\0130_04.jpg +n000010\0130_06.jpg +n000010\0131_02.jpg +n000010\0137_01.jpg +n000010\0138_01.jpg +n000010\0138_02.jpg +n000010\0147_01.jpg +n000010\0154_01.jpg +n000010\0157_02.jpg +n000010\0158_01.jpg +n000010\0166_01.jpg +n000010\0167_01.jpg +n000010\0214_01.jpg +n000010\0280_01.jpg +n000011\0021_01.jpg +n000011\0099_02.jpg +n000011\0114_02.jpg +n000011\0128_02.jpg +n000011\0166_01.jpg +n000011\0186_06.jpg +n000011\0210_01.jpg +n000011\0215_01.jpg +n000011\0219_01.jpg +n000011\0221_01.jpg +n000011\0224_02.jpg +n000011\0227_02.jpg +n000011\0228_01.jpg +n000011\0234_01.jpg +n000011\0238_01.jpg +n000011\0247_01.jpg +n000011\0248_01.jpg +n000011\0249_01.jpg +n000011\0267_01.jpg +n000011\0270_01.jpg +n000011\0271_02.jpg +n000011\0273_01.jpg +n000011\0279_02.jpg +n000011\0280_04.jpg +n000011\0281_01.jpg +n000011\0285_06.jpg +n000011\0293_01.jpg +n000011\0296_01.jpg +n000011\0299_01.jpg +n000011\0306_01.jpg +n000011\0306_07.jpg +n000011\0312_02.jpg +n000011\0313_01.jpg +n000011\0316_01.jpg +n000011\0317_02.jpg +n000011\0318_01.jpg +n000011\0324_01.jpg +n000011\0329_02.jpg +n000011\0334_02.jpg +n000011\0382_01.jpg +n000011\0385_01.jpg +n000011\0387_01.jpg +n000011\0397_01.jpg +n000011\0407_06.jpg +n000011\0408_01.jpg +n000011\0417_01.jpg +n000011\0424_01.jpg +n000011\0426_03.jpg +n000012\0012_02.jpg +n000012\0029_03.jpg +n000012\0032_01.jpg +n000012\0046_01.jpg +n000012\0056_02.jpg +n000012\0067_01.jpg +n000012\0068_01.jpg +n000012\0069_01.jpg +n000012\0076_01.jpg +n000012\0078_01.jpg +n000012\0100_02.jpg +n000012\0101_01.jpg +n000012\0102_01.jpg +n000012\0103_01.jpg +n000012\0109_01.jpg +n000012\0109_02.jpg +n000012\0109_03.jpg +n000012\0110_01.jpg +n000012\0112_01.jpg +n000012\0114_01.jpg +n000012\0116_01.jpg +n000012\0122_01.jpg +n000012\0141_01.jpg +n000012\0179_01.jpg +n000012\0181_01.jpg +n000012\0194_01.jpg +n000012\0208_01.jpg +n000012\0210_01.jpg +n000012\0210_02.jpg +n000012\0211_01.jpg +n000012\0243_01.jpg +n000012\0253_01.jpg +n000012\0254_01.jpg +n000012\0257_01.jpg +n000012\0263_03.jpg +n000012\0266_01.jpg +n000012\0273_02.jpg +n000012\0277_02.jpg +n000012\0279_01.jpg +n000012\0285_02.jpg +n000012\0288_01.jpg +n000012\0288_02.jpg +n000012\0289_01.jpg +n000012\0291_03.jpg +n000012\0299_01.jpg +n000012\0301_02.jpg +n000012\0304_01.jpg +n000012\0306_02.jpg +n000012\0309_03.jpg +n000012\0309_01.jpg +n000012\0315_02.jpg +n000012\0320_01.jpg +n000012\0320_02.jpg +n000012\0335_02.jpg +n000012\0340_01.jpg +n000012\0350_01.jpg +n000012\0350_02.jpg +n000012\0358_01.jpg +n000012\0360_01.jpg +n000012\0375_01.jpg +n000012\0406_01.jpg +n000012\0406_02.jpg +n000012\0410_01.jpg +n000012\0412_01.jpg +n000012\0414_02.jpg +n000012\0422_01.jpg +n000012\0426_01.jpg +n000012\0426_02.jpg +n000012\0430_01.jpg +n000013\0013_01.jpg +n000013\0014_01.jpg +n000013\0023_01.jpg +n000013\0029_04.jpg +n000013\0030_01.jpg +n000013\0041_01.jpg +n000013\0048_01.jpg +n000013\0057_01.jpg +n000013\0105_01.jpg +n000013\0106_01.jpg +n000013\0112_01.jpg +n000013\0117_01.jpg +n000013\0118_01.jpg +n000013\0123_03.jpg +n000013\0124_01.jpg +n000013\0127_01.jpg +n000013\0131_02.jpg +n000013\0131_03.jpg +n000013\0134_01.jpg +n000013\0141_01.jpg +n000013\0149_01.jpg +n000013\0157_01.jpg +n000013\0160_01.jpg +n000013\0163_01.jpg +n000013\0164_01.jpg +n000013\0165_01.jpg +n000013\0166_01.jpg +n000013\0168_01.jpg +n000013\0175_01.jpg +n000013\0176_02.jpg +n000013\0177_01.jpg +n000013\0181_02.jpg +n000013\0182_01.jpg +n000013\0186_01.jpg +n000013\0192_01.jpg +n000013\0193_02.jpg +n000013\0193_04.jpg +n000013\0196_01.jpg +n000013\0198_01.jpg +n000013\0201_01.jpg +n000013\0203_01.jpg +n000013\0204_01.jpg +n000013\0205_01.jpg +n000013\0209_01.jpg +n000013\0210_01.jpg +n000013\0211_01.jpg +n000013\0212_01.jpg +n000013\0213_01.jpg +n000013\0215_01.jpg +n000013\0220_01.jpg +n000013\0227_01.jpg +n000013\0230_01.jpg +n000013\0233_01.jpg +n000013\0237_01.jpg +n000013\0236_01.jpg +n000013\0238_01.jpg +n000013\0242_01.jpg +n000013\0245_01.jpg +n000013\0246_03.jpg +n000013\0247_01.jpg +n000013\0248_01.jpg +n000013\0249_02.jpg +n000013\0252_01.jpg +n000013\0253_01.jpg +n000013\0254_01.jpg +n000013\0258_01.jpg +n000013\0259_01.jpg +n000013\0261_01.jpg +n000013\0266_01.jpg +n000013\0268_02.jpg +n000013\0273_01.jpg +n000013\0274_02.jpg +n000013\0283_01.jpg +n000013\0293_01.jpg +n000013\0305_02.jpg +n000013\0316_01.jpg +n000013\0320_01.jpg +n000013\0323_01.jpg +n000013\0330_01.jpg +n000013\0331_01.jpg +n000013\0332_01.jpg +n000013\0340_01.jpg +n000014\0049_01.jpg +n000014\0067_06.jpg +n000014\0130_08.jpg +n000014\0130_09.jpg +n000014\0130_10.jpg +n000014\0130_12.jpg +n000014\0130_13.jpg +n000014\0130_14.jpg +n000014\0130_15.jpg +n000014\0130_19.jpg +n000014\0130_20.jpg +n000014\0130_21.jpg +n000014\0130_22.jpg +n000014\0130_25.jpg +n000014\0130_28.jpg +n000014\0130_30.jpg +n000014\0130_31.jpg +n000014\0130_32.jpg +n000014\0130_33.jpg +n000014\0130_34.jpg +n000014\0130_35.jpg +n000014\0132_01.jpg +n000014\0134_01.jpg +n000014\0158_01.jpg +n000014\0177_01.jpg +n000014\0200_01.jpg +n000014\0201_01.jpg +n000014\0203_01.jpg +n000014\0206_01.jpg +n000014\0208_01.jpg +n000014\0209_01.jpg +n000014\0213_01.jpg +n000014\0214_01.jpg +n000014\0215_01.jpg +n000014\0216_01.jpg +n000014\0217_01.jpg +n000014\0222_01.jpg +n000014\0232_02.jpg +n000014\0233_01.jpg +n000014\0244_02.jpg +n000014\0255_01.jpg +n000014\0289_02.jpg +n000014\0283_01.jpg +n000015\0020_01.jpg +n000015\0021_01.jpg +n000015\0023_01.jpg +n000015\0031_01.jpg +n000015\0034_02.jpg +n000015\0040_01.jpg +n000015\0050_01.jpg +n000015\0050_02.jpg +n000015\0052_02.jpg +n000015\0055_01.jpg +n000015\0056_01.jpg +n000015\0066_01.jpg +n000015\0067_01.jpg +n000015\0068_01.jpg +n000015\0075_01.jpg +n000015\0076_01.jpg +n000015\0078_02.jpg +n000015\0081_02.jpg +n000015\0087_03.jpg +n000015\0088_01.jpg +n000015\0096_01.jpg +n000015\0100_01.jpg +n000015\0101_04.jpg +n000015\0102_01.jpg +n000015\0103_04.jpg +n000015\0104_03.jpg +n000015\0110_01.jpg +n000015\0111_01.jpg +n000015\0112_01.jpg +n000015\0113_03.jpg +n000015\0115_01.jpg +n000015\0116_01.jpg +n000015\0117_01.jpg +n000015\0118_01.jpg +n000015\0119_03.jpg +n000015\0122_01.jpg +n000015\0123_01.jpg +n000015\0126_01.jpg +n000015\0130_01.jpg +n000015\0131_01.jpg +n000015\0134_01.jpg +n000015\0138_02.jpg +n000015\0139_02.jpg +n000015\0140_01.jpg +n000015\0142_01.jpg +n000015\0147_01.jpg +n000015\0151_01.jpg +n000015\0153_02.jpg +n000015\0155_03.jpg +n000015\0161_01.jpg +n000015\0163_03.jpg +n000015\0167_04.jpg +n000015\0169_01.jpg +n000015\0173_01.jpg +n000015\0174_05.jpg +n000015\0175_03.jpg +n000015\0181_03.jpg +n000015\0185_02.jpg +n000015\0186_01.jpg +n000015\0190_02.jpg +n000015\0192_02.jpg +n000015\0194_01.jpg +n000015\0201_01.jpg +n000015\0201_03.jpg +n000015\0206_01.jpg +n000015\0288_03.jpg +n000015\0314_01.jpg +n000015\0344_06.jpg +n000015\0356_01.jpg +n000015\0372_01.jpg +n000015\0393_04.jpg +n000015\0391_01.jpg +n000015\0395_01.jpg +n000015\0415_01.jpg +n000015\0434_02.jpg +n000015\0438_01.jpg +n000015\0438_02.jpg +n000017\0036_01.jpg +n000017\0047_01.jpg +n000017\0236_01.jpg +n000017\0237_01.jpg +n000017\0262_01.jpg +n000017\0269_01.jpg +n000018\0108_01.jpg +n000018\0173_01.jpg +n000018\0206_02.jpg +n000018\0304_01.jpg +n000019\0085_01.jpg +n000019\0089_01.jpg +n000019\0106_03.jpg +n000019\0170_01.jpg +n000019\0234_02.jpg +n000019\0249_01.jpg +n000019\0273_01.jpg +n000019\0275_01.jpg +n000019\0276_01.jpg +n000019\0306_01.jpg +n000019\0309_01.jpg +n000019\0313_01.jpg +n000019\0328_01.jpg +n000019\0331_01.jpg +n000019\0333_01.jpg +n000019\0334_01.jpg +n000019\0337_01.jpg +n000019\0347_01.jpg +n000019\0350_02.jpg +n000020\0243_01.jpg +n000020\0290_01.jpg +n000020\0334_01.jpg +n000020\0400_01.jpg +n000020\0384_01.jpg +n000020\0409_01.jpg +n000020\0418_01.jpg +n000021\0046_01.jpg +n000021\0052_01.jpg +n000021\0087_01.jpg +n000021\0117_01.jpg +n000021\0143_01.jpg +n000021\0184_01.jpg +n000022\0347_01.jpg +n000022\0415_01.jpg +n000023\0008_01.jpg +n000023\0012_01.jpg +n000023\0156_01.jpg +n000023\0162_01.jpg +n000023\0198_01.jpg +n000023\0207_03.jpg +n000023\0256_01.jpg +n000023\0257_01.jpg +n000023\0269_02.jpg +n000023\0285_01.jpg +n000023\0280_01.jpg +n000023\0294_01.jpg +n000023\0319_01.jpg +n000023\0343_01.jpg +n000023\0352_01.jpg +n000023\0359_01.jpg +n000023\0366_01.jpg +n000023\0389_01.jpg +n000024\0046_02.jpg +n000024\0056_01.jpg +n000024\0188_01.jpg +n000024\0258_01.jpg +n000024\0311_01.jpg +n000024\0325_01.jpg +n000024\0327_01.jpg +n000025\0245_01.jpg +n000026\0038_01.jpg +n000026\0060_01.jpg +n000026\0075_01.jpg +n000026\0078_01.jpg +n000026\0082_02.jpg +n000026\0103_01.jpg +n000026\0104_01.jpg +n000026\0125_01.jpg +n000026\0137_01.jpg +n000026\0196_01.jpg +n000026\0280_01.jpg +n000027\0023_02.jpg +n000027\0023_05.jpg +n000027\0097_01.jpg +n000027\0099_01.jpg +n000027\0108_02.jpg +n000027\0115_01.jpg +n000027\0157_02.jpg +n000027\0171_01.jpg +n000027\0182_02.jpg +n000027\0211_02.jpg +n000027\0255_01.jpg +n000027\0256_03.jpg +n000027\0257_01.jpg +n000027\0274_04.jpg +n000027\0318_04.jpg +n000027\0401_01.jpg +n000027\0402_01.jpg +n000027\0438_01.jpg +n000027\0438_02.jpg +n000027\0440_01.jpg +n000027\0442_01.jpg +n000027\0443_01.jpg +n000027\0446_01.jpg +n000027\0456_01.jpg +n000027\0458_01.jpg +n000027\0469_02.jpg +n000027\0493_01.jpg +n000028\0040_04.jpg +n000028\0044_01.jpg +n000028\0056_01.jpg +n000028\0080_01.jpg +n000028\0083_01.jpg +n000028\0088_01.jpg +n000028\0113_01.jpg +n000028\0120_01.jpg +n000028\0138_01.jpg +n000028\0140_02.jpg +n000028\0141_02.jpg +n000028\0144_02.jpg +n000028\0147_01.jpg +n000028\0149_01.jpg +n000028\0156_01.jpg +n000028\0155_01.jpg +n000028\0161_01.jpg +n000028\0175_02.jpg +n000028\0179_01.jpg +n000028\0180_01.jpg +n000028\0205_01.jpg +n000028\0208_02.jpg +n000028\0249_01.jpg +n000028\0300_01.jpg +n000028\0324_02.jpg +n000028\0343_01.jpg +n000028\0392_01.jpg +n000028\0412_02.jpg +n000030\0155_01.jpg +n000030\0157_01.jpg +n000030\0186_01.jpg +n000030\0193_01.jpg +n000030\0203_01.jpg +n000030\0204_01.jpg +n000030\0214_01.jpg +n000030\0218_02.jpg +n000030\0220_01.jpg +n000030\0244_01.jpg +n000031\0080_02.jpg +n000031\0092_01.jpg +n000031\0174_01.jpg +n000031\0180_01.jpg +n000031\0196_01.jpg +n000031\0248_01.jpg +n000031\0319_01.jpg +n000031\0320_03.jpg +n000032\0100_01.jpg +n000032\0100_02.jpg +n000032\0209_01.jpg +n000032\0233_01.jpg +n000032\0236_01.jpg +n000032\0237_01.jpg +n000032\0238_01.jpg +n000032\0309_01.jpg +n000032\0374_01.jpg +n000032\0393_01.jpg +n000032\0401_01.jpg +n000032\0409_01.jpg +n000032\0410_01.jpg +n000032\0420_01.jpg +n000032\0422_01.jpg +n000032\0459_01.jpg +n000032\0465_02.jpg +n000032\0531_01.jpg +n000032\0540_01.jpg +n000032\0556_01.jpg +n000032\0566_01.jpg +n000032\0578_01.jpg +n000032\0580_01.jpg +n000032\0582_01.jpg +n000032\0591_01.jpg +n000032\0605_01.jpg +n000033\0034_02.jpg +n000033\0095_01.jpg +n000033\0100_01.jpg +n000033\0100_02.jpg +n000033\0107_01.jpg +n000033\0170_01.jpg +n000033\0171_01.jpg +n000033\0179_01.jpg +n000033\0207_02.jpg +n000033\0224_01.jpg +n000033\0228_01.jpg +n000033\0231_01.jpg +n000033\0232_01.jpg +n000033\0233_02.jpg +n000033\0234_01.jpg +n000033\0235_01.jpg +n000033\0247_01.jpg +n000033\0250_02.jpg +n000033\0327_01.jpg +n000033\0337_01.jpg +n000033\0344_01.jpg +n000033\0435_01.jpg +n000034\0171_01.jpg +n000035\0069_01.jpg +n000035\0072_01.jpg +n000035\0072_02.jpg +n000035\0072_04.jpg +n000035\0098_03.jpg +n000035\0099_01.jpg +n000035\0132_02.jpg +n000035\0132_03.jpg +n000035\0134_01.jpg +n000035\0150_01.jpg +n000035\0159_02.jpg +n000035\0158_01.jpg +n000035\0161_01.jpg +n000035\0167_01.jpg +n000035\0171_01.jpg +n000035\0180_01.jpg +n000035\0200_01.jpg +n000036\0003_01.jpg +n000036\0066_02.jpg +n000036\0069_01.jpg +n000036\0083_01.jpg +n000036\0117_01.jpg +n000036\0178_02.jpg +n000036\0278_01.jpg +n000036\0279_01.jpg +n000036\0280_04.jpg +n000036\0302_02.jpg +n000036\0303_03.jpg +n000036\0304_02.jpg +n000036\0335_02.jpg +n000036\0558_01.jpg +n000036\0603_02.jpg +n000037\0007_02.jpg +n000037\0016_02.jpg +n000037\0146_03.jpg +n000037\0184_01.jpg +n000037\0166_01.jpg +n000038\0016_02.jpg +n000038\0068_01.jpg +n000038\0110_01.jpg +n000038\0114_01.jpg +n000038\0118_01.jpg +n000038\0155_01.jpg +n000038\0167_02.jpg +n000038\0169_01.jpg +n000038\0171_01.jpg +n000038\0172_01.jpg +n000038\0176_01.jpg +n000038\0178_01.jpg +n000038\0210_01.jpg +n000038\0212_01.jpg +n000038\0227_01.jpg +n000038\0236_01.jpg +n000038\0237_01.jpg +n000038\0241_01.jpg +n000038\0249_01.jpg +n000038\0260_01.jpg +n000038\0265_01.jpg +n000038\0275_01.jpg +n000038\0283_01.jpg +n000038\0286_01.jpg +n000038\0290_01.jpg +n000038\0308_01.jpg +n000038\0309_01.jpg +n000038\0336_02.jpg +n000038\0343_01.jpg +n000038\0355_01.jpg +n000038\0357_01.jpg +n000038\0366_01.jpg +n000038\0429_01.jpg +n000039\0174_01.jpg +n000039\0195_03.jpg +n000039\0310_01.jpg +n000039\0311_01.jpg +n000039\0313_01.jpg +n000039\0358_02.jpg +n000041\0089_04.jpg +n000041\0119_01.jpg +n000043\0159_02.jpg +n000043\0169_01.jpg +n000043\0369_01.jpg +n000043\0389_01.jpg +n000043\0391_01.jpg +n000043\0436_01.jpg +n000043\0457_01.jpg +n000043\0458_01.jpg +n000044\0007_01.jpg +n000044\0009_02.jpg +n000044\0078_01.jpg +n000044\0099_01.jpg +n000044\0117_01.jpg +n000044\0258_01.jpg +n000044\0275_01.jpg +n000044\0325_01.jpg +n000044\0350_01.jpg +n000044\0353_02.jpg +n000044\0364_01.jpg +n000044\0374_01.jpg +n000044\0379_01.jpg +n000045\0048_01.jpg +n000045\0048_02.jpg +n000045\0054_03.jpg +n000045\0120_03.jpg +n000045\0120_03.jpg +n000045\0120_02.jpg +n000045\0128_01.jpg +n000045\0128_02.jpg +n000045\0150_02.jpg +n000045\0156_01.jpg +n000045\0170_02.jpg +n000045\0226_02.jpg +n000045\0230_01.jpg +n000045\0254_03.jpg +n000045\0256_01.jpg +n000045\0269_01.jpg +n000045\0270_01.jpg +n000046\0145_01.jpg +n000047\0102_01.jpg +n000047\0191_03.jpg +n000047\0232_02.jpg +n000047\0292_01.jpg +n000047\0324_01.jpg +n000047\0464_01.jpg +n000047\0492_02.jpg +n000047\0484_03.jpg +n000047\0496_02.jpg +n000048\0050_01.jpg +n000048\0199_01.jpg +n000048\0197_01.jpg +n000048\0232_01.jpg +n000049\0046_01.jpg +n000049\0085_01.jpg +n000049\0136_01.jpg +n000049\0155_01.jpg +n000049\0277_01.jpg +n000049\0339_01.jpg +n000049\0342_01.jpg +n000049\0345_01.jpg +n000049\0372_01.jpg +n000049\0397_02.jpg +n000049\0417_01.jpg +n000049\0418_01.jpg +n000049\0472_02.jpg +n000049\0469_01.jpg +n000049\0474_01.jpg +n000050\0098_01.jpg +n000050\0115_01.jpg +n000050\0130_01.jpg +n000050\0158_02.jpg +n000050\0189_01.jpg +n000050\0228_01.jpg +n000050\0321_01.jpg +n000050\0321_02.jpg +n000050\0323_01.jpg +n000050\0332_02.jpg +n000050\0368_01.jpg +n000050\0369_01.jpg +n000050\0444_01.jpg +n000051\0243_02.jpg +n000051\0249_01.jpg +n000051\0250_01.jpg +n000051\0258_01.jpg +n000051\0274_01.jpg +n000051\0342_01.jpg +n000051\0366_01.jpg +n000052\0233_02.jpg +n000052\0290_01.jpg +n000052\0288_02.jpg +n000052\0373_02.jpg +n000052\0387_02.jpg +n000052\0451_01.jpg +n000052\0514_01.jpg +n000053\0136_01.jpg +n000053\0280_01.jpg +n000053\0283_01.jpg +n000053\0287_01.jpg +n000053\0287_02.jpg +n000053\0288_01.jpg +n000053\0291_01.jpg +n000053\0299_02.jpg +n000053\0314_01.jpg +n000053\0329_01.jpg +n000053\0399_01.jpg +n000054\0111_01.jpg +n000054\0258_01.jpg +n000054\0261_01.jpg +n000054\0263_01.jpg +n000054\0273_03.jpg +n000054\0275_01.jpg +n000054\0319_01.jpg +n000054\0322_01.jpg +n000054\0361_01.jpg +n000054\0451_01.jpg +n000054\0453_01.jpg +n000054\0455_01.jpg +n000055\0043_01.jpg +n000055\0167_01.jpg +n000055\0172_01.jpg +n000055\0175_01.jpg +n000055\0181_01.jpg +n000055\0251_01.jpg +n000055\0255_01.jpg +n000056\0158_02.jpg +n000056\0254_01.jpg +n000057\0200_03.jpg +n000057\0293_01.jpg +n000057\0300_01.jpg +n000057\0337_01.jpg +n000057\0341_02.jpg +n000057\0344_06.jpg +n000057\0348_01.jpg +n000057\0353_01.jpg +n000057\0351_01.jpg +n000057\0356_01.jpg +n000057\0357_01.jpg +n000057\0368_01.jpg +n000057\0373_01.jpg +n000058\0266_01.jpg +n000058\0467_03.jpg +n000058\0468_01.jpg +n000059\0005_01.jpg +n000059\0013_01.jpg +n000059\0046_01.jpg +n000059\0124_01.jpg +n000059\0177_01.jpg +n000059\0177_02.jpg +n000059\0182_01.jpg +n000059\0222_01.jpg diff --git a/clear_dataset.py b/clear_dataset.py index 7a1205e..fb6da04 100644 --- a/clear_dataset.py +++ b/clear_dataset.py @@ -5,7 +5,7 @@ # Created Date: Thursday March 24th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 24th March 2022 3:32:40 pm +# Last Modified: Sunday, 3rd April 2022 1:20:44 am # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -18,11 +18,14 @@ if __name__ == "__main__": savePath = "./vggface2hq_failed.txt" env_config = readConfig('env/env.json') env_config = env_config["path"] - dataset_root = env_config["dataset_paths"]["vggface2_hq"] + dataset_root = env_config["dataset_paths"]["vggface2_hq"]["images"] + # dataset_root = "G:/VGGFace2-HQ/newversion" + print(dataset_root) with open(savePath,'r') as logf: for line in logf: - img_path = os.path.join(dataset_root,line[:-1]).replace("\\","/") + img_path = os.path.join(dataset_root,line.replace("\n","")).replace("\\","/") try: os.rename(img_path,img_path+".deleted") - except: pass \ No newline at end of file + except Exception as e: + print(e) \ No newline at end of file diff --git a/components/Generator_256.py b/components/Generator_256.py new file mode 100644 index 0000000..9417ac3 --- /dev/null +++ b/components/Generator_256.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 19th April 2022 7:03:46 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"].lower() + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + norm_mask= nn.InstanceNorm2d + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + norm_mask = nn.BatchNorm2d + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 128 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 64 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 32 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + # self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 1 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 32 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 64 + + # self.maskhead = nn.Sequential( + # nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False), + # norm_mask, # 64 + # activation, + # nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1), + # nn.Sigmoid()) + self.maskhead_lr = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False), + norm_mask(in_channel, affine=True), # 32 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False), + norm_mask(in_channel//4, affine=True), # 64 + activation + ) + self.maskhead_hr = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//4, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False), + norm_mask(in_channel//16, affine=True), # 128 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() # 256 + ) + self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//4, 1, kernel_size=1, stride=1), + nn.Sigmoid()) + + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm) + + # self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + mask_feat= self.maskhead_lr(res) + + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up4(res,id) + res = self.up3(res,id) + mask_lr= self.maskhead_out(mask_feat) + # res = (1-mask) * self.sigma(skip) + mask * res + res = (1-mask_lr) * skip + mask_lr * res + res = self.up2(res) # + skip + res = self.up1(res) + res = self.to_rgb(res) + mask_hr=self.maskhead_hr(mask_feat) + res = (1-mask_hr) * img + mask_hr * res + return res, mask_lr, mask_hr \ No newline at end of file diff --git a/components/Generator_2mask.py b/components/Generator_2mask.py new file mode 100644 index 0000000..9ea274b --- /dev/null +++ b/components/Generator_2mask.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 15th April 2022 12:30:27 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"].lower() + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + norm_mask= nn.InstanceNorm2d + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + norm_mask = nn.BatchNorm2d + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + # self.maskhead = nn.Sequential( + # nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False), + # norm_mask, # 64 + # activation, + # nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1), + # nn.Sigmoid()) + self.maskhead_lr = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False), + norm_mask(in_channel, affine=True), # 32 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False), + norm_mask(in_channel//4, affine=True), # 64 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1), + norm_mask(in_channel//8, affine=True), # 128 + activation, + ) + self.maskhead_hr = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False), + norm_mask(in_channel//16, affine=True), # 256 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() # 512 + ) + self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1), + nn.Sigmoid()) + + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm) + + # self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + mask_feat= self.maskhead_lr(res) + + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + mask_lr= self.maskhead_out(mask_feat) + # res = (1-mask) * self.sigma(skip) + mask * res + res = (1-mask_lr) * skip + mask_lr * res + res = self.up2(res) # + skip + res = self.up1(res) + res = self.to_rgb(res) + mask_hr=self.maskhead_hr(mask_feat) + res = (1-mask_hr) * img + mask_hr * res + return res, mask_lr, mask_hr \ No newline at end of file diff --git a/components/Generator_2mask2.py b/components/Generator_2mask2.py new file mode 100644 index 0000000..7040df7 --- /dev/null +++ b/components/Generator_2mask2.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 19th April 2022 12:45:55 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"].lower() + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + norm_mask= nn.InstanceNorm2d + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + norm_mask = nn.BatchNorm2d + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + # self.maskhead = nn.Sequential( + # nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False), + # norm_mask, # 64 + # activation, + # nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1), + # nn.Sigmoid()) + self.maskhead_lr = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False), + norm_mask(in_channel, affine=True), # 32 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False), + norm_mask(in_channel//4, affine=True), # 64 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1), + norm_mask(in_channel//8, affine=True), # 128 + activation, + ) + self.maskhead_hr = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False), + norm_mask(in_channel//16, affine=True), # 256 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() # 512 + ) + self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1), + nn.Sigmoid()) + + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm) + + self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + # self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + mask_feat= self.maskhead_lr(res) + + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + mask_lr= self.maskhead_out(mask_feat) + # res = (1-mask) * self.sigma(skip) + mask * res + res = (1-mask_lr) * skip + mask_lr * res + res = self.up2(res,id) # + skip + res = self.up1(res,id) + res = self.to_rgb(res) + mask_hr=self.maskhead_hr(mask_feat) + res = (1-mask_hr) * img + mask_hr * res + return res, mask_lr, mask_hr \ No newline at end of file diff --git a/components/Generator_2mask_DWConv.py b/components/Generator_2mask_DWConv.py new file mode 100644 index 0000000..8d5c743 --- /dev/null +++ b/components/Generator_2mask_DWConv.py @@ -0,0 +1,453 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Monday, 18th April 2022 10:20:12 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import math + +import torch +from torch import nn +import torch.nn.functional as F + +from components.ModulatedDWConv import ModulatedDWConv2d + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Sequential( + nn.Conv2d(dim_in, dim_in, 3, 1, 1, groups=dim_in), + nn.Conv2d(dim_in, dim_in, 1, 1) + ) + + self.conv2 = nn.Sequential( + nn.Conv2d(dim_in, dim_in, 3, 1, 1, groups=dim_in), + nn.Conv2d(dim_in, dim_out, 1, 1) + ) + # self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + # self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +# class ResUpBlk(nn.Module): +# def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): +# super().__init__() +# self.actv = actv +# self.normalize = normalize +# self.learned_sc = dim_in != dim_out +# self.equal_var = math.sqrt(2) +# self._build_weights(dim_in, dim_out) + +# def _build_weights(self, dim_in, dim_out): +# self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) +# self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) +# if self.normalize.lower() == "in": +# self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) +# self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) +# elif self.normalize.lower() == "bn": +# self.norm1 = nn.BatchNorm2d(dim_in) +# self.norm2 = nn.BatchNorm2d(dim_out) +# if self.learned_sc: +# self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + +# def _shortcut(self, x): +# x = F.interpolate(x, scale_factor=2, mode='nearest') +# if self.learned_sc: +# x = self.conv1x1(x) +# return x + +# def _residual(self, x): +# x = self.norm1(x) +# x = self.actv(x) +# x = F.interpolate(x, scale_factor=2, mode='nearest') +# x = self.conv1(x) +# x = self.norm2(x) +# x = self.actv(x) +# x = self.conv2(x) +# return x + +# def forward(self, x): +# out = self._residual(x) +# out = (out + self._shortcut(x)) / self.equal_var +# return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Sequential( + nn.Conv2d(dim_in, dim_in, 3, 1, 1,groups=dim_in), + nn.Conv2d(dim_in, dim_out, 1, 1) + ) + + self.conv2 = nn.Sequential( + nn.Conv2d(dim_out, dim_out, 3, 1, 1,groups=dim_out), + nn.Conv2d(dim_out, dim_out, 1) + ) + + # self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + # self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class ModulatedResBlk(nn.Module): + def __init__(self, + dim_in, + dim_out, + style_dim=512, + actv=nn.LeakyReLU(0.2), + upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = ModulatedDWConv2d(dim_in, dim_out, style_dim, 3) + self.conv2 = ModulatedDWConv2d(dim_out, dim_out, style_dim, 3) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x,s) + x = self.actv(x) + x = self.conv2(x,s) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"].lower() + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + norm_mask= nn.InstanceNorm2d + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + norm_mask = nn.BatchNorm2d + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + # self.maskhead = nn.Sequential( + # nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False), + # norm_mask, # 64 + # activation, + # nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1), + # nn.Sigmoid()) + self.maskhead_lr = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False), + norm_mask(in_channel, affine=True), # 32 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False), + norm_mask(in_channel//4, affine=True), # 64 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1), + norm_mask(in_channel//8, affine=True), # 128 + activation, + ) + + self.maskhead_hr = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False), + norm_mask(in_channel//16, affine=True), # 256 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() # 512 + ) + + self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1), + nn.Sigmoid()) + + # self.maskhead_lr = nn.Sequential( + # nn.UpsamplingNearest2d(scale_factor = 2), + # nn.Conv2d(in_channel*8, in_channel*8, 3, 1, 1,groups=in_channel*8), + # nn.Conv2d(in_channel*8, in_channel, 1, bias=False), + # # nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False), + # norm_mask(in_channel, affine=True), # 32 + # activation, + # nn.UpsamplingNearest2d(scale_factor = 2), + # nn.Conv2d(in_channel, in_channel, 3, 1, 1,groups=in_channel), + # nn.Conv2d(in_channel, in_channel//4, 1, bias=False), + # # nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False), + # norm_mask(in_channel//4, affine=True), # 64 + # activation, + # nn.UpsamplingNearest2d(scale_factor = 2), + # # nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1), + # nn.Conv2d(in_channel//4, in_channel//4, 3, 1, 1,groups=in_channel//4), + # nn.Conv2d(in_channel//4, in_channel//8, 1, bias=False), + # norm_mask(in_channel//8, affine=True), # 128 + # activation, + # ) + # self.maskhead_hr = nn.Sequential( + # nn.UpsamplingNearest2d(scale_factor = 2), + # # nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False), + # nn.Conv2d(in_channel//8, in_channel//8, 3, 1, 1,groups=in_channel//8), + # nn.Conv2d(in_channel//8, in_channel//16, 1, bias=False), + # norm_mask(in_channel//16, affine=True), # 256 + # activation, + # nn.UpsamplingNearest2d(scale_factor = 2), + # nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1), + # nn.Sigmoid() # 512 + # ) + # self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1), + # nn.Sigmoid()) + + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel*2, in_channel, normalize=norm) + + # self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel, in_channel, normalize=norm) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + mask_feat= self.maskhead_lr(res) + + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + mask_lr= self.maskhead_out(mask_feat) + # res = (1-mask) * self.sigma(skip) + mask * res + res = (1-mask_lr) * skip + mask_lr * res + res = self.up2(res,id) # + skip + res = self.up1(res,id) + res = self.to_rgb(res) + mask_hr=self.maskhead_hr(mask_feat) + res = (1-mask_hr) * img + mask_hr * res + return res, mask_lr, mask_hr \ No newline at end of file diff --git a/components/Generator_2maskhead_config copy.py b/components/Generator_2maskhead_config copy.py new file mode 100644 index 0000000..f44ac61 --- /dev/null +++ b/components/Generator_2maskhead_config copy.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 13th April 2022 10:22:52 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="bn", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"] + norm = norm.lower() + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + # self.sigma = ResBlk(in_channel*2,in_channel*2) + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + self.maskhead_lr = nn.Sequential( + nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(in_channel), # 64 + activation, + nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() + ) + + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm) + + # self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + self.maskhead_hr = nn.Sequential( + nn.Conv2d(in_channel, in_channel//8, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(in_channel//8), # 64 + activation, + nn.Conv2d(in_channel//8, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() + ) + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + mask= self.maskhead_lr(res) + # res = (1-mask) * self.sigma(skip) + mask * res + res = (1-mask) * skip + mask * res + res = self.up2(res) # + skip + res = self.up1(res) + mask_hr = self.maskhead_hr(res) + res = self.to_rgb(res) + res = (1-mask_hr)*img + mask_hr*res + return res, mask, mask_hr \ No newline at end of file diff --git a/components/Generator_2maskhead_config.py b/components/Generator_2maskhead_config.py new file mode 100644 index 0000000..f44ac61 --- /dev/null +++ b/components/Generator_2maskhead_config.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 13th April 2022 10:22:52 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="bn", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"] + norm = norm.lower() + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + # self.sigma = ResBlk(in_channel*2,in_channel*2) + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + self.maskhead_lr = nn.Sequential( + nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(in_channel), # 64 + activation, + nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() + ) + + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm) + + # self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + self.maskhead_hr = nn.Sequential( + nn.Conv2d(in_channel, in_channel//8, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(in_channel//8), # 64 + activation, + nn.Conv2d(in_channel//8, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() + ) + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + mask= self.maskhead_lr(res) + # res = (1-mask) * self.sigma(skip) + mask * res + res = (1-mask) * skip + mask * res + res = self.up2(res) # + skip + res = self.up1(res) + mask_hr = self.maskhead_hr(res) + res = self.to_rgb(res) + res = (1-mask_hr)*img + mask_hr*res + return res, mask, mask_hr \ No newline at end of file diff --git a/components/Generator_ResSkip_config.py b/components/Generator_ResSkip_config.py new file mode 100644 index 0000000..7255a59 --- /dev/null +++ b/components/Generator_ResSkip_config.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 29th March 2022 12:02:53 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + + +import torch +from torch import nn +import torch.nn.functional as F +import math +from components.LSTU import LSTU + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"] + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + self.lstu = LSTU(in_channel*2,norm) + + self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel*2, in_channel, normalize="in") # 256 + + # self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False), + # nn.BatchNorm2d(in_channel), + # activation, + # nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1), + # nn.Sigmoid() + # ) + + self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + # res = self.down6(res) + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + # res = self.up6(res,id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + res,mask = self.lstu(skip, res) + res = self.up2(res,id) # + skip + res = self.up1(res,id) + res = self.to_rgb(res) + return res,mask \ No newline at end of file diff --git a/components/Generator_ResSkip_config1.py b/components/Generator_ResSkip_config1.py new file mode 100644 index 0000000..3fb0783 --- /dev/null +++ b/components/Generator_ResSkip_config1.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 29th March 2022 1:08:05 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"] + + lstu_script = kwargs["lstu_script"] + lstu_class = kwargs["lstu_class"] + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + script_name = "components." + lstu_script + package = __import__(script_name, fromlist=True) + lstu_class = getattr(package, lstu_class) + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + self.lstu = lstu_class(in_channel*2,norm) + + self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel*2, in_channel, normalize="in") # 256 + + # self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False), + # nn.BatchNorm2d(in_channel), + # activation, + # nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1), + # nn.Sigmoid() + # ) + + self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + # res = self.down6(res) + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + # res = self.up6(res,id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + res,mask = self.lstu(skip, res) + res = self.up2(res,id) # + skip + res = self.up1(res,id) + res = self.to_rgb(res) + return res \ No newline at end of file diff --git a/components/Generator_Res_config3.py b/components/Generator_Res_config3.py new file mode 100644 index 0000000..41690bf --- /dev/null +++ b/components/Generator_Res_config3.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 30th March 2022 4:14:27 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"] + + lstu_script = kwargs["lstu_script"] + lstu_class = kwargs["lstu_class"] + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + script_name = "components." + lstu_script + package = __import__(script_name, fromlist=True) + lstu_class = getattr(package, lstu_class) + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + self.lstu = lstu_class(in_channel*2,norm) + + self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel*2, in_channel, normalize="in") # 256 + + # self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False), + # nn.BatchNorm2d(in_channel), + # activation, + # nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1), + # nn.Sigmoid() + # ) + + self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + # res = self.down6(res) + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + # res = self.up6(res,id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + # res,mask = self.lstu(skip, res) + res = self.up2(res,id) # + skip + res = self.up1(res,id) + res = self.to_rgb(res) + return res \ No newline at end of file diff --git a/components/Generator_VGGStyle_maskhead_config.py b/components/Generator_VGGStyle_maskhead_config.py new file mode 100644 index 0000000..6fffd06 --- /dev/null +++ b/components/Generator_VGGStyle_maskhead_config.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 6th April 2022 12:55:51 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainUpBlock(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2)): + super().__init__() + self.actv = actv + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.norm = AdaIN(style_dim, dim_out) + + def forward(self, x, s): + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv(x) + x = self.norm(x, s) + x = self.actv(x) + return x + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"] + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = nn.Sequential(nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=2, padding=1, bias=False), # 256 + nn.BatchNorm2d(in_channel), activation) + + self.down2 = nn.Sequential(nn.Conv2d(in_channel, in_channel*2, kernel_size=3, stride=2, padding=1, bias=False), # 128 + nn.BatchNorm2d(in_channel*2), activation) + + self.down3 = nn.Sequential(nn.Conv2d(in_channel*2, in_channel*4, kernel_size=3, stride=2, padding=1, bias=False), # 64 + nn.BatchNorm2d(in_channel*4), activation) + + self.down4 = nn.Sequential(nn.Conv2d(in_channel*4, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False), # 32 + nn.BatchNorm2d(in_channel*8), activation) + + self.down5 = nn.Sequential(nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False), # 32 + nn.BatchNorm2d(in_channel*8), activation) + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + self.maskhead = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False), + nn.BatchNorm2d(in_channel), # 32 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(in_channel//2), # 64 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1, bias=False), + nn.Sigmoid() + ) + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainUpBlock(in_channel*8, in_channel*8, style_dim=id_dim) # 32 + + self.up4 = AdainUpBlock(in_channel*8, in_channel*4, style_dim=id_dim) # 64 + + self.up3 = AdainUpBlock(in_channel*4, in_channel*2, style_dim=id_dim) # 128 + + self.up2 = AdainUpBlock(in_channel*2, in_channel, style_dim=id_dim) + + self.up1 = AdainUpBlock(in_channel, in_channel, style_dim=id_dim) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + + self.to_rgb = nn.Sequential(nn.ReflectionPad2d(1), + nn.Conv2d(in_channel, 3, kernel_size=3, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + mask= self.maskhead(res) + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + res = (1-mask) * skip + mask * res + res = self.up2(res,id) # + skip + res = self.up1(res,id) + res = self.to_rgb(res) + return res, mask \ No newline at end of file diff --git a/components/Generator_featout_config.py b/components/Generator_featout_config.py new file mode 100644 index 0000000..ae475a6 --- /dev/null +++ b/components/Generator_featout_config.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Saturday, 2nd April 2022 1:27:23 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"] + + lstu_script = kwargs["lstu_script"] + lstu_class = kwargs["lstu_class"] + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + script_name = "components." + lstu_script + package = __import__(script_name, fromlist=True) + lstu_class = getattr(package, lstu_class) + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + self.maskhead = nn.Sequential( + nn.ConvTranspose2d(in_channel*8, in_channel, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(in_channel), # 32 + activation, + nn.ConvTranspose2d(in_channel, in_channel//2, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(in_channel), # 64 + activation, + nn.ConvTranspose2d(in_channel//2, 1, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(in_channel), # 128 + nn.Sigmoid() + ) + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + self.lstu = lstu_class(in_channel*2,norm) + + self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel*2, in_channel, normalize="in") # 256 + + # self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False), + # nn.BatchNorm2d(in_channel), + # activation, + # nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1), + # nn.Sigmoid() + # ) + + self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id, feat_out=False): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + if feat_out: + return res + # res = self.down6(res) + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + # res = self.up6(res,id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + res,mask = self.lstu(skip, res) + res = self.up2(res,id) # + skip + res = self.up1(res,id) + res = self.to_rgb(res) + return res \ No newline at end of file diff --git a/components/Generator_involution_maskhead_config.py b/components/Generator_involution_maskhead_config.py new file mode 100644 index 0000000..4acf04d --- /dev/null +++ b/components/Generator_involution_maskhead_config.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 3rd April 2022 1:06:31 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"] + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + self.maskhead = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False), + nn.BatchNorm2d(in_channel), # 32 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(in_channel//2), # 64 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1, bias=False), + nn.Sigmoid() + ) + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + + self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + mask= self.maskhead(res) + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + res = (1-mask) * skip + mask * res + res = self.up2(res,id) # + skip + res = self.up1(res,id) + res = self.to_rgb(res) + return res, mask \ No newline at end of file diff --git a/components/Generator_maskhead_config.py b/components/Generator_maskhead_config.py new file mode 100644 index 0000000..4acf04d --- /dev/null +++ b/components/Generator_maskhead_config.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 3rd April 2022 1:06:31 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"] + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + self.maskhead = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False), + nn.BatchNorm2d(in_channel), # 32 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(in_channel//2), # 64 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1, bias=False), + nn.Sigmoid() + ) + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + + self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + mask= self.maskhead(res) + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + res = (1-mask) * skip + mask * res + res = self.up2(res,id) # + skip + res = self.up1(res,id) + res = self.to_rgb(res) + return res, mask \ No newline at end of file diff --git a/components/Generator_maskhead_config1.py b/components/Generator_maskhead_config1.py new file mode 100644 index 0000000..197fc33 --- /dev/null +++ b/components/Generator_maskhead_config1.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 6th April 2022 8:38:50 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="bn", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"] + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.sigma = ResBlk(in_channel*2,in_channel*2) + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + self.maskhead = nn.Sequential( + nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(in_channel), # 64 + activation, + nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() + ) + + self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + + self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + mask= self.maskhead(res) + res = (1-mask) * self.sigma(skip) + mask * res + res = self.up2(res,id) # + skip + res = self.up1(res,id) + res = self.to_rgb(res) + return res, mask \ No newline at end of file diff --git a/components/Generator_maskhead_config2.py b/components/Generator_maskhead_config2.py new file mode 100644 index 0000000..c1609ad --- /dev/null +++ b/components/Generator_maskhead_config2.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 13th April 2022 3:12:53 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="bn", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"].lower() + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + # self.sigma = ResBlk(in_channel*2,in_channel*2) + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + self.maskhead = nn.Sequential( + nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(in_channel), # 64 + activation, + nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() + ) + + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + self.up2 = ResUpBlk(in_channel*2, in_channel, normalize="bn") + + # self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + self.up1 = ResUpBlk(in_channel, in_channel, normalize="bn") + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + mask= self.maskhead(res) + # res = (1-mask) * self.sigma(skip) + mask * res + res = (1-mask) * skip + mask * res + res = self.up2(res) # + skip + res = self.up1(res) + res = self.to_rgb(res) + return res, mask \ No newline at end of file diff --git a/components/Generator_starganv2.py b/components/Generator_starganv2.py new file mode 100644 index 0000000..728f71b --- /dev/null +++ b/components/Generator_starganv2.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator_Invobn_config1.py +# Created Date: Saturday February 26th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 13th April 2022 6:30:26 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os + +import torch +from torch import nn +import torch.nn.functional as F +import math + + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class AdaIN(nn.Module): + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features*2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + +class ResUpBlk(nn.Module): + def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"): + super().__init__() + self.actv = actv + self.normalize = normalize + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_out, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x): + x = self.norm1(x) + x = self.actv(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + out = self._residual(x) + out = (out + self._shortcut(x)) / self.equal_var + return out + +class AdainResBlk(nn.Module): + def __init__(self, dim_in, dim_out, style_dim=512, + actv=nn.LeakyReLU(0.2), upsample=False): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / self.equal_var + return out + + +class Generator(nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + + id_dim = kwargs["id_dim"] + k_size = kwargs["g_kernel_size"] + res_num = kwargs["res_num"] + in_channel = kwargs["in_channel"] + up_mode = kwargs["up_mode"] + norm = kwargs["norm"].lower() + + aggregator = kwargs["aggregator"] + res_mode = kwargs["res_mode"] + + padding_size= int((k_size -1)/2) + padding_type= 'reflect' + + if norm.lower() == "in": + norm_out = nn.InstanceNorm2d(in_channel, affine=True) + norm_mask= nn.InstanceNorm2d + elif norm.lower() == "bn": + norm_out = nn.BatchNorm2d(in_channel) + norm_mask = nn.BatchNorm2d + + + activation = nn.LeakyReLU(0.2) + # activation = nn.ReLU() + + self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0) + # self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), + # nn.BatchNorm2d(64), activation) + ### downsample + self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256 + + self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128 + + self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64 + + self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32 + + self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16 + + # self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8 + + + ### resnet blocks + BN = [] + for i in range(res_num): + BN += [ + AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)] + self.BottleNeck = nn.Sequential(*BN) + + # self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16 + + self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32 + + self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64 + + self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128 + + # self.maskhead = nn.Sequential( + # nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False), + # norm_mask, # 64 + # activation, + # nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1), + # nn.Sigmoid()) + self.maskhead = nn.Sequential( + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False), + norm_mask(in_channel, affine=True), # 32 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False), + norm_mask(in_channel//4, affine=True), # 64 + activation, + nn.UpsamplingNearest2d(scale_factor = 2), + nn.Conv2d(in_channel//4, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() + ) + + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + # self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True) + self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm) + + # self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True) + self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm) + # ResUpBlk(in_channel, in_channel, normalize="in") # 512 + + + + self.to_rgb = nn.Sequential( + norm_out, + activation, + nn.Conv2d(in_channel, 3, 3, 1, 1)) + + # self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), + # nn.Conv2d(64, 3, kernel_size=7, padding=0)) + + + # self.__weights_init__() + + # def __weights_init__(self): + # for layer in self.encoder: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + # for layer in self.encoder2: + # if isinstance(layer,nn.Conv2d): + # nn.init.xavier_uniform_(layer.weight) + + def forward(self, img, id): + res = self.from_rgb(img) + res = self.down1(res) + skip = self.down2(res) + res = self.down3(skip) + res = self.down4(res) + res = self.down5(res) + mask= self.maskhead(res) + for i in range(len(self.BottleNeck)): + res = self.BottleNeck[i](res, id) + res = self.up5(res,id) + res = self.up4(res,id) + res = self.up3(res,id) + + # res = (1-mask) * self.sigma(skip) + mask * res + res = (1-mask) * skip + mask * res + res = self.up2(res) # + skip + res = self.up1(res) + res = self.to_rgb(res) + return res, mask \ No newline at end of file diff --git a/components/LSTU.py b/components/LSTU.py index 85f9606..240fb21 100644 --- a/components/LSTU.py +++ b/components/LSTU.py @@ -5,43 +5,115 @@ # Created Date: Sunday January 16th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Sunday, 13th February 2022 2:03:21 am +# Last Modified: Monday, 28th March 2022 11:47:55 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# +import math import torch from torch import nn +import torch.nn.functional as F + +# class LSTU(nn.Module): +# def __init__( +# self, +# in_channel, +# out_channel, +# latent_channel, +# scale = 4 +# ): +# super().__init__() +# sig = nn.Sigmoid() +# self.relu = nn.ReLU(True) + +# self.up_sample = nn.Sequential(nn.Conv2d(latent_channel, out_channel/4, kernel_size=3, stride=1, padding=1, bias=False), +# nn.BatchNorm2d(out_channel/4), +# self.relu, +# nn.Conv2d(latent_channel/4, out_channel, kernel_size=3, stride=1, padding=1), +# ) + +# self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False), +# nn.BatchNorm2d(out_channel), sig) + +# self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False), +# nn.BatchNorm2d(out_channel), sig) + +# self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True)) + +# def forward(self, encoder_in, bottleneck_in): +# h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel` +# h_bar_l = self.conv11(h_hat_l_1) +# f_l = self.forget_gate(h_hat_l_1) +# r_l = self.reset_gate (h_hat_l_1) +# h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in +# x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1 +# return x_hat_l + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance class LSTU(nn.Module): def __init__( self, in_channel, - out_channel, - latent_channel, - scale = 4 + norm ): super().__init__() - sig = nn.Sigmoid() - self.relu = nn.ReLU(True) + self.sig = nn.Sigmoid() - self.up_sample = nn.Sequential(nn.ConvTranspose2d(latent_channel, out_channel, kernel_size=4, stride=scale, padding=0, bias=False), - nn.BatchNorm2d(out_channel), sig) - - self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(out_channel), sig) - - self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(out_channel), sig) - - self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True)) + self.mask_head = ResBlk(in_channel, 1, normalize=norm) + # self.forget_gate = ResBlk(in_channel,in_channel, normalize=norm) - def forward(self, encoder_in, bottleneck_in): - h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel` - h_bar_l = self.conv11(h_hat_l_1) - f_l = self.forget_gate(h_hat_l_1) - r_l = self.reset_gate (h_hat_l_1) - h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in - x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1 - return x_hat_l \ No newline at end of file + def forward(self, encoder_in, decoder_in): + mask = self.sig(self.mask_head(decoder_in)) # upsample and make `channel` identical to `out_channel` + # enc_feat= self.forget_gate(encoder_in) + out = (1-mask)*encoder_in + mask * decoder_in + return out, mask \ No newline at end of file diff --git a/components/LSTU_Config.py b/components/LSTU_Config.py new file mode 100644 index 0000000..05dff5c --- /dev/null +++ b/components/LSTU_Config.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Generator.py +# Created Date: Sunday January 16th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 29th March 2022 12:20:26 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import math +import torch +from torch import nn + +import torch.nn.functional as F + +# class LSTU(nn.Module): +# def __init__( +# self, +# in_channel, +# out_channel, +# latent_channel, +# scale = 4 +# ): +# super().__init__() +# sig = nn.Sigmoid() +# self.relu = nn.ReLU(True) + +# self.up_sample = nn.Sequential(nn.Conv2d(latent_channel, out_channel/4, kernel_size=3, stride=1, padding=1, bias=False), +# nn.BatchNorm2d(out_channel/4), +# self.relu, +# nn.Conv2d(latent_channel/4, out_channel, kernel_size=3, stride=1, padding=1), +# ) + +# self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False), +# nn.BatchNorm2d(out_channel), sig) + +# self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False), +# nn.BatchNorm2d(out_channel), sig) + +# self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True)) + +# def forward(self, encoder_in, bottleneck_in): +# h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel` +# h_bar_l = self.conv11(h_hat_l_1) +# f_l = self.forget_gate(h_hat_l_1) +# r_l = self.reset_gate (h_hat_l_1) +# h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in +# x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1 +# return x_hat_l + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.equal_var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x /self.equal_var # unit variance + +class LSTU(nn.Module): + def __init__( + self, + in_channel, + norm + ): + super().__init__() + + # self.mask_head = ResBlk(in_channel, 1, normalize=norm) + self.mask_head = nn.Sequential(nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(in_channel//2), + nn.LeakyReLU(0.2), + nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() + ) + # self.forget_gate = ResBlk(in_channel,in_channel, normalize=norm) + + def forward(self, encoder_in, decoder_in): + mask = self.mask_head(decoder_in) # upsample and make `channel` identical to `out_channel` + # enc_feat= self.forget_gate(encoder_in) + out = (1-mask)*encoder_in + mask * decoder_in + return out, mask \ No newline at end of file diff --git a/components/ModulatedDWConv.py b/components/ModulatedDWConv.py new file mode 100644 index 0000000..9461cba --- /dev/null +++ b/components/ModulatedDWConv.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: ModulatedDWConv.py +# Created Date: Monday April 18th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Monday, 18th April 2022 10:33:48 am +# Modified By: Chen Xuanhong +# Modified from: https://github.com/bes-dev/MobileStyleGAN.pytorch +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ModulatedDWConv2d(nn.Module): + def __init__( + self, + channels_in, + channels_out, + style_dim, + kernel_size, + demodulate=True + ): + super().__init__() + # create conv + self.weight_dw = nn.Parameter( + torch.randn(channels_in, 1, kernel_size, kernel_size) + ) + self.weight_permute = nn.Parameter( + torch.randn(channels_out, channels_in, 1, 1) + ) + # create modulation network + self.modulation = nn.Linear(style_dim, channels_in, bias=True) + self.modulation.bias.data.fill_(1.0) + # create demodulation parameters + self.demodulate = demodulate + if self.demodulate: + self.register_buffer("style_inv", torch.randn(1, 1, channels_in, 1, 1)) + # some service staff + self.scale = 1.0 / math.sqrt(channels_in * kernel_size ** 2) + self.padding = kernel_size // 2 + + def forward(self, x, style): + modulation = self.get_modulation(style) + x = modulation * x + x = F.conv2d(x, self.weight_dw, padding=self.padding, groups=x.size(1)) + x = F.conv2d(x, self.weight_permute) + if self.demodulate: + demodulation = self.get_demodulation(style) + x = demodulation * x + return x + + def get_modulation(self, style): + style = self.modulation(style).view(style.size(0), -1, 1, 1) + modulation = self.scale * style + return modulation + + def get_demodulation(self, style): + w = (self.weight_dw.transpose(0, 1) * self.weight_permute).unsqueeze(0) + norm = torch.rsqrt((self.scale * self.style_inv * w).pow(2).sum([2, 3, 4]) + 1e-8) + demodulation = norm + return demodulation.view(*demodulation.size(), 1, 1) \ No newline at end of file diff --git a/components/Nonstau_Discriminator.py b/components/Nonstau_Discriminator.py new file mode 100644 index 0000000..638d3dd --- /dev/null +++ b/components/Nonstau_Discriminator.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Nonstau_Discriminator.py +# Created Date: Monday March 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Monday, 28th March 2022 10:03:56 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + elif self.normalize.lower() == "none": + self.normalize = False + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x / self.var # unit variance + +class Discriminator(torch.nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + img_size = kwargs["img_size"] + num_domains = 1 + max_conv_dim = kwargs["max_conv_dim"] + norm = kwargs["norm"] + dim_in = 2**14 // img_size + blocks = [] + blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] + + repeat_num = int(np.log2(img_size)) - 2 + for _ in range(repeat_num): + dim_out = min(dim_in*2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, normalize=norm, downsample=True)] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)] + self.main = nn.Sequential(*blocks) + + def forward(self, x): + out = self.main(x) + out = out.view(out.size(0), -1) # (batch, num_domains) + return out \ No newline at end of file diff --git a/components/Nonstau_Discriminator_FM.py b/components/Nonstau_Discriminator_FM.py new file mode 100644 index 0000000..8a49c00 --- /dev/null +++ b/components/Nonstau_Discriminator_FM.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: Nonstau_Discriminator.py +# Created Date: Monday March 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 13th April 2022 3:11:40 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize="in", downsample=False): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self.var = math.sqrt(2) + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize.lower() == "in": + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + elif self.normalize.lower() == "bn": + self.norm1 = nn.BatchNorm2d(dim_in) + self.norm2 = nn.BatchNorm2d(dim_in) + elif self.normalize.lower() == "none": + self.normalize = False + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x / self.var # unit variance + +class Discriminator(torch.nn.Module): + def __init__( + self, + **kwargs + ): + super().__init__() + img_size = kwargs["img_size"] + num_domains = 1 + max_conv_dim = kwargs["max_conv_dim"] + norm = kwargs["norm"].lower() + dim_in = 2**14 // img_size + blocks = [] + blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] + + repeat_num = int(np.log2(img_size)) - 2 + for _ in range(repeat_num-2): + dim_out = min(dim_in*2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, normalize=norm, downsample=True)] + dim_in = dim_out + blocks1 = [] + for _ in range(2): # 16 + dim_out = min(dim_in*2, max_conv_dim) + blocks1 += [ResBlk(dim_in, dim_out, normalize=norm, downsample=True)] + dim_in = dim_out + + blocks1 += [nn.LeakyReLU(0.2)] + blocks1 += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] + blocks1 += [nn.LeakyReLU(0.2)] + blocks1 += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)] + self.main = nn.Sequential(*blocks) + self.tail = nn.Sequential(*blocks1) + + def get_feature(self,x): + mid = self.main(x) + return mid + + def forward(self, x): + mid = self.main(x) + out = self.tail(mid) + out = out.view(out.size(0), -1) # (batch, num_domains) + return out,mid diff --git a/data_tools/data_loader_FFHQ_multigpu.py b/data_tools/data_loader_FFHQ_multigpu.py new file mode 100644 index 0000000..d8e4570 --- /dev/null +++ b/data_tools/data_loader_FFHQ_multigpu.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: data_loader_VGGFace2HQ copy.py +# Created Date: Sunday February 6th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 15th February 2022 1:50:19 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + + +import os +import glob +import torch +import random +import numpy as np +from PIL import Image +from torch.utils import data +from torchvision import transforms as T +# from StyleResize import StyleResize + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +class data_prefetcher(): + def __init__(self, loader, cur_gpu): + torch.cuda.set_device(cur_gpu) # must add this line to avoid excessive use of GPU 0 by the prefetcher + self.loader = loader + self.dataiter = iter(loader) + self.stream = torch.cuda.Stream(device=cur_gpu) + self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1) + self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1) + self.cur_gpu = cur_gpu + # With Amp, it isn't necessary to manually convert data to half. + # if args.fp16: + # self.mean = self.mean.half() + # self.std = self.std.half() + # self.num_images = loader.__len__() + self.preload() + + def preload(self): + # try: + self.src_image1, self.src_image2 = next(self.dataiter) + # except StopIteration: + # self.dataiter = iter(self.loader) + # self.src_image1, self.src_image2 = next(self.dataiter) + + with torch.cuda.stream(self.stream): + self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True) + self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std) + self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True) + self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std) + # With Amp, it isn't necessary to manually convert data to half. + # if args.fp16: + # self.next_input = self.next_input.half() + # else: + # self.next_input = self.next_input.float() + # self.next_input = self.next_input.sub_(self.mean).div_(self.std) + def next(self): + torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream) + src_image1 = self.src_image1 + src_image2 = self.src_image2 + self.preload() + return src_image1, src_image2 + + # def __len__(self): + # """Return the number of images.""" + # return self.num_images + +class VGGFace2HQDataset(data.Dataset): + """Dataset class for the Artworks dataset and content dataset.""" + + def __init__(self, + image_dir, + img_transform, + subffix='jpg', + random_seed=1234): + """Initialize and preprocess the VGGFace2 HQ dataset.""" + self.image_dir = image_dir + self.img_transform = img_transform + self.subffix = subffix + self.dataset = [] + self.random_seed = random_seed + self.preprocess() + self.num_images = len(self.dataset) + + def preprocess(self): + """Preprocess the VGGFace2 HQ dataset.""" + print("processing VGGFace2 HQ dataset images...") + + temp_path = os.path.join(self.image_dir,'*/') + pathes = glob.glob(temp_path) + self.dataset = [] + for dir_item in pathes: + join_path = glob.glob(os.path.join(dir_item,'*.jpg')) + print("processing %s"%dir_item,end='\r') + temp_list = [] + for item in join_path: + temp_list.append(item) + self.dataset.append(temp_list) + random.seed(self.random_seed) + random.shuffle(self.dataset) + print('Finished preprocessing the VGGFace2 HQ dataset, total dirs number: %d...'%len(self.dataset)) + + def __getitem__(self, index): + """Return two src domain images and two dst domain images.""" + dir_tmp1 = self.dataset[index] + dir_tmp1_len = len(dir_tmp1) + + filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] + filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] + image1 = self.img_transform(Image.open(filename1)) + image2 = self.img_transform(Image.open(filename2)) + return image1, image2 + + def __len__(self): + """Return the number of images.""" + return self.num_images + +def GetLoader( dataset_roots, + rank, + num_gpus, + batch_size=16, + **kwargs + ): + """Build and return a data loader.""" + + data_root = dataset_roots + random_seed = kwargs["random_seed"] + num_workers = kwargs["dataloader_workers"] + + c_transforms = [] + + c_transforms.append(T.ToTensor()) + c_transforms = T.Compose(c_transforms) + + content_dataset = VGGFace2HQDataset( + data_root, + c_transforms, + "jpg", + random_seed) + device = torch.device('cuda', rank) + sampler = InfiniteSampler(dataset=content_dataset, rank=rank, num_replicas=num_gpus, seed=random_seed) + content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size, + drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True, sampler=sampler) + # content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size, + # drop_last=False,shuffle=True,num_workers=num_workers,pin_memory=True) + prefetcher = data_prefetcher(content_data_loader,device) + return prefetcher + +def denorm(x): + out = (x + 1) / 2 + return out.clamp_(0, 1) \ No newline at end of file diff --git a/data_tools/data_loader_VGGFace2HQ_multigpu.py b/data_tools/data_loader_VGGFace2HQ_multigpu.py index d8e4570..6dcb269 100644 --- a/data_tools/data_loader_VGGFace2HQ_multigpu.py +++ b/data_tools/data_loader_VGGFace2HQ_multigpu.py @@ -5,7 +5,7 @@ # Created Date: Sunday February 6th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 15th February 2022 1:50:19 am +# Last Modified: Wednesday, 6th April 2022 12:53:53 am # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -108,7 +108,7 @@ class VGGFace2HQDataset(data.Dataset): subffix='jpg', random_seed=1234): """Initialize and preprocess the VGGFace2 HQ dataset.""" - self.image_dir = image_dir + self.image_dir = image_dir["images"] self.img_transform = img_transform self.subffix = subffix self.dataset = [] diff --git a/data_tools/data_loader_VGGFace2HQ_multigpu_w_mask.py b/data_tools/data_loader_VGGFace2HQ_multigpu_w_mask.py new file mode 100644 index 0000000..af1ce98 --- /dev/null +++ b/data_tools/data_loader_VGGFace2HQ_multigpu_w_mask.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: data_loader_VGGFace2HQ copy.py +# Created Date: Sunday February 6th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 3rd April 2022 9:48:23 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + + +import os +import glob +import torch +import random +import numpy as np +from PIL import Image +from torch.utils import data +from torchvision import transforms as T +import cv2 +# from StyleResize import StyleResize + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +class data_prefetcher(): + def __init__(self, loader, cur_gpu): + torch.cuda.set_device(cur_gpu) # must add this line to avoid excessive use of GPU 0 by the prefetcher + self.loader = loader + self.dataiter = iter(loader) + self.stream = torch.cuda.Stream(device=cur_gpu) + self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1) + self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1) + self.cur_gpu = cur_gpu + # With Amp, it isn't necessary to manually convert data to half. + # if args.fp16: + # self.mean = self.mean.half() + # self.std = self.std.half() + # self.num_images = loader.__len__() + self.preload() + + def preload(self): + # try: + self.src_image1, self.src_image2, self.mask = next(self.dataiter) + # except StopIteration: + # self.dataiter = iter(self.loader) + # self.src_image1, self.src_image2 = next(self.dataiter) + + with torch.cuda.stream(self.stream): + self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True) + self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std) + self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True) + self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std) + self.mask = self.mask.cuda(device= self.cur_gpu, non_blocking=True) + self.mask = self.mask/255.0 + # With Amp, it isn't necessary to manually convert data to half. + # if args.fp16: + # self.next_input = self.next_input.half() + # else: + # self.next_input = self.next_input.float() + # self.next_input = self.next_input.sub_(self.mean).div_(self.std) + def next(self): + torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream) + src_image1 = self.src_image1 + src_image2 = self.src_image2 + mask = self.mask + self.preload() + return src_image1, src_image2, mask + + # def __len__(self): + # """Return the number of images.""" + # return self.num_images + +class VGGFace2HQDataset(data.Dataset): + """Dataset class for the Artworks dataset and content dataset.""" + + def __init__(self, + image_dir, + mask_dir, + img_transform, + subffix='jpg', + random_seed=1234): + """Initialize and preprocess the VGGFace2 HQ dataset.""" + self.image_dir = image_dir + self.mask_dir = mask_dir + self.img_transform = img_transform + self.subffix = subffix + self.dataset = [] + self.random_seed = random_seed + self.preprocess() + self.num_images = len(self.dataset) + + def preprocess(self): + """Preprocess the VGGFace2 HQ dataset.""" + print("processing VGGFace2 HQ dataset images...") + + temp_path = os.path.join(self.image_dir,'*/') + pathes = glob.glob(temp_path) + self.dataset = [] + for dir_item in pathes: + join_path = glob.glob(os.path.join(dir_item,'*.jpg')) + print("processing %s"%dir_item,end='\r') + dir_path = os.path.dirname(join_path[1]) + dir_name = os.path.join(self.mask_dir, os.path.basename(dir_path)) + # print(dir_name) + temp_list = [] + for item in join_path: + img_name = os.path.basename(item) + img_name, _ = os.path.splitext(img_name) + temp_list.append({ + "i":item, + "m":os.path.join(dir_name, img_name + ".png") + }) + self.dataset.append(temp_list) + random.seed(self.random_seed) + random.shuffle(self.dataset) + print('Finished preprocessing the VGGFace2 HQ dataset, total dirs number: %d...'%len(self.dataset)) + + def __getitem__(self, index): + """Return two src domain images and two dst domain images.""" + dir_tmp1 = self.dataset[index] + dir_tmp1_len = len(dir_tmp1) + + filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] + filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] + image1 = self.img_transform(Image.open(filename1["i"])) + image2 = self.img_transform(Image.open(filename2["i"])) + mask = torch.from_numpy(cv2.imread(filename1["m"],0)).unsqueeze(0) + return image1, image2, mask + + def __len__(self): + """Return the number of images.""" + return self.num_images + +def GetLoader( dataset_roots, + rank, + num_gpus, + batch_size=16, + **kwargs + ): + """Build and return a data loader.""" + + data_root = dataset_roots["images"] + mask_root = dataset_roots["masks"] + random_seed = kwargs["random_seed"] + num_workers = kwargs["dataloader_workers"] + + c_transforms = [] + + c_transforms.append(T.ToTensor()) + c_transforms = T.Compose(c_transforms) + + content_dataset = VGGFace2HQDataset( + data_root, + mask_root, + c_transforms, + "jpg", + random_seed) + device = torch.device('cuda', rank) + sampler = InfiniteSampler(dataset=content_dataset, rank=rank, num_replicas=num_gpus, seed=random_seed) + content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size, + drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True, sampler=sampler) + # content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size, + # drop_last=False,shuffle=True,num_workers=num_workers,pin_memory=True) + prefetcher = data_prefetcher(content_data_loader,device) + return prefetcher + +def denorm(x): + out = (x + 1) / 2 + return out.clamp_(0, 1) \ No newline at end of file diff --git a/dataset.check.py b/dataset.check.py new file mode 100644 index 0000000..e69de45 --- /dev/null +++ b/dataset.check.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: dataset.check.py +# Created Date: Sunday April 3rd 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 3rd April 2022 2:57:48 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import glob +from utilities.json_config import readConfig, writeConfig + +# dataset = "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan" +# mask_dir= "G:/VGGFace2-HQ/VGGface2_HQ_original_aligned_mask" + +savePath = "./vggface2hq_failed.txt" +env_config = readConfig('env/env.json') +env_config = env_config["path"] +dataset = env_config["dataset_paths"]["vggface2_hq"]["images"] +mask_dir = env_config["dataset_paths"]["vggface2_hq"]["masks"] + +temp_path = os.path.join(dataset,'*/') +pathes = glob.glob(temp_path) +for dir_item in pathes: + join_path = glob.glob(os.path.join(dir_item,'*.jpg')) + print("processing %s"%dir_item,end='\r') + dir_path = os.path.dirname(join_path[1]) + dir_name = os.path.join(mask_dir, os.path.basename(dir_path)) + # print(dir_name) + temp_list = [] + for item in join_path: + img_name = os.path.basename(item) + img_name, _ = os.path.splitext(img_name) + mask_name = os.path.join(dir_name, img_name + ".png") + if not os.path.exists(mask_name): + print(mask_name) \ No newline at end of file diff --git a/dataset_readme.txt b/dataset_readme.txt new file mode 100644 index 0000000..c2cc6db --- /dev/null +++ b/dataset_readme.txt @@ -0,0 +1,113 @@ + --> 1 + --> 2 +ܽ --> 3 +ͮ --> 4 +־ --> 13 +103᰽ѹcoshunter.top --> 18 +19 --> [YOUMI尤蜜荟] 2020.09.03 VOL.521 妲己_Toxic [61P508MB] + --> 20 + --> 21 + --> 22 +ӱ --> 23 + --> 24 + +G:/4K/90 ١ --> H:/face_data/VGGFace2_HQ/27 +273 --> 28 +G:/4K/291-֧ߴ򿪣غѹhher --> H:/face_data/VGGFace2_HQ/29 +G:/4K/114 --> H:/face_data/VGGFace2_HQ/30 +G:/4K/120 --> H:/face_data/VGGFace2_HQ/31 +G:/4K/124 ӷ --> H:/face_data/VGGFace2_HQ/32 +G:/4K/173 --> H:/face_data/VGGFace2_HQ/33 +G:/4K/191 ֣ˬ --> H:/face_data/VGGFace2_HQ/34 +G:/4K/286˿b/зļУ߽ѹ ͼ 02/зļУ߽ѹ ͼ 02 --> H:/face_data/VGGFace2_HQ/35 +G:/4K/ --> H:/face_data/VGGFace2_HQ/36 +G:/4K/197 ˼Ҽ --> H:/face_data/VGGFace2_HQ/37 +G:/4K/195 Ԭ --> H:/face_data/VGGFace2_HQ/38 +G:/4K/249 ϣ. --> H:/face_data/VGGFace2_HQ/39 +G:/4K/زĺ/ĸԭƬز --> H:/face_data/VGGFace2_HQ/40 +G:/4K/زĺ/źϵ --> H:/face_data/VGGFace2_HQ/41 +G:/4K/231 --> H:/face_data/VGGFace2_HQ/42 +G:/4K/259 ܽ --> H:/face_data/VGGFace2_HQ/43 +G:/4K/190 --> H:/face_data/VGGFace2_HQ/44 +G:/4K/193 --> H:/face_data/VGGFace2_HQ/45 +G:/4K/174 ֮ --> H:/face_data/VGGFace2_HQ/46 +G:/4K/196 ¶ --> H:/face_data/VGGFace2_HQ/47 +G:/4K/236 --> H:/face_data/VGGFace2_HQ/48 +G:/4K/221 Ӱ --> H:/face_data/VGGFace2_HQ/49 +G:/4K/253 --> H:/face_data/VGGFace2_HQ/50 +G:/4K/282 ͮ --> H:/face_data/VGGFace2_HQ/51 +G:/4K/γϵ --> H:/face_data/VGGFace2_HQ/52 +G:/4K/293 --> H:/face_data/VGGFace2_HQ/53 +G:/4K/˽ļ/11.7 --> H:/face_data/VGGFace2_HQ/54 +G:/4K/˽ļ/2016.3.31ľ˽/ԭƬ --> H:/face_data/VGGFace2_HQ/55 +G:/4K/˽ļ/ ԭƬ --> H:/face_data/VGGFace2_HQ/56 +G:/4K/˽ļ/ --> H:/face_data/VGGFace2_HQ/57 +G:/4K/˽ļ/½ļ --> H:/face_data/VGGFace2_HQ/58 +G:/4K/˽ļ/ --> H:/face_data/VGGFace2_HQ/58 +G:/4K/˽ļ/ԭƬ --> H:/face_data/VGGFace2_HQ/59 +G:/4K/˽ļ/ԭƬ22 --> H:/face_data/VGGFace2_HQ/60 +G:/4K/˽ļ/ԭͼ --> H:/face_data/VGGFace2_HQ/61 +G:/4K/ϵ --> H:/face_data/VGGFace2_HQ/62 +G:/4K/춬 --> H:/face_data/VGGFace2_HQ/63 +G:/4K/쾲 --> H:/face_data/VGGFace2_HQ/64 +G:/4K/ӱ --> H:/face_data/VGGFace2_HQ/65 +G:/4K/˽ļ/ԭƬ22 --> H:/face_data/hg +G:/4K/˽ļ/ԭƬ22 --> H:/face_data/hg1 +G:/4K/ӱ --> H:/face_data/VGGFace2_HQ/65 +G:/4K/߶ױ --> H:/face_data/VGGFace2_HQ/66 +G:/4K/־- --> H:/face_data/VGGFace2_HQ/67 +G:/4K/000δ/С --> H:/face_data/VGGFace2_HQ/68 +G:/4K/000δ/˼ϵ1 --> H:/face_data/VGGFace2_HQ/69 +G:/4K/000δ/ --> H:/face_data/VGGFace2_HQ/70 +G:/4K/000δ/ȫ --> H:/face_data/VGGFace2_HQ/71 +G:/4K/000δ//Сӷȴ//ӷ --> H:/face_data/VGGFace2_HQ/72 +G:/4K/000δ//Сӷȴ//С --> H:/face_data/VGGFace2_HQ/73 +G:/4K/295-֧ߴ򿪣غѹ --> H:/face_data/VGGFace2_HQ/74 +G:/4K/297-֧ߴ򿪣غѹһֻ --> H:/face_data/VGGFace2_HQ/74 +G:/4K/295-֧ߴ򿪣غѹ --> H:/face_data/VGGFace2_HQ/75 +G:/4K/000δ//Сӷȴ//˿ز/ --> H:/face_data/VGGFace2_HQ/76 +G:/4K/000δ//Сӷȴ//˿ز/ƶ --> H:/face_data/VGGFace2_HQ/77 +G:/4K/000δ//Сӷȴ// --> H:/face_data/VGGFace2_HQ/78 +G:/4K/000δ//Сӷȴ// --> H:/face_data/VGGFace2_HQ/79 +G:/4K/000δ//Сӷȴ// --> H:/face_data/VGGFace2_HQ/80 +G:/4K/000δ//Сӷȴ// --> H:/face_data/VGGFace2_HQ/81 +G:/4K/000δ//Сӷȴ//Ϸ --> H:/face_data/VGGFace2_HQ/82 +G:/4K/000δ//Сӷȴ//Ϸ2 --> H:/face_data/VGGFace2_HQ/83 +G:/4K/000δ//Сӷȴ/0905-0912 --> H:/face_data/VGGFace2_HQ/84 +G:/4K/000δ//Сӷȴ/0913-0922 --> H:/face_data/VGGFace2_HQ/85 +G:/4K/000δ//Сӷȴ/0923-1004 --> H:/face_data/VGGFace2_HQ/86 +G:/4K/000δ//Сӷȴ/1005-1010 --> H:/face_data/VGGFace2_HQ/87 +G:/4K/000δ//Сӷȴ/1011-1017 --> H:/face_data/VGGFace2_HQ/88 +G:/4K/000δ//Сӷȴ/1018-1101 --> H:/face_data/VGGFace2_HQ/89 +G:/4K/000δ//Сӷȴ/1102-1114 --> H:/face_data/VGGFace2_HQ/90 +G:/4K/000δ//Сӷȴ/1115-1122 --> H:/face_data/VGGFace2_HQ/91 +G:/4K/000δ//Сӷȴ/1123-1204 --> H:/face_data/VGGFace2_HQ/92 +G:/4K/000δ//Сӷȴ/1205-1212 --> H:/face_data/VGGFace2_HQ/93 +G:/4K/000δ/1 --> H:/face_data/VGGFace2_HQ/94 +G:/4K/000δ/ --> H:/face_data/VGGFace2_HQ/95 +G:/4K/000δ/ȿС贵 --> H:/face_data/VGGFace2_HQ/96 +G:/4K/000δ/ȿӾ --> H:/face_data/VGGFace2_HQ/97 +H:/face_data/VGGFace2_HQ/77 --> H:/face_data/VGGFace2_HQ/98 +G:/4K/16 ӳսƷ --> H:/face_data/VGGFace2_HQ/98 +G:/4K/2 --> H:/face_data/VGGFace2_HQ/99 +G:/4K/108 غ --> H:/face_data/VGGFace2_HQ\108 غ +G:/4K/110 С --> H:/face_data/VGGFace2_HQ\110 С +G:/4K/111 С --> H:/face_data/VGGFace2_HQ\111 С +G:/4K/119 --> H:/face_data/VGGFace2_HQ\119 +G:/4K/120 --> H:/face_data/VGGFace2_HQ\120 +G:/4K/125 俺ϼ --> H:/face_data/VGGFace2_HQ\125 俺ϼ +G:/4K//½ļ (6) --> H:/face_data/VGGFace2_HQ\½ļ (6) +G:/4K//ԭƬ(6) --> H:/face_data/VGGFace2_HQ\ԭƬ(6) +G:/4K//20161113ϵjpg --> H:/face_data/VGGFace2_HQ\20161113ϵjpg +G:/4K//155 --> H:/face_data/VGGFace2_HQ\155 +G:/4K//17 --> H:/face_data/VGGFace2_HQ\17 +G:/4K/249erѹcoshunter.comcoshunter.top --> H:/face_data/VGGFace2_HQ\249erѹcoshunter.comcoshunter +H:/face_data/cw --> H:/face_data/VGGFace2_HQ\cw +H:/face_data/cw --> H:/face_data/VGGFace2_HQ\cw +G:/4K/ --> H:/face_data/VGGFace2_HQ\ +G:/4K/216 ¬ --> H:/face_data/VGGFace2_HQ\216 ¬ +G:/4K/217 ¬ --> H:/face_data/VGGFace2_HQ\217 ¬ +G:/4K/װٺ --> H:/face_data/VGGFace2_HQ\װٺ +G:/4K/װٺ --> H:/face_data/VGGFace2_HQ\װٺ +G:/4K/ --> H:/face_data/VGGFace2_HQ\ +G:/4K/179 н --> H:/face_data/VGGFace2_HQ\179 н diff --git a/face_crop.py b/face_crop.py index 050f313..eafe538 100644 --- a/face_crop.py +++ b/face_crop.py @@ -5,7 +5,7 @@ # Created Date: Tuesday February 1st 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Wednesday, 2nd February 2022 4:13:28 pm +# Last Modified: Sunday, 24th April 2022 2:01:47 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -26,6 +26,8 @@ import tkinter.ttk as ttk import subprocess from pathlib import Path +import numpy as np + from insightface_func.face_detect_crop_multi import Face_detect_crop class TextRedirector(object): @@ -113,6 +115,7 @@ class Application(tk.Frame): label_frame.columnconfigure(0, weight=1) label_frame.columnconfigure(1, weight=1) label_frame.columnconfigure(2, weight=1) + label_frame.columnconfigure(3, weight=1) tk.Label(label_frame, text="Crop Size:",font=font_list,justify="left")\ .grid(row=0,column=0,sticky=tk.EW) @@ -123,6 +126,9 @@ class Application(tk.Frame): tk.Label(label_frame, text="Target Format:",font=font_list,justify="left")\ .grid(row=0,column=2,sticky=tk.EW) + tk.Label(label_frame, text="Blurry Thredhold:",font=font_list,justify="left")\ + .grid(row=0,column=3,sticky=tk.EW) + ################################################################################################# test_frame = tk.Frame(self.master) @@ -151,8 +157,10 @@ class Application(tk.Frame): self.format_com["value"] = ["png","jpg"] self.format_com.current(0) - - + self.thredhold = tkinter.StringVar() + tk.Entry(test_frame, textvariable= self.thredhold, font=font_list)\ + .grid(row=0,column=3,sticky=tk.EW) + self.thredhold.set("70") ################################################################################################# scale_frame = tk.Frame(self.master) scale_frame.pack(fill="both", padx=5,pady=5) @@ -200,8 +208,10 @@ class Application(tk.Frame): def select_task(self): path = askdirectory() - print("Selected source directory: %s"%path) - self.img_path.set(path) + + if os.path.isdir(path): + print("Selected source directory: %s"%path) + self.img_path.set(path) def Select_Target(self): thread_update = threading.Thread(target=self.select_target_task) @@ -209,8 +219,9 @@ class Application(tk.Frame): def select_target_task(self): path = askdirectory() - print("Selected target directory: %s"%path) - self.save_path.set(path) + if os.path.isdir(path): + print("Selected target directory: %s"%path) + self.save_path.set(path) def Crop(self): thread_update = threading.Thread(target=self.crop_task) @@ -218,39 +229,59 @@ class Application(tk.Frame): def crop_task(self): mode = self.align_com.get() + if mode == "VGGFace": + mode = "None" crop_size = int(self.test_com.get()) path = self.img_path.get() tg_path = self.save_path.get() + blur_t = self.thredhold.get() + basepath = os.path.splitext(os.path.basename(path))[0] + tg_path = os.path.join("H:/face_data/VGGFace2_HQ",basepath) + print("target path: ",tg_path) + if not os.path.exists(tg_path): + os.makedirs(tg_path) tg_format = self.format_com.get() min_scale = float(self.min_scale.get()) - blur_t = 100.0 - font = cv2.FONT_HERSHEY_SIMPLEX + blur_t = float(blur_t) + print("Blurry thredhold %f"%blur_t) self.detect.prepare(ctx_id = 0, det_thresh=0.6,\ det_size=(640,640),mode = mode,crop_size=crop_size,ratio=min_scale) + log_file = "./dataset_readme.txt" + with open(log_file,'a+') as logf: # ,encoding='UTF-8' + logf.writelines("%s --> %s\n"%(path,tg_path)) if path and tg_path: imgs_list = [] if os.path.isdir(path): print("Input a dir....") - imgs = glob.glob(os.path.join(path,"*")) - for item in imgs: + # imgs = glob.glob(os.path.join(path,"**")) + for item in glob.iglob(os.path.join(path,"**"),recursive=True): imgs_list.append(item) # print(imgs_list) index = 0 for img in imgs_list: print(img) - attr_img_ori= cv2.imread(img) + try: + attr_img_ori = cv2.imdecode(np.fromfile(img, dtype=np.uint8),-1) + except: + print("Illegal file!") + continue + # attr_img_ori= cv2.imread(img) try: attr_img_align_crop, _ = self.detect.get(attr_img_ori) sub_index = 0 + if len(attr_img_align_crop) < 1: + print("Small face") for face_i in attr_img_align_crop: imageVar = cv2.Laplacian(face_i, cv2.CV_64F).var() f_path =os.path.join(tg_path, str(index).zfill(6)+"_%d.%s"%(sub_index,tg_format)) + # print("save path: ",f_path) if imageVar < blur_t: print("Over blurry image!") continue # face_i = cv2.putText(face_i, '%.1f'%imageVar,(50, 50), font, 0.8, (15, 9, 255), 2) - cv2.imwrite(f_path,face_i) + # cv2.imwrite(f_path,face_i) + cv2.imencode('.png',face_i)[1].tofile(f_path) sub_index += 1 index += 1 except: diff --git a/face_crop_record.py b/face_crop_record.py new file mode 100644 index 0000000..5e32aaa --- /dev/null +++ b/face_crop_record.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: face_crop.py +# Created Date: Tuesday February 1st 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 22nd April 2022 8:43:40 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + + +import os +import cv2 +import sys +import glob +import json +import tkinter +from tkinter.filedialog import askdirectory + +import threading +import tkinter as tk +import tkinter.ttk as ttk + +import subprocess +from pathlib import Path + +from insightface_func.face_detect_crop_multi import Face_detect_crop + +class TextRedirector(object): + def __init__(self, widget, tag="stdout"): + self.widget = widget + self.tag = tag + + def write(self, str): + self.widget.configure(state="normal") + self.widget.insert("end", str, (self.tag,)) + self.widget.configure(state="disabled") + self.widget.see(tk.END) + + def flush(self): + pass + +############################################################# +# Main Class +############################################################# + +class Application(tk.Frame): + + + def __init__(self, master=None): + tk.Frame.__init__(self, master,bg='black') + # self.font_size = 16 + self.font_list = ("Times New Roman",14) + self.padx = 5 + self.pady = 5 + self.window_init() + + def __label_text__(self, usr, root): + return "User Name: %s\nWorkspace: %s"%(usr, root) + + def window_init(self): + cwd = os.getcwd() + self.master.title('Face Crop - %s'%cwd) + # self.master.iconbitmap('./utilities/_logo.ico') + self.master.geometry("{}x{}".format(640, 600)) + + font_list = self.font_list + + ################################################################################################# + list_frame = tk.Frame(self.master) + list_frame.pack(fill="both", padx=5,pady=5) + list_frame.columnconfigure(0, weight=1) + list_frame.columnconfigure(1, weight=1) + list_frame.columnconfigure(2, weight=1) + + self.img_path = tkinter.StringVar() + + tk.Label(list_frame, text="Image/Video Path:",font=font_list,justify="left")\ + .grid(row=0,column=0,sticky=tk.EW) + + tk.Entry(list_frame, textvariable= self.img_path, font=font_list)\ + .grid(row=0,column=1,sticky=tk.EW) + + + tk.Button(list_frame, text = "Select Path", font=font_list, + command = self.Select, bg='#F4A460', fg='#F5F5F5')\ + .grid(row=0,column=2,sticky=tk.EW) + ################################################################################################# + list_frame1 = tk.Frame(self.master) + list_frame1.pack(fill="both", padx=5,pady=5) + list_frame1.columnconfigure(0, weight=1) + list_frame1.columnconfigure(1, weight=1) + list_frame1.columnconfigure(2, weight=1) + + self.save_path = tkinter.StringVar() + + tk.Label(list_frame1, text="Target Path:",font=font_list,justify="left")\ + .grid(row=0,column=0,sticky=tk.EW) + + tk.Entry(list_frame1, textvariable= self.save_path, font=font_list)\ + .grid(row=0,column=1,sticky=tk.EW) + + + tk.Button(list_frame1, text = "Select Path", font=font_list, + command = self.Select_Target, bg='#F4A460', fg='#F5F5F5')\ + .grid(row=0,column=2,sticky=tk.EW) + + ################################################################################################# + label_frame = tk.Frame(self.master) + label_frame.pack(fill="both", padx=5,pady=5) + label_frame.columnconfigure(0, weight=1) + label_frame.columnconfigure(1, weight=1) + label_frame.columnconfigure(2, weight=1) + + tk.Label(label_frame, text="Crop Size:",font=font_list,justify="left")\ + .grid(row=0,column=0,sticky=tk.EW) + + tk.Label(label_frame, text="Align Mode:",font=font_list,justify="left")\ + .grid(row=0,column=1,sticky=tk.EW) + + tk.Label(label_frame, text="Target Format:",font=font_list,justify="left")\ + .grid(row=0,column=2,sticky=tk.EW) + + ################################################################################################# + + test_frame = tk.Frame(self.master) + test_frame.pack(fill="both", padx=5,pady=5) + test_frame.columnconfigure(0, weight=1) + test_frame.columnconfigure(1, weight=1) + test_frame.columnconfigure(2, weight=1) + + self.test_var = tkinter.StringVar() + + self.test_com = ttk.Combobox(test_frame, textvariable=self.test_var) + self.test_com.grid(row=0,column=0,sticky=tk.EW) + self.test_com["value"] = [256,512,768,1024] + self.test_com.current(1) + + self.align_var = tkinter.StringVar() + self.align_com = ttk.Combobox(test_frame, textvariable=self.align_var) + self.align_com.grid(row=0,column=1,sticky=tk.EW) + self.align_com["value"] = ["VGGFace","ffhq"] + self.align_com.current(0) + + self.format_var = tkinter.StringVar() + + self.format_com = ttk.Combobox(test_frame, textvariable=self.format_var) + self.format_com.grid(row=0,column=2,sticky=tk.EW) + self.format_com["value"] = ["png","jpg"] + self.format_com.current(0) + + + + ################################################################################################# + scale_frame = tk.Frame(self.master) + scale_frame.pack(fill="both", padx=5,pady=5) + scale_frame.columnconfigure(0, weight=2) + label_frame.columnconfigure(1, weight=1) + # label_frame.columnconfigure(2, weight=1) + + tk.Label(scale_frame, text="Min Size:",font=font_list,justify="left")\ + .grid(row=0,column=0,sticky=tk.EW) + self.min_scale = tkinter.StringVar() + tk.Scale(scale_frame, from_=0.5, to=2.0, length=500, orient=tk.HORIZONTAL, variable= self.min_scale,\ + font=font_list, resolution=0.1).grid(row=0,column=1,sticky=tk.EW) + + ################################################################################################# + test_frame1 = tk.Frame(self.master) + test_frame1.pack(fill="both", padx=5,pady=5) + test_frame1.columnconfigure(0, weight=1) + # test_frame1.columnconfigure(1, weight=1) + + test_update_button = tk.Button(test_frame1, text = "Crop", + font=font_list, command = self.Crop, bg='#F4A460', fg='#F5F5F5') + test_update_button.grid(row=0,column=0,sticky=tk.EW) + + + + ################################################################################################# + + text = tk.Text(self.master, wrap="word") + text.pack(fill="both",expand="yes", padx=5,pady=5) + + + sys.stdout = TextRedirector(text, "stdout") + + self.init_algorithm() + self.master.protocol("WM_DELETE_WINDOW", self.on_closing) + + def init_algorithm(self): + self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') + + + # def __scaning_logs__(self): + def Select(self): + thread_update = threading.Thread(target=self.select_task) + thread_update.start() + + def select_task(self): + path = askdirectory() + print("Selected source directory: %s"%path) + self.img_path.set(path) + + def Select_Target(self): + thread_update = threading.Thread(target=self.select_target_task) + thread_update.start() + + def select_target_task(self): + path = askdirectory() + print("Selected target directory: %s"%path) + self.save_path.set(path) + + def Crop(self): + thread_update = threading.Thread(target=self.crop_task) + thread_update.start() + + def crop_task(self): + mode = self.align_com.get() + crop_size = int(self.test_com.get()) + + path = self.img_path.get() + tg_path = self.save_path.get() + if not os.path.exists(tg_path): + os.makedirs(tg_path) + tg_format = self.format_com.get() + min_scale = float(self.min_scale.get()) + blur_t = 10.0 + font = cv2.FONT_HERSHEY_SIMPLEX + self.detect.prepare(ctx_id = 0, det_thresh=0.6,\ + det_size=(640,640),mode = mode,crop_size=crop_size,ratio=min_scale) + if path and tg_path: + imgs_list = [] + if os.path.isdir(path): + print("Input a dir....") + imgs = glob.glob(os.path.join(path,"*")) + for item in imgs: + imgs_list.append(item) + # print(imgs_list) + index = 0 + for img in imgs_list: + print(img) + attr_img_ori= cv2.imread(img) + try: + attr_img_align_crop, _ = self.detect.get(attr_img_ori) + sub_index = 0 + if len(attr_img_align_crop) < 1: + print("Small face") + for face_i in attr_img_align_crop: + imageVar = cv2.Laplacian(face_i, cv2.CV_64F).var() + f_path =os.path.join(tg_path, str(index).zfill(6)+"_%d.%s"%(sub_index,tg_format)) + if imageVar < blur_t: + print("Over blurry image!") + continue + # face_i = cv2.putText(face_i, '%.1f'%imageVar,(50, 50), font, 0.8, (15, 9, 255), 2) + cv2.imwrite(f_path,face_i) + sub_index += 1 + index += 1 + except: + print("Detect no face!") + continue + else: + print("Input an image....") + imgs_list.append(path) + print("Process finished!") + else: + print("Pathes are invalid!") + + def on_closing(self): + + # self.__save_config__() + self.master.destroy() + + + +if __name__ == "__main__": + app = Application() + app.mainloop() \ No newline at end of file diff --git a/face_crop_video.py b/face_crop_video.py index 763ca0f..b358b73 100644 --- a/face_crop_video.py +++ b/face_crop_video.py @@ -5,7 +5,7 @@ # Created Date: Tuesday February 1st 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Wednesday, 2nd February 2022 11:17:04 pm +# Last Modified: Friday, 15th April 2022 10:07:15 am # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -27,7 +27,7 @@ def getParameters(): parser = argparse.ArgumentParser() parser.add_argument('-p', '--save_path', type=str, default="./output/", help="The root path for saving cropped images") - parser.add_argument('-v', '--video', type=str, default="G:\\4K\\05.mp4", + parser.add_argument('-v', '--video', type=str, default="G:\\4K\\Faith.Makes.Great.2021\\40.mp4", help="The path for input video") parser.add_argument('-c', '--crop_size', type=int, default=512, help="expected image resolution") @@ -39,7 +39,7 @@ def getParameters(): choices=['jpg', 'png'],help="target file format") parser.add_argument('-i', '--interval', type=int, default=20, help="number of frames interval") - parser.add_argument('-b', '--blur', type=float, default=10.0, + parser.add_argument('-b', '--blur', type=float, default=20.0, help="blur degree") return parser.parse_args() diff --git a/face_enhancer/experiments/pretrained_models/README.md b/face_enhancer/experiments/pretrained_models/README.md new file mode 100644 index 0000000..3401a5c --- /dev/null +++ b/face_enhancer/experiments/pretrained_models/README.md @@ -0,0 +1,7 @@ +# Pre-trained Models and Other Data + +Download pre-trained models and other data. Put them in this folder. + +1. [Pretrained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth) +1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth) +1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth) diff --git a/face_enhancer/gfpgan/__init__.py b/face_enhancer/gfpgan/__init__.py new file mode 100644 index 0000000..94daaee --- /dev/null +++ b/face_enhancer/gfpgan/__init__.py @@ -0,0 +1,7 @@ +# flake8: noqa +from .archs import * +from .data import * +from .models import * +from .utils import * + +# from .version import * diff --git a/face_enhancer/gfpgan/archs/__init__.py b/face_enhancer/gfpgan/archs/__init__.py new file mode 100644 index 0000000..bec5f17 --- /dev/null +++ b/face_enhancer/gfpgan/archs/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import arch modules for registry +# scan all the files that end with '_arch.py' under the archs folder +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames] diff --git a/face_enhancer/gfpgan/archs/arcface_arch.py b/face_enhancer/gfpgan/archs/arcface_arch.py new file mode 100644 index 0000000..e6d3bd9 --- /dev/null +++ b/face_enhancer/gfpgan/archs/arcface_arch.py @@ -0,0 +1,245 @@ +import torch.nn as nn +from basicsr.utils.registry import ARCH_REGISTRY + + +def conv3x3(inplanes, outplanes, stride=1): + """A simple wrapper for 3x3 convolution with padding. + + Args: + inplanes (int): Channel number of inputs. + outplanes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + """ + return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + """Basic residual block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class IRBlock(nn.Module): + """Improved residual block (IR Block) used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 4 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBlock(nn.Module): + """The squeeze-and-excitation block (SEBlock) used in the IRBlock. + + Args: + channel (int): Channel number of inputs. + reduction (int): Channel reduction ration. Default: 16. + """ + + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), + nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +@ARCH_REGISTRY.register() +class ResNetArcFace(nn.Module): + """ArcFace with ResNet architectures. + + Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. + + Args: + block (str): Block used in the ArcFace architecture. + layers (tuple(int)): Block numbers in each layer. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + + def __init__(self, block, layers, use_se=True): + if block == 'IRBlock': + block = IRBlock + self.inplanes = 64 + self.use_se = use_se + super(ResNetArcFace, self).__init__() + + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + self.bn5 = nn.BatchNorm1d(512) + + # initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for _ in range(1, num_blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + x = self.bn5(x) + + return x diff --git a/face_enhancer/gfpgan/archs/gfpgan_bilinear_arch.py b/face_enhancer/gfpgan/archs/gfpgan_bilinear_arch.py new file mode 100644 index 0000000..52e0de8 --- /dev/null +++ b/face_enhancer/gfpgan/archs/gfpgan_bilinear_arch.py @@ -0,0 +1,312 @@ +import math +import random +import torch +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn + +from .gfpganv1_arch import ResUpBlock +from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, + StyleGAN2GeneratorBilinear) + + +class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for + deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + lr_mlp=0.01, + narrow=1, + sft_half=False): + super(StyleGAN2GeneratorBilinearSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + lr_mlp=lr_mlp, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorBilinearSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +@ARCH_REGISTRY.register() +class GFPGANBilinear(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for + deployment. It can be easily converted to the clean version: GFPGANv1Clean. + + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANBilinear, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels)) + in_channels = out_channels + + self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = EqualLinear( + channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + lr_mlp=lr_mlp, + narrow=narrow, + sft_half=sft_half) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1))) + self.condition_shift.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0))) + + def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True): + """Forward function for GFPGANBilinear. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = self.conv_body_first(x) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + + feat = self.final_conv(feat) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs diff --git a/face_enhancer/gfpgan/archs/gfpganv1_arch.py b/face_enhancer/gfpgan/archs/gfpganv1_arch.py new file mode 100644 index 0000000..e092b4f --- /dev/null +++ b/face_enhancer/gfpgan/archs/gfpganv1_arch.py @@ -0,0 +1,439 @@ +import math +import random +import torch +from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, + StyleGAN2Generator) +from basicsr.ops.fused_act import FusedLeakyReLU +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + + +class StyleGAN2GeneratorSFT(StyleGAN2Generator): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be + applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1). + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + resample_kernel=(1, 3, 3, 1), + lr_mlp=0.01, + narrow=1, + sft_half=False): + super(StyleGAN2GeneratorSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + resample_kernel=resample_kernel, + lr_mlp=lr_mlp, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ConvUpLayer(nn.Module): + """Convolutional upsampling layer. It uses bilinear upsampler + Conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + activate (bool): Whether use activateion. Default: True. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + bias=True, + bias_init_val=0, + activate=True): + super(ConvUpLayer, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + # self.scale is used to scale the convolution weights, which is related to the common initializations. + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + + if bias and not activate: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + # activation + if activate: + if bias: + self.activation = FusedLeakyReLU(out_channels) + else: + self.activation = ScaledLeakyReLU(0.2) + else: + self.activation = None + + def forward(self, x): + # bilinear upsample + out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + # conv + out = F.conv2d( + out, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + # activation + if self.activation is not None: + out = self.activation(out) + return out + + +class ResUpBlock(nn.Module): + """Residual block with upsampling. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + """ + + def __init__(self, in_channels, out_channels): + super(ResUpBlock, self).__init__() + + self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) + self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True) + self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out + + +@ARCH_REGISTRY.register() +class GFPGANv1(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be + applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1). + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + resample_kernel=(1, 3, 3, 1), + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANv1, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel)) + in_channels = out_channels + + self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = EqualLinear( + channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + resample_kernel=resample_kernel, + lr_mlp=lr_mlp, + narrow=narrow, + sft_half=sft_half) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1))) + self.condition_shift.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0))) + + def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True): + """Forward function for GFPGANv1. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = self.conv_body_first(x) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + + feat = self.final_conv(feat) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs + + +@ARCH_REGISTRY.register() +class FacialComponentDiscriminator(nn.Module): + """Facial component (eyes, mouth, noise) discriminator used in GFPGAN. + """ + + def __init__(self): + super(FacialComponentDiscriminator, self).__init__() + # It now uses a VGG-style architectrue with fixed model size + self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False) + + def forward(self, x, return_feats=False): + """Forward function for FacialComponentDiscriminator. + + Args: + x (Tensor): Input images. + return_feats (bool): Whether to return intermediate features. Default: False. + """ + feat = self.conv1(x) + feat = self.conv3(self.conv2(feat)) + rlt_feats = [] + if return_feats: + rlt_feats.append(feat.clone()) + feat = self.conv5(self.conv4(feat)) + if return_feats: + rlt_feats.append(feat.clone()) + out = self.final_conv(feat) + + if return_feats: + return out, rlt_feats + else: + return out, None diff --git a/face_enhancer/gfpgan/archs/gfpganv1_clean_arch.py b/face_enhancer/gfpgan/archs/gfpganv1_clean_arch.py new file mode 100644 index 0000000..eb2e15d --- /dev/null +++ b/face_enhancer/gfpgan/archs/gfpganv1_clean_arch.py @@ -0,0 +1,324 @@ +import math +import random +import torch +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + +from .stylegan2_clean_arch import StyleGAN2GeneratorClean + + +class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False): + super(StyleGAN2GeneratorCSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorCSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ResBlock(nn.Module): + """Residual block with bilinear upsampling/downsampling. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + mode (str): Upsampling/downsampling mode. Options: down | up. Default: down. + """ + + def __init__(self, in_channels, out_channels, mode='down'): + super(ResBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) + self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) + if mode == 'down': + self.scale_factor = 0.5 + elif mode == 'up': + self.scale_factor = 2 + + def forward(self, x): + out = F.leaky_relu_(self.conv1(x), negative_slope=0.2) + # upsample/downsample + out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) + out = F.leaky_relu_(self.conv2(out), negative_slope=0.2) + # skip + x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) + skip = self.skip(x) + out = out + skip + return out + + +@ARCH_REGISTRY.register() +class GFPGANv1Clean(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANv1Clean, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down')) + in_channels = out_channels + + self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up')) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorCSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow, + sft_half=sft_half) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) + self.condition_shift.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) + + def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True): + """Forward function for GFPGANv1Clean. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs diff --git a/face_enhancer/gfpgan/archs/stylegan2_bilinear_arch.py b/face_enhancer/gfpgan/archs/stylegan2_bilinear_arch.py new file mode 100644 index 0000000..1342ee3 --- /dev/null +++ b/face_enhancer/gfpgan/archs/stylegan2_bilinear_arch.py @@ -0,0 +1,613 @@ +import math +import random +import torch +from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + + +class NormStyleCode(nn.Module): + + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +class EqualLinear(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Size of each sample. + out_channels (int): Size of each output sample. + bias (bool): If set to ``False``, the layer will not learn an additive + bias. Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + lr_mul (float): Learning rate multiplier. Default: 1. + activation (None | str): The activation after ``linear`` operation. + Supported: 'fused_lrelu', None. Default: None. + """ + + def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None): + super(EqualLinear, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.lr_mul = lr_mul + self.activation = activation + if self.activation not in ['fused_lrelu', None]: + raise ValueError(f'Wrong activation value in EqualLinear: {activation}' + "Supported ones are: ['fused_lrelu', None].") + self.scale = (1 / math.sqrt(in_channels)) * lr_mul + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + if self.bias is None: + bias = None + else: + bias = self.bias * self.lr_mul + if self.activation == 'fused_lrelu': + out = F.linear(x, self.weight * self.scale) + out = fused_leaky_relu(out, bias) + else: + out = F.linear(x, self.weight * self.scale, bias=bias) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, bias={self.bias is not None})') + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. + Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + eps (float): A value added to the denominator for numerical stability. + Default: 1e-8. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8, + interpolation_mode='bilinear'): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + self.interpolation_mode = interpolation_mode + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + # modulation inside each modulated conv + self.modulation = EqualLinear( + num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None) + + self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.scale * self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) + + if self.sample_mode == 'upsample': + x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners) + elif self.sample_mode == 'downsample': + x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size}, ' + f'demodulate={self.demodulate}, sample_mode={self.sample_mode})') + + +class StyleConv(nn.Module): + """Style conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode='bilinear'): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode, + interpolation_mode=interpolation_mode) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.activate = FusedLeakyReLU(out_channels) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # activation (with bias) + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + + def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'): + super(ToRGB, self).__init__() + self.upsample = upsample + self.interpolation_mode = interpolation_mode + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + self.modulated_conv = ModulatedConv2d( + in_channels, + 3, + kernel_size=1, + num_style_feat=num_style_feat, + demodulate=False, + sample_mode=None, + interpolation_mode=interpolation_mode) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate( + skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +@ARCH_REGISTRY.register() +class StyleGAN2GeneratorBilinear(nn.Module): + """StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + lr_mlp=0.01, + narrow=1, + interpolation_mode='bilinear'): + super(StyleGAN2GeneratorBilinear, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.append( + EqualLinear( + num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp, + activation='fused_lrelu')) + self.style_mlp = nn.Sequential(*style_mlp_layers) + + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv( + channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode=interpolation_mode) + self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels['4'] + # noise + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample', + interpolation_mode=interpolation_mode)) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode=interpolation_mode)) + self.to_rgbs.append( + ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2Generator. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. + Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is + False. Default: True. + truncation (float): TODO. Default: 1. + truncation_latent (Tensor | None): TODO. Default: None. + inject_index (int | None): The injection index for mixing noise. + Default: None. + return_latents (bool): Whether to return style latents. + Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latent with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ScaledLeakyReLU(nn.Module): + """Scaled LeakyReLU. + + Args: + negative_slope (float): Negative slope. Default: 0.2. + """ + + def __init__(self, negative_slope=0.2): + super(ScaledLeakyReLU, self).__init__() + self.negative_slope = negative_slope + + def forward(self, x): + out = F.leaky_relu(x, negative_slope=self.negative_slope) + return out * math.sqrt(2) + + +class EqualConv2d(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. + Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. + Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0): + super(EqualConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + out = F.conv2d( + x, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size},' + f' stride={self.stride}, padding={self.padding}, ' + f'bias={self.bias is not None})') + + +class ConvLayer(nn.Sequential): + """Conv Layer used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Kernel size. + downsample (bool): Whether downsample by a factor of 2. + Default: False. + bias (bool): Whether with bias. Default: True. + activate (bool): Whether use activateion. Default: True. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + downsample=False, + bias=True, + activate=True, + interpolation_mode='bilinear'): + layers = [] + self.interpolation_mode = interpolation_mode + # downsample + if downsample: + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + + layers.append( + torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners)) + stride = 1 + self.padding = kernel_size // 2 + # conv + layers.append( + EqualConv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias + and not activate)) + # activation + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channels)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super(ConvLayer, self).__init__(*layers) + + +class ResBlock(nn.Module): + """Residual block used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + """ + + def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'): + super(ResBlock, self).__init__() + + self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) + self.conv2 = ConvLayer( + in_channels, + out_channels, + 3, + downsample=True, + interpolation_mode=interpolation_mode, + bias=True, + activate=True) + self.skip = ConvLayer( + in_channels, + out_channels, + 1, + downsample=True, + interpolation_mode=interpolation_mode, + bias=False, + activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out diff --git a/face_enhancer/gfpgan/archs/stylegan2_clean_arch.py b/face_enhancer/gfpgan/archs/stylegan2_clean_arch.py new file mode 100644 index 0000000..9e2ee94 --- /dev/null +++ b/face_enhancer/gfpgan/archs/stylegan2_clean_arch.py @@ -0,0 +1,368 @@ +import math +import random +import torch +from basicsr.archs.arch_util import default_init_weights +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + + +class NormStyleCode(nn.Module): + + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + eps (float): A value added to the denominator for numerical stability. Default: 1e-8. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + + # modulation inside each modulated conv + self.modulation = nn.Linear(num_style_feat, in_channels, bias=True) + # initialization + default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear') + + self.weight = nn.Parameter( + torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) / + math.sqrt(in_channels * kernel_size**2)) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) + + # upsample or downsample if necessary + if self.sample_mode == 'upsample': + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + elif self.sample_mode == 'downsample': + x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})') + + +class StyleConv(nn.Module): + """Style conv used in StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + """ + + def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) + self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) * 2**0.5 # for conversion + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # add bias + out = out + self.bias + # activation + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB (image space) from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + + def __init__(self, in_channels, num_style_feat, upsample=True): + super(ToRGB, self).__init__() + self.upsample = upsample + self.modulated_conv = ModulatedConv2d( + in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +@ARCH_REGISTRY.register() +class StyleGAN2GeneratorClean(nn.Module): + """Clean version of StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1): + super(StyleGAN2GeneratorClean, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.extend( + [nn.Linear(num_style_feat, num_style_feat, bias=True), + nn.LeakyReLU(negative_slope=0.2, inplace=True)]) + self.style_mlp = nn.Sequential(*style_mlp_layers) + # initialization + default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu') + + # channel list + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv( + channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None) + self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels['4'] + # noise + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample')) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None)) + self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorClean. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None diff --git a/face_enhancer/gfpgan/data/__init__.py b/face_enhancer/gfpgan/data/__init__.py new file mode 100644 index 0000000..69fd9f9 --- /dev/null +++ b/face_enhancer/gfpgan/data/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import dataset modules for registry +# scan all the files that end with '_dataset.py' under the data folder +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames] diff --git a/face_enhancer/gfpgan/data/ffhq_degradation_dataset.py b/face_enhancer/gfpgan/data/ffhq_degradation_dataset.py new file mode 100644 index 0000000..64e5755 --- /dev/null +++ b/face_enhancer/gfpgan/data/ffhq_degradation_dataset.py @@ -0,0 +1,230 @@ +import cv2 +import math +import numpy as np +import os.path as osp +import torch +import torch.utils.data as data +from basicsr.data import degradations as degradations +from basicsr.data.data_util import paths_from_folder +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, + normalize) + + +@DATASET_REGISTRY.register() +class FFHQDegradationDataset(data.Dataset): + """FFHQ dataset for GFPGAN. + + It reads high resolution images, and then generate low-quality (LQ) images on-the-fly. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + io_backend (dict): IO backend type and other kwarg. + mean (list | tuple): Image mean. + std (list | tuple): Image std. + use_hflip (bool): Whether to horizontally flip. + Please see more options in the codes. + """ + + def __init__(self, opt): + super(FFHQDegradationDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.mean = opt['mean'] + self.std = opt['std'] + self.out_size = opt['out_size'] + + self.crop_components = opt.get('crop_components', False) # facial components + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions + + if self.crop_components: + # load component list from a pre-process pth files + self.components_list = torch.load(opt.get('component_path')) + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + # disk backend: scan file list from a folder + self.paths = paths_from_folder(self.gt_folder) + + # degradation configurations + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] + self.blur_sigma = opt['blur_sigma'] + self.downsample_range = opt['downsample_range'] + self.noise_range = opt['noise_range'] + self.jpeg_range = opt['jpeg_range'] + + # color jitter + self.color_jitter_prob = opt.get('color_jitter_prob') + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob') + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + # to gray + self.gray_prob = opt.get('gray_prob') + + logger = get_root_logger() + logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') + logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') + logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') + logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') + + if self.color_jitter_prob is not None: + logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') + if self.gray_prob is not None: + logger.info(f'Use random gray. Prob: {self.gray_prob}') + self.color_jitter_shift /= 255. + + @staticmethod + def color_jitter(img, shift): + """jitter color: randomly jitter the RGB values, in numpy formats""" + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and brightness is not None: + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = adjust_brightness(img, brightness_factor) + + if fn_id == 1 and contrast is not None: + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + + if fn_id == 2 and saturation is not None: + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + + if fn_id == 3 and hue is not None: + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + def get_component_coordinates(self, index, status): + """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file""" + components_bbox = self.components_list[f'{index:08d}'] + if status[0]: # hflip + # exchange right and left eye + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + # modify the width coordinate + components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0] + components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0] + + # get coordinates + locations = [] + for part in ['left_eye', 'right_eye', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = torch.from_numpy(loc).float() + locations.append(loc) + return locations + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + img_bytes = self.file_client.get(gt_path) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + h, w, _ = img_gt.shape + + # get facial component coordinates + if self.crop_components: + locations = self.get_component_coordinates(index, status) + loc_left_eye, loc_right_eye, loc_mouth = locations + + # ------------------------ generate lq image ------------------------ # + # blur + kernel = degradations.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + noise_range=None) + img_lq = cv2.filter2D(img_gt, -1, kernel) + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) + # noise + if self.noise_range is not None: + img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range) + # jpeg compression + if self.jpeg_range is not None: + img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range) + + # resize to original size + img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_lq = self.color_jitter(img_lq, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) + img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) + if self.opt.get('gt_gray'): # whether convert GT to gray images + img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) + img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) + + # round and clip + img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. + + # normalize + normalize(img_gt, self.mean, self.std, inplace=True) + normalize(img_lq, self.mean, self.std, inplace=True) + + if self.crop_components: + return_dict = { + 'lq': img_lq, + 'gt': img_gt, + 'gt_path': gt_path, + 'loc_left_eye': loc_left_eye, + 'loc_right_eye': loc_right_eye, + 'loc_mouth': loc_mouth + } + return return_dict + else: + return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/face_enhancer/gfpgan/models/__init__.py b/face_enhancer/gfpgan/models/__init__.py new file mode 100644 index 0000000..6afad57 --- /dev/null +++ b/face_enhancer/gfpgan/models/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import model modules for registry +# scan all the files that end with '_model.py' under the model folder +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames] diff --git a/face_enhancer/gfpgan/models/gfpgan_model.py b/face_enhancer/gfpgan/models/gfpgan_model.py new file mode 100644 index 0000000..684fc60 --- /dev/null +++ b/face_enhancer/gfpgan/models/gfpgan_model.py @@ -0,0 +1,579 @@ +import math +import os.path as osp +import torch +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.losses.losses import r1_penalty +from basicsr.metrics import calculate_metric +from basicsr.models.base_model import BaseModel +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +from collections import OrderedDict +from torch.nn import functional as F +from torchvision.ops import roi_align +from tqdm import tqdm + + +@MODEL_REGISTRY.register() +class GFPGANModel(BaseModel): + """The GFPGAN model for Towards real-world blind face restoratin with generative facial prior""" + + def __init__(self, opt): + super(GFPGANModel, self).__init__(opt) + self.idx = 0 # it is used for saving data for check + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + self.log_size = int(math.log(self.opt['network_g']['out_size'], 2)) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + train_opt = self.opt['train'] + + # ----------- define net_d ----------- # + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + # ----------- define net_g with Exponential Moving Average (EMA) ----------- # + # net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + + self.net_g.train() + self.net_d.train() + self.net_g_ema.eval() + + # ----------- facial component networks ----------- # + if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt): + self.use_facial_disc = True + else: + self.use_facial_disc = False + + if self.use_facial_disc: + # left eye + self.net_d_left_eye = build_network(self.opt['network_d_left_eye']) + self.net_d_left_eye = self.model_to_device(self.net_d_left_eye) + self.print_network(self.net_d_left_eye) + load_path = self.opt['path'].get('pretrain_network_d_left_eye') + if load_path is not None: + self.load_network(self.net_d_left_eye, load_path, True, 'params') + # right eye + self.net_d_right_eye = build_network(self.opt['network_d_right_eye']) + self.net_d_right_eye = self.model_to_device(self.net_d_right_eye) + self.print_network(self.net_d_right_eye) + load_path = self.opt['path'].get('pretrain_network_d_right_eye') + if load_path is not None: + self.load_network(self.net_d_right_eye, load_path, True, 'params') + # mouth + self.net_d_mouth = build_network(self.opt['network_d_mouth']) + self.net_d_mouth = self.model_to_device(self.net_d_mouth) + self.print_network(self.net_d_mouth) + load_path = self.opt['path'].get('pretrain_network_d_mouth') + if load_path is not None: + self.load_network(self.net_d_mouth, load_path, True, 'params') + + self.net_d_left_eye.train() + self.net_d_right_eye.train() + self.net_d_mouth.train() + + # ----------- define facial component gan loss ----------- # + self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device) + + # ----------- define losses ----------- # + # pixel loss + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + # perceptual loss + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + # L1 loss is used in pyramid loss, component style loss and identity loss + self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device) + + # gan loss (wgan) + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + # ----------- define identity loss ----------- # + if 'network_identity' in self.opt: + self.use_identity = True + else: + self.use_identity = False + + if self.use_identity: + # define identity network + self.network_identity = build_network(self.opt['network_identity']) + self.network_identity = self.model_to_device(self.network_identity) + self.print_network(self.network_identity) + load_path = self.opt['path'].get('pretrain_network_identity') + if load_path is not None: + self.load_network(self.network_identity, load_path, True, None) + self.network_identity.eval() + for param in self.network_identity.parameters(): + param.requires_grad = False + + # regularization weights + self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + self.net_d_reg_every = train_opt['net_d_reg_every'] + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + + # ----------- optimizer g ----------- # + net_g_reg_ratio = 1 + normal_params = [] + for _, param in self.net_g.named_parameters(): + normal_params.append(param) + optim_params_g = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }] + optim_type = train_opt['optim_g'].pop('type') + lr = train_opt['optim_g']['lr'] * net_g_reg_ratio + betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio) + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas) + self.optimizers.append(self.optimizer_g) + + # ----------- optimizer d ----------- # + net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1) + normal_params = [] + for _, param in self.net_d.named_parameters(): + normal_params.append(param) + optim_params_d = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_d']['lr'] + }] + optim_type = train_opt['optim_d'].pop('type') + lr = train_opt['optim_d']['lr'] * net_d_reg_ratio + betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio) + self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas) + self.optimizers.append(self.optimizer_d) + + # ----------- optimizers for facial component networks ----------- # + if self.use_facial_disc: + # setup optimizers for facial component discriminators + optim_type = train_opt['optim_component'].pop('type') + lr = train_opt['optim_component']['lr'] + # left eye + self.optimizer_d_left_eye = self.get_optimizer( + optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_left_eye) + # right eye + self.optimizer_d_right_eye = self.get_optimizer( + optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_right_eye) + # mouth + self.optimizer_d_mouth = self.get_optimizer( + optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_mouth) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + if 'loc_left_eye' in data: + # get facial component locations, shape (batch, 4) + self.loc_left_eyes = data['loc_left_eye'] + self.loc_right_eyes = data['loc_right_eye'] + self.loc_mouths = data['loc_mouth'] + + # uncomment to check data + # import torchvision + # if self.opt['rank'] == 0: + # import os + # os.makedirs('tmp/gt', exist_ok=True) + # os.makedirs('tmp/lq', exist_ok=True) + # print(self.idx) + # torchvision.utils.save_image( + # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) + # torchvision.utils.save_image( + # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) + # self.idx = self.idx + 1 + + def construct_img_pyramid(self): + """Construct image pyramid for intermediate restoration loss""" + pyramid_gt = [self.gt] + down_img = self.gt + for _ in range(0, self.log_size - 3): + down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False) + pyramid_gt.insert(0, down_img) + return pyramid_gt + + def get_roi_regions(self, eye_out_size=80, mouth_out_size=120): + face_ratio = int(self.opt['network_g']['out_size'] / 512) + eye_out_size *= face_ratio + mouth_out_size *= face_ratio + + rois_eyes = [] + rois_mouths = [] + for b in range(self.loc_left_eyes.size(0)): # loop for batch size + # left eye and right eye + img_inds = self.loc_left_eyes.new_full((2, 1), b) + bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4) + rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5) + rois_eyes.append(rois) + # mouse + img_inds = self.loc_left_eyes.new_full((1, 1), b) + rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5) + rois_mouths.append(rois) + + rois_eyes = torch.cat(rois_eyes, 0).to(self.device) + rois_mouths = torch.cat(rois_mouths, 0).to(self.device) + + # real images + all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio + self.left_eyes_gt = all_eyes[0::2, :, :, :] + self.right_eyes_gt = all_eyes[1::2, :, :, :] + self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio + # output + all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio + self.left_eyes = all_eyes[0::2, :, :, :] + self.right_eyes = all_eyes[1::2, :, :, :] + self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + def gray_resize_for_identity(self, out, size=128): + out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) + out_gray = out_gray.unsqueeze(1) + out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) + return out_gray + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + self.optimizer_g.zero_grad() + + # do not update facial component net_d + if self.use_facial_disc: + for p in self.net_d_left_eye.parameters(): + p.requires_grad = False + for p in self.net_d_right_eye.parameters(): + p.requires_grad = False + for p in self.net_d_mouth.parameters(): + p.requires_grad = False + + # image pyramid loss weight + pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0) + if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')): + pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error + if pyramid_loss_weight > 0: + self.output, out_rgbs = self.net_g(self.lq, return_rgb=True) + pyramid_gt = self.construct_img_pyramid() + else: + self.output, out_rgbs = self.net_g(self.lq, return_rgb=False) + + # get roi-align regions + if self.use_facial_disc: + self.get_roi_regions(eye_out_size=80, mouth_out_size=120) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # image pyramid loss + if pyramid_loss_weight > 0: + for i in range(0, self.log_size - 2): + l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight + l_g_total += l_pyramid + loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid + + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + # facial component loss + if self.use_facial_disc: + # left eye + fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True) + l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_left_eye'] = l_g_gan + # right eye + fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True) + l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_right_eye'] = l_g_gan + # mouth + fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True) + l_g_gan = self.cri_component(fake_mouth, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_mouth'] = l_g_gan + + if self.opt['train'].get('comp_style_weight', 0) > 0: + # get gt feat + _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True) + _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True) + _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True) + + def _comp_style(feat, feat_gt, criterion): + return criterion(self._gram_mat(feat[0]), self._gram_mat( + feat_gt[0].detach())) * 0.5 + criterion( + self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach())) + + # facial component style loss + comp_style_loss = 0 + comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1) + comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1) + comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1) + comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight'] + l_g_total += comp_style_loss + loss_dict['l_g_comp_style_loss'] = comp_style_loss + + # identity loss + if self.use_identity: + identity_weight = self.opt['train']['identity_weight'] + # get gray images and resize + out_gray = self.gray_resize_for_identity(self.output) + gt_gray = self.gray_resize_for_identity(self.gt) + + identity_gt = self.network_identity(gt_gray).detach() + identity_out = self.network_identity(out_gray) + l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight + l_g_total += l_identity + loss_dict['l_identity'] = l_identity + + l_g_total.backward() + self.optimizer_g.step() + + # EMA + self.model_ema(decay=0.5**(32 / (10 * 1000))) + + # ----------- optimize net_d ----------- # + for p in self.net_d.parameters(): + p.requires_grad = True + self.optimizer_d.zero_grad() + if self.use_facial_disc: + for p in self.net_d_left_eye.parameters(): + p.requires_grad = True + for p in self.net_d_right_eye.parameters(): + p.requires_grad = True + for p in self.net_d_mouth.parameters(): + p.requires_grad = True + self.optimizer_d_left_eye.zero_grad() + self.optimizer_d_right_eye.zero_grad() + self.optimizer_d_mouth.zero_grad() + + fake_d_pred = self.net_d(self.output.detach()) + real_d_pred = self.net_d(self.gt) + l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d'] = l_d + # In WGAN, real_score should be positive and fake_score should be negative + loss_dict['real_score'] = real_d_pred.detach().mean() + loss_dict['fake_score'] = fake_d_pred.detach().mean() + l_d.backward() + + # regularization loss + if current_iter % self.net_d_reg_every == 0: + self.gt.requires_grad = True + real_pred = self.net_d(self.gt) + l_d_r1 = r1_penalty(real_pred, self.gt) + l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0]) + loss_dict['l_d_r1'] = l_d_r1.detach().mean() + l_d_r1.backward() + + self.optimizer_d.step() + + # optimize facial component discriminators + if self.use_facial_disc: + # left eye + fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach()) + real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt) + l_d_left_eye = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_left_eye'] = l_d_left_eye + l_d_left_eye.backward() + # right eye + fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach()) + real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt) + l_d_right_eye = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_right_eye'] = l_d_right_eye + l_d_right_eye.backward() + # mouth + fake_d_pred, _ = self.net_d_mouth(self.mouths.detach()) + real_d_pred, _ = self.net_d_mouth(self.mouths_gt) + l_d_mouth = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_mouth'] = l_d_mouth + l_d_mouth.backward() + + self.optimizer_d_left_eye.step() + self.optimizer_d_right_eye.step() + self.optimizer_d_mouth.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _ = self.net_g_ema(self.lq) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _ = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', False) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1)) + metric_data['img'] = sr_img + if hasattr(self, 'gt'): + gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1)) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + self.metric_results[name] += calculate_metric(metric_data, opt_) + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def save(self, epoch, current_iter): + # save net_g and net_d + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + self.save_network(self.net_d, 'net_d', current_iter) + # save component discriminators + if self.use_facial_disc: + self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter) + self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter) + self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter) + # save training state + self.save_training_state(epoch, current_iter) diff --git a/face_enhancer/gfpgan/train.py b/face_enhancer/gfpgan/train.py new file mode 100644 index 0000000..fe5f1f9 --- /dev/null +++ b/face_enhancer/gfpgan/train.py @@ -0,0 +1,11 @@ +# flake8: noqa +import os.path as osp +from basicsr.train import train_pipeline + +import gfpgan.archs +import gfpgan.data +import gfpgan.models + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/face_enhancer/gfpgan/utils.py b/face_enhancer/gfpgan/utils.py new file mode 100644 index 0000000..1cc104d --- /dev/null +++ b/face_enhancer/gfpgan/utils.py @@ -0,0 +1,143 @@ +import cv2 +import os +import torch +from basicsr.utils import img2tensor, tensor2img +from basicsr.utils.download_util import load_file_from_url +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from torchvision.transforms.functional import normalize + +from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear +from gfpgan.archs.gfpganv1_arch import GFPGANv1 +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class GFPGANer(): + """Helper for restoration with GFPGAN. + + It will detect and crop faces, and then resize the faces to 512x512. + GFPGAN is used to restored the resized faces. + The background is upsampled with the bg_upsampler. + Finally, the faces will be pasted back to the upsample background image. + + Args: + model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). + upscale (float): The upscale of the final output. Default: 2. + arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + bg_upsampler (nn.Module): The upsampler for the background. Default: None. + """ + + def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None): + self.upscale = upscale + self.bg_upsampler = bg_upsampler + + # initialize model + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # initialize the GFP-GAN + if arch == 'clean': + self.gfpgan = GFPGANv1Clean( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'bilinear': + self.gfpgan = GFPGANBilinear( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'original': + self.gfpgan = GFPGANv1( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=True, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + # initialize face helper + self.face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + device=self.device) + + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None) + loadnet = torch.load(model_path) + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + self.gfpgan.load_state_dict(loadnet[keyname], strict=True) + self.gfpgan.eval() + self.gfpgan = self.gfpgan.to(self.device) + + @torch.no_grad() + def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True): + self.face_helper.clean_all() + + if has_aligned: # the inputs are already aligned + img = cv2.resize(img, (512, 512)) + self.face_helper.cropped_faces = [img] + else: + self.face_helper.read_image(img) + # get face landmarks for each face + self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5) + # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels + # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations. + # align and warp each face + self.face_helper.align_warp_face() + + # face restoration + for cropped_face in self.face_helper.cropped_faces: + # prepare data + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) + + try: + output = self.gfpgan(cropped_face_t, return_rgb=False)[0] + # convert to image + restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) + except RuntimeError as error: + print(f'\tFailed inference for GFPGAN: {error}.') + restored_face = cropped_face + + restored_face = restored_face.astype('uint8') + self.face_helper.add_restored_face(restored_face) + + if not has_aligned and paste_back: + # upsample the background + if self.bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] + else: + bg_img = None + + self.face_helper.get_inverse_affine(None) + # paste each restored face to the input image + restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img) + return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img + else: + return self.face_helper.cropped_faces, self.face_helper.restored_faces, None diff --git a/face_enhancer/gfpgan/version.py b/face_enhancer/gfpgan/version.py new file mode 100644 index 0000000..3a5a8d7 --- /dev/null +++ b/face_enhancer/gfpgan/version.py @@ -0,0 +1,5 @@ +# GENERATED VERSION FILE +# TIME: Wed Mar 30 13:34:44 2022 +__version__ = '1.3.2' +__gitsha__ = 'unknown' +version_info = (1, 3, 2) diff --git a/face_enhancer/gfpgan/weights/README.md b/face_enhancer/gfpgan/weights/README.md new file mode 100644 index 0000000..4d7b7e6 --- /dev/null +++ b/face_enhancer/gfpgan/weights/README.md @@ -0,0 +1,3 @@ +# Weights + +Put the downloaded weights to this folder. diff --git a/face_enhancer/scripts/convert_gfpganv_to_clean.py b/face_enhancer/scripts/convert_gfpganv_to_clean.py new file mode 100644 index 0000000..8fdccb6 --- /dev/null +++ b/face_enhancer/scripts/convert_gfpganv_to_clean.py @@ -0,0 +1,164 @@ +import argparse +import math +import torch + +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean + + +def modify_checkpoint(checkpoint_bilinear, checkpoint_clean): + for ori_k, ori_v in checkpoint_bilinear.items(): + if 'stylegan_decoder' in ori_k: + if 'style_mlp' in ori_k: # style_mlp_layers + lr_mul = 0.01 + prefix, name, idx, var = ori_k.split('.') + idx = (int(idx) * 2) - 1 + crt_k = f'{prefix}.{name}.{idx}.{var}' + if var == 'weight': + _, c_in = ori_v.size() + scale = (1 / math.sqrt(c_in)) * lr_mul + crt_v = ori_v * scale * 2**0.5 + else: + crt_v = ori_v * lr_mul * 2**0.5 + checkpoint_clean[crt_k] = crt_v + elif 'modulation' in ori_k: # modulation in StyleConv + lr_mul = 1 + crt_k = ori_k + var = ori_k.split('.')[-1] + if var == 'weight': + _, c_in = ori_v.size() + scale = (1 / math.sqrt(c_in)) * lr_mul + crt_v = ori_v * scale + else: + crt_v = ori_v * lr_mul + checkpoint_clean[crt_k] = crt_v + elif 'style_conv' in ori_k: + # StyleConv in style_conv1 and style_convs + if 'activate' in ori_k: # FusedLeakyReLU + # eg. style_conv1.activate.bias + # eg. style_convs.13.activate.bias + split_rlt = ori_k.split('.') + if len(split_rlt) == 4: + prefix, name, _, var = split_rlt + crt_k = f'{prefix}.{name}.{var}' + elif len(split_rlt) == 5: + prefix, name, idx, _, var = split_rlt + crt_k = f'{prefix}.{name}.{idx}.{var}' + crt_v = ori_v * 2**0.5 # 2**0.5 used in FusedLeakyReLU + c = crt_v.size(0) + checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1) + elif 'modulated_conv' in ori_k: + # eg. style_conv1.modulated_conv.weight + # eg. style_convs.13.modulated_conv.weight + _, c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v * scale + elif 'weight' in ori_k: + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v * 2**0.5 + elif 'to_rgb' in ori_k: # StyleConv in to_rgb1 and to_rgbs + if 'modulated_conv' in ori_k: + # eg. to_rgb1.modulated_conv.weight + # eg. to_rgbs.5.modulated_conv.weight + _, c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v * scale + else: + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v + else: + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v + # end of 'stylegan_decoder' + elif 'conv_body_first' in ori_k or 'final_conv' in ori_k: + # key name + name, _, var = ori_k.split('.') + crt_k = f'{name}.{var}' + # weight and bias + if var == 'weight': + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale * 2**0.5 + else: + checkpoint_clean[crt_k] = ori_v * 2**0.5 + elif 'conv_body' in ori_k: + if 'conv_body_up' in ori_k: + ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight') + ori_k = ori_k.replace('skip.weight', 'skip.1.weight') + name1, idx1, name2, _, var = ori_k.split('.') + crt_k = f'{name1}.{idx1}.{name2}.{var}' + if name2 == 'skip': + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale / 2**0.5 + else: + if var == 'weight': + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale + else: + checkpoint_clean[crt_k] = ori_v + if 'conv1' in ori_k: + checkpoint_clean[crt_k] *= 2**0.5 + elif 'toRGB' in ori_k: + crt_k = ori_k + if 'weight' in ori_k: + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale + else: + checkpoint_clean[crt_k] = ori_v + elif 'final_linear' in ori_k: + crt_k = ori_k + if 'weight' in ori_k: + _, c_in = ori_v.size() + scale = 1 / math.sqrt(c_in) + checkpoint_clean[crt_k] = ori_v * scale + else: + checkpoint_clean[crt_k] = ori_v + elif 'condition' in ori_k: + crt_k = ori_k + if '0.weight' in ori_k: + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale * 2**0.5 + elif '0.bias' in ori_k: + checkpoint_clean[crt_k] = ori_v * 2**0.5 + elif '2.weight' in ori_k: + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale + elif '2.bias' in ori_k: + checkpoint_clean[crt_k] = ori_v + + return checkpoint_clean + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--ori_path', type=str, help='Path to the original model') + parser.add_argument('--narrow', type=float, default=1) + parser.add_argument('--channel_multiplier', type=float, default=2) + parser.add_argument('--save_path', type=str) + args = parser.parse_args() + + ori_ckpt = torch.load(args.ori_path)['params_ema'] + + net = GFPGANv1Clean( + 512, + num_style_feat=512, + channel_multiplier=args.channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + # for stylegan decoder + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=args.narrow, + sft_half=True) + crt_ckpt = net.state_dict() + + crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt) + print(f'Save to {args.save_path}.') + torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False) diff --git a/face_enhancer/scripts/parse_landmark.py b/face_enhancer/scripts/parse_landmark.py new file mode 100644 index 0000000..74e2ff9 --- /dev/null +++ b/face_enhancer/scripts/parse_landmark.py @@ -0,0 +1,85 @@ +import cv2 +import json +import numpy as np +import os +import torch +from basicsr.utils import FileClient, imfrombytes +from collections import OrderedDict + +# ---------------------------- This script is used to parse facial landmarks ------------------------------------- # +# Configurations +save_img = False +scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others +enlarge_ratio = 1.4 # only for eyes +json_path = 'ffhq-dataset-v2.json' +face_path = 'datasets/ffhq/ffhq_512.lmdb' +save_path = './FFHQ_eye_mouth_landmarks_512.pth' + +print('Load JSON metadata...') +# use the official json file in FFHQ dataset +with open(json_path, 'rb') as f: + json_data = json.load(f, object_pairs_hook=OrderedDict) + +print('Open LMDB file...') +# read ffhq images +file_client = FileClient('lmdb', db_paths=face_path) +with open(os.path.join(face_path, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + +save_dict = {} + +for item_idx, item in enumerate(json_data.values()): + print(f'\r{item_idx} / {len(json_data)}, {item["image"]["file_path"]} ', end='', flush=True) + + # parse landmarks + lm = np.array(item['image']['face_landmarks']) + lm = lm * scale + + item_dict = {} + # get image + if save_img: + img_bytes = file_client.get(paths[item_idx]) + img = imfrombytes(img_bytes, float32=True) + + # get landmarks for each component + map_left_eye = list(range(36, 42)) + map_right_eye = list(range(42, 48)) + map_mouth = list(range(48, 68)) + + # eye_left + mean_left_eye = np.mean(lm[map_left_eye], 0) # (x, y) + half_len_left_eye = np.max((np.max(np.max(lm[map_left_eye], 0) - np.min(lm[map_left_eye], 0)) / 2, 16)) + item_dict['left_eye'] = [mean_left_eye[0], mean_left_eye[1], half_len_left_eye] + # mean_left_eye[0] = 512 - mean_left_eye[0] # for testing flip + half_len_left_eye *= enlarge_ratio + loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int) + if save_img: + eye_left_img = img[loc_left_eye[1]:loc_left_eye[3], loc_left_eye[0]:loc_left_eye[2], :] + cv2.imwrite(f'tmp/{item_idx:08d}_eye_left.png', eye_left_img * 255) + + # eye_right + mean_right_eye = np.mean(lm[map_right_eye], 0) + half_len_right_eye = np.max((np.max(np.max(lm[map_right_eye], 0) - np.min(lm[map_right_eye], 0)) / 2, 16)) + item_dict['right_eye'] = [mean_right_eye[0], mean_right_eye[1], half_len_right_eye] + # mean_right_eye[0] = 512 - mean_right_eye[0] # # for testing flip + half_len_right_eye *= enlarge_ratio + loc_right_eye = np.hstack( + (mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int) + if save_img: + eye_right_img = img[loc_right_eye[1]:loc_right_eye[3], loc_right_eye[0]:loc_right_eye[2], :] + cv2.imwrite(f'tmp/{item_idx:08d}_eye_right.png', eye_right_img * 255) + + # mouth + mean_mouth = np.mean(lm[map_mouth], 0) + half_len_mouth = np.max((np.max(np.max(lm[map_mouth], 0) - np.min(lm[map_mouth], 0)) / 2, 16)) + item_dict['mouth'] = [mean_mouth[0], mean_mouth[1], half_len_mouth] + # mean_mouth[0] = 512 - mean_mouth[0] # for testing flip + loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int) + if save_img: + mouth_img = img[loc_mouth[1]:loc_mouth[3], loc_mouth[0]:loc_mouth[2], :] + cv2.imwrite(f'tmp/{item_idx:08d}_mouth.png', mouth_img * 255) + + save_dict[f'{item_idx:08d}'] = item_dict + +print('Save...') +torch.save(save_dict, save_path) diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/.gitignore b/face_parse/PSFRGAN-master/PSFRGAN-master/.gitignore new file mode 100644 index 0000000..0a07a97 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/.gitignore @@ -0,0 +1,110 @@ +check_points/ +pretrain_models* +test_dir_enhance_results/ +test_dir_align_results/ +test_unalign_results/ +tmp* +local* +*.pth + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/LICENSE b/face_parse/PSFRGAN-master/PSFRGAN-master/LICENSE new file mode 100644 index 0000000..964d1ea --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/LICENSE @@ -0,0 +1,445 @@ +PSFR-GAN (c) by Chaofeng Chen + +PSFR-GAN is licensed under a +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. + +You should have received a copy of the license along with this +work. If not, see . + +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/README.md b/face_parse/PSFRGAN-master/PSFRGAN-master/README.md new file mode 100644 index 0000000..dc8e528 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/README.md @@ -0,0 +1,123 @@ +# PSFR-GAN in PyTorch + +[Progressive Semantic-Aware Style Transformation for Blind Face Restoration](https://arxiv.org/abs/2009.08709) +[Chaofeng Chen](https://chaofengc.github.io), [Xiaoming Li](https://csxmli2016.github.io/), [Lingbo Yang](https://lotayou.github.io), [Xianhui Lin](https://dblp.org/pid/147/7708.html), [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang/), [Kwan-Yee K. Wong](https://i.cs.hku.hk/~kykwong/) + +![](test_dir/test_hzgg.jpg) +![](test_hzgg_results/hq_final.jpg) + +### Changelog +- **2021.04.26**: Add pytorch vgg19 model to GoogleDrive and remove `--distributed` option which causes training error. +- **2021.03.22**: Update new model at 15 epoch (52.5k iterations). +- **2021.03.19**: Add train codes for PSFRGAN and FPN. + +## Prerequisites and Installation +- Ubuntu 18.04 +- CUDA 10.1 +- Clone this repository + ``` + git clone https://github.com/chaofengc/PSFR-GAN.git + cd PSFR-GAN + ``` +- Python 3.7, install required packages by `pip3 install -r requirements.txt` + +## Quick Test + +### Download Pretrain Models and Dataset +Download the pretrained models from the following link and put them to `./pretrain_models` +- [GoogleDrive](https://drive.google.com/drive/folders/1Ubejhxd2xd4fxGc_M_LWl3Ux6CgQd9rP?usp=sharing) +- [BaiduNetDisk](https://pan.baidu.com/s/1cru3uUASEfGX6p6L0_7gWQ), extract code: `gj2r` + +### Test single image +Run the following script to enhance face(s) in single input +``` +python test_enhance_single_unalign.py --test_img_path ./test_dir/test_hzgg.jpg --results_dir test_hzgg_results --gpus 1 +``` + +This script do the following things: +- Crop and align all the faces from input image, stored at `results_dir/LQ_faces` +- Parse these faces and then enhance them, results stored at `results_dir/ParseMaps` and `results_dir/HQ` +- Paste then enhanced faces back to the original image `results_dir/hq_final.jpg` +- You can use `--gpus` to specify how many GPUs to use, `<=0` means running on CPU. The program will use GPU with the most available memory. Set `CUDA_VISIBLE_DEVICE` to specify the GPU if you do not want automatic GPU selection. + +### Test image folder +To test multiple images, we first crop out all the faces and align them use the following script. +``` +python align_and_crop_dir.py --src_dir test_dir --results_dir test_dir_align_results +``` + +For images (*e.g.* `multiface_test.jpg`) contain multiple faces, the aligned faces will be stored as `multiface_test_{face_index}.jpg` +And then parse the aligned faces and enhance them with +``` +python test_enhance_dir_align.py --src_dir test_dir_align_results --results_dir test_dir_enhance_results +``` +Results will be saved to three folders respectively: `results_dir/lq`, `results_dir/parse`, `results_dir/hq`. + +### Additional test script + +For your convenience, we also provide script to test multiple unaligned images and paste the enhance results back. **Note the paste back operation could be quite slow for large size images containing many faces (dlib takes time to detect faces in large image).** +``` +python test_enhance_dir_unalign.py --src_dir test_dir --results_dir test_unalign_results +``` +This script basically do the same thing as `test_enhance_single_unalign.py` for each image in `src_dir` + +## Train the Model + +### Data Preparation + +- Download [FFHQ](https://github.com/NVlabs/ffhq-dataset) and put the images to `../datasets/FFHQ/imgs1024` +- Download parsing masks (`512x512`) [HERE](https://drive.google.com/file/d/1eQwO8hKcaluyCnxuZAp0eJVOdgMi30uA/view?usp=sharing) generated by the pretrained FPN and put them to `../datasets/FFHQ/masks512`. + +*Note: you may change `../datasets/FFHQ` to your own path. But images and masks must be stored under `your_own_path/imgs1024` and `your_own_path/masks512` respectively.* + +### Train Script for PSFRGAN + +Here is an example train script for PSFRGAN: + +``` +python train.py --gpus 2 --model enhance --name PSFRGAN_v001 \ + --g_lr 0.0001 --d_lr 0.0004 --beta1 0.5 \ + --gan_mode 'hinge' --lambda_pix 10 --lambda_fm 10 --lambda_ss 1000 \ + --Dinput_nc 22 --D_num 3 --n_layers_D 4 \ + --batch_size 2 --dataset ffhq --dataroot ../datasets/FFHQ \ + --visual_freq 100 --print_freq 10 #--continue_train +``` +- Please change the `--name` option for different experiments. Tensorboard records with the same name will be moved to `check_points/log_archive`, and the weight directory will only store weight history of latest experiment with the same name. +- `--gpus` specify number of GPUs used to train. The script will use GPUs with more available memory first. To specify the GPU index, use `export CUDA_VISIBLE_DEVICES=your_gpu_ids` before the script. +- Uncomment `--continue_train` to resume training. *Current codes do not resume the optimizer state.* +- It needs at least **8GB** memory to train with **batch_size=1**. + +### Scripts for FPN + +You may also train your own FPN and generate masks for the HQ images by yourself with the following steps: + +- Download [CelebAHQ-Mask](https://github.com/switchablenorms/CelebAMask-HQ) dataset. Generate `CelebAMask-HQ-mask` and `CelebAMask-HQ-mask-color` with the provided scripts in `CelebAMask-HQ/face_parsing/Data_preprocessing/`. +- Train FPN with the following commmand +``` +python train.py --gpus 1 --model parse --name FPN_v001 \ + --lr 0.0002 --batch_size 8 \ + --dataset celebahqmask --dataroot ../datasets/CelebAMask-HQ \ + --visual_freq 100 --print_freq 10 #--continue_train +``` +- Generate parsing masks with your own FPN using the following command: +``` +python generate_masks.py --save_masks_dir ../datasets/FFHQ/masks512 --batch_size 8 --parse_net_weight path/to/your/own/FPN +``` + +## Citation +``` +@inproceedings{ChenPSFRGAN, + author = {Chen, Chaofeng and Li, Xiaoming and Lingbo, Yang and Lin, Xianhui and Zhang, Lei and Wong, KKY}, + title = {Progressive Semantic-Aware Style Transformation for Blind Face Restoration}, + Journal = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2021} +} +``` + +## License + +Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. + +## Acknowledgement + +This work is inspired by [SPADE](https://github.com/NVlabs/SPADE), and closed related to [DFDNet](https://github.com/csxmli2016/DFDNet) and [HiFaceGAN](https://github.com/Lotayou/Face-Renovation). Our codes largely benefit from [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/align_and_crop_dir.py b/face_parse/PSFRGAN-master/PSFRGAN-master/align_and_crop_dir.py new file mode 100644 index 0000000..dbb41b4 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/align_and_crop_dir.py @@ -0,0 +1,86 @@ +import dlib +import os +import cv2 +import numpy as np +from tqdm import tqdm +from skimage import transform as trans +from skimage import io +import argparse + + +def get_points(img, detector, shape_predictor, size_threshold=999): + dets = detector(img, 1) + if len(dets) == 0: + return None + + all_points = [] + for det in dets: + if isinstance(detector, dlib.cnn_face_detection_model_v1): + rec = det.rect # for cnn detector + else: + rec = det + if rec.width() > size_threshold or rec.height() > size_threshold: + break + shape = shape_predictor(img, rec) + single_points = [] + for i in range(5): + single_points.append([shape.part(i).x, shape.part(i).y]) + all_points.append(np.array(single_points)) + if len(all_points) <= 0: + return None + else: + return all_points + +def align_and_save(img, save_path, src_points, template_path, template_scale=1): + out_size = (512, 512) + reference = np.load(template_path) / template_scale + + ext = os.path.splitext(save_path) + for idx, spoint in enumerate(src_points): + tform = trans.SimilarityTransform() + tform.estimate(spoint, reference) + M = tform.params[0:2,:] + + crop_img = cv2.warpAffine(img, M, out_size) + if len(src_points) > 1: + save_path = ext[0] + '_{}'.format(idx) + ext[1] + dlib.save_image(crop_img.astype(np.uint8), save_path) + print('Saving image', save_path) + +def align_and_save_dir(src_dir, save_dir, template_path='./pretrain_models/FFHQ_template.npy', template_scale=2, use_cnn_detector=True): + out_size = (512, 512) + if use_cnn_detector: + detector = dlib.cnn_face_detection_model_v1('./pretrain_models/mmod_human_face_detector.dat') + else: + detector = dlib.get_frontal_face_detector() + sp = dlib.shape_predictor('./pretrain_models/shape_predictor_5_face_landmarks.dat') + + for name in os.listdir(src_dir): + img_path = os.path.join(src_dir, name) + img = dlib.load_rgb_image(img_path) + + points = get_points(img, detector, sp) + if points is not None: + save_path = os.path.join(save_dir, name) + align_and_save(img, save_path, points, template_path, template_scale) + else: + print('No face detected in', img_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--src_dir', type=str, help='source directory containing images to crop and align.') + parser.add_argument('--results_dir', type=str, help='results directory to save the aligned faces.') + parser.add_argument('--not_use_cnn_detector', action='store_true', help='do not use cnn face detector in dlib.') + args = parser.parse_args() + + src_dir = args.src_dir + assert os.path.isdir(src_dir), 'Source path should be a directory containing images' + save_dir = args.results_dir + if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) + align_and_save_dir(src_dir, save_dir, use_cnn_detector=not args.not_use_cnn_detector) + + + + + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/data/__init__.py b/face_parse/PSFRGAN-master/PSFRGAN-master/data/__init__.py new file mode 100644 index 0000000..9d3dfcc --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/data/__init__.py @@ -0,0 +1,94 @@ +"""This package includes all the modules related to data loading and preprocessing + + To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. + You need to implement four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point from data loader. + -- : (optionally) add dataset-specific options and set default options. + +Now you can use the dataset class by specifying flag '--dataset_mode dummy'. +See our template dataset class 'template_dataset.py' for more details. +""" +import importlib +import torch.utils.data +from data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + """Import the module "data/[dataset_name]_dataset.py". + + In the file, the class called DatasetNameDataset() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + """Return the static method of the dataset class.""" + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt): + """Create a dataset given the option. + + This function wraps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from data import create_dataset + >>> dataset = create_dataset(opt) + """ + data_loader = CustomDatasetDataLoader(opt) + dataset = data_loader.load_data() + return dataset + + +class CustomDatasetDataLoader(): + """Wrapper class of Dataset class that performs multi-threaded data loading""" + + def __init__(self, opt): + """Initialize this class + + Step 1: create a dataset instance given the name [dataset_mode] + Step 2: create a multi-threaded data loader. + """ + self.opt = opt + dataset_class = find_dataset_using_name(opt.dataset_name) + self.dataset = dataset_class(opt) + print("dataset [%s] was created" % type(self.dataset).__name__) + drop_last = True if opt.isTrain else False + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size, + shuffle=not opt.serial_batches, + num_workers=int(opt.num_threads), drop_last=drop_last) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/data/base_dataset.py b/face_parse/PSFRGAN-master/PSFRGAN-master/data/base_dataset.py new file mode 100644 index 0000000..6b34f5f --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/data/base_dataset.py @@ -0,0 +1,162 @@ +"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" +import random +import numpy as np +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +from abc import ABC, abstractmethod + +import imgaug as ia +import imgaug.augmenters as iaa + +class BaseDataset(data.Dataset, ABC): + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + self.root = opt.dataroot + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + +def get_params(opt, size): + w, h = size + new_h = h + new_w = w + if opt.preprocess == 'resize_and_crop': + new_h = new_w = opt.load_size + elif opt.preprocess == 'scale_width_and_crop': + new_w = opt.load_size + new_h = opt.load_size * h // w + + x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) + y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) + + flip = random.random() > 0.5 + + return {'crop_pos': (x, y), 'flip': flip} + +def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): + transform_list = [] + if grayscale: + # transform_list.append(transforms.Grayscale(1)) + from util import util + transform_list.append(util.RGBtoY) + if 'resize' in opt.preprocess: + osize = [opt.load_size, opt.load_size] + transform_list.append(transforms.Resize(osize, method)) + elif 'scale_width' in opt.preprocess: + transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) + + if 'crop' in opt.preprocess: + if params is None: + transform_list.append(transforms.RandomCrop(opt.crop_size)) + else: + if 'crop_size' in params: + transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], params['crop_size']))) + else: + transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) + + if opt.preprocess == 'none': + transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) + + if not opt.no_flip: + if params is None: + transform_list.append(transforms.RandomHorizontalFlip()) + elif params['flip']: + transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) + + if convert: + transform_list += [transforms.ToTensor()] + if grayscale: + transform_list += [transforms.Normalize((0.5,), (0.5,))] + else: + transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def __make_power_2(img, base, method=Image.BICUBIC): + ow, oh = img.size + h = int(round(oh / base) * base) + w = int(round(ow / base) * base) + if (h == oh) and (w == ow): + return img + + __print_size_warning(ow, oh, w, h) + return img.resize((w, h), method) + + +def __scale_width(img, target_width, method=Image.BICUBIC): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), method) + + +def __crop(img, pos, size): + ow, oh = img.size + x1, y1 = pos + tw = th = size + if (ow > tw or oh > th): + return img.crop((x1, y1, x1 + tw, y1 + th)) + return img + + +def __flip(img, flip): + if flip: + return img.transpose(Image.FLIP_LEFT_RIGHT) + return img + + +def __print_size_warning(ow, oh, w, h): + """Print warning information about image size(only print once)""" + if not hasattr(__print_size_warning, 'has_printed'): + print("The image size needs to be a multiple of 4. " + "The loaded image size was (%d, %d), so it was adjusted to " + "(%d, %d). This adjustment will be done to all images " + "whose sizes are not multiples of 4" % (ow, oh, w, h)) + __print_size_warning.has_printed = True diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/data/celebahqmask_dataset.py b/face_parse/PSFRGAN-master/PSFRGAN-master/data/celebahqmask_dataset.py new file mode 100644 index 0000000..3d48bd8 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/data/celebahqmask_dataset.py @@ -0,0 +1,60 @@ +import os +import random +import numpy as np +from PIL import Image +import imgaug as ia +import imgaug.augmenters as iaa + +from data.image_folder import make_dataset + +import torch +from torch.utils.data import Dataset +from torchvision.transforms import transforms + +from data.base_dataset import BaseDataset +from utils.utils import onehot_parse_map + +from data.ffhq_dataset import complex_imgaug, random_gray + +class CelebAHQMaskDataset(BaseDataset): + + def __init__(self, opt): + BaseDataset.__init__(self, opt) + self.img_size = opt.Pimg_size + self.lr_size = opt.Gin_size + self.hr_size = opt.Gout_size + self.shuffle = True if opt.isTrain else False + + self.img_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'CelebA-HQ-img'))) + self.mask_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'CelebAMask-HQ-mask'))) + + self.to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + def __len__(self,): + return len(self.img_dataset) + + def __getitem__(self, idx): + sample = {} + img_path = self.img_dataset[idx] + mask_path = self.mask_dataset[idx] + hr_img = Image.open(img_path).convert('RGB') + mask_img = Image.open(mask_path) + + hr_img = hr_img.resize((self.hr_size, self.hr_size)) + hr_img = random_gray(hr_img, p=0.3) + scale_size = np.random.randint(32, 256) + lr_img = complex_imgaug(hr_img, self.img_size, scale_size) + + mask_img = mask_img.resize((self.hr_size, self.hr_size)) + mask_label = torch.tensor(np.array(mask_img)).long() + + hr_tensor = self.to_tensor(hr_img) + lr_tensor = self.to_tensor(lr_img) + + return {'HR': hr_tensor, 'LR': lr_tensor, 'HR_paths': img_path, 'Mask': mask_label} + + + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/data/ffhq_dataset.py b/face_parse/PSFRGAN-master/PSFRGAN-master/data/ffhq_dataset.py new file mode 100644 index 0000000..f987223 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/data/ffhq_dataset.py @@ -0,0 +1,88 @@ +import os +import random +import numpy as np +from PIL import Image +import imgaug as ia +import imgaug.augmenters as iaa + +from data.image_folder import make_dataset + +import torch +from torch.utils.data import Dataset +from torchvision.transforms import transforms + +from data.base_dataset import BaseDataset +from utils.utils import onehot_parse_map + +class FFHQDataset(BaseDataset): + + def __init__(self, opt): + BaseDataset.__init__(self, opt) + self.img_size = opt.Pimg_size + self.lr_size = opt.Gin_size + self.hr_size = opt.Gout_size + self.shuffle = True if opt.isTrain else False + + self.img_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'imgs1024'))) + self.mask_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'masks512'))) + + self.to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + self.random_crop = transforms.RandomCrop(self.hr_size) + + def __len__(self,): + return len(self.img_dataset) + + def __getitem__(self, idx): + sample = {} + img_path = self.img_dataset[idx] + mask_path = self.mask_dataset[idx] + hr_img = Image.open(img_path).convert('RGB') + mask_img = Image.open(mask_path).convert('RGB') + + hr_img = hr_img.resize((self.hr_size, self.hr_size)) + hr_img = random_gray(hr_img, p=0.3) + scale_size = np.random.randint(32, 256) + lr_img = complex_imgaug(hr_img, self.img_size, scale_size) + + mask_img = mask_img.resize((self.hr_size, self.hr_size)) + mask_label = onehot_parse_map(mask_img) + mask_label = torch.tensor(mask_label).float() + + hr_tensor = self.to_tensor(hr_img) + lr_tensor = self.to_tensor(lr_img) + + return {'HR': hr_tensor, 'LR': lr_tensor, 'HR_paths': img_path, 'Mask': mask_label} + + +def complex_imgaug(x, org_size, scale_size): + """input single RGB PIL Image instance""" + x = np.array(x) + x = x[np.newaxis, :, :, :] + aug_seq = iaa.Sequential([ + iaa.Sometimes(0.5, iaa.OneOf([ + iaa.GaussianBlur((3, 15)), + iaa.AverageBlur(k=(3, 15)), + iaa.MedianBlur(k=(3, 15)), + iaa.MotionBlur((5, 25)) + ])), + iaa.Resize(scale_size, interpolation=ia.ALL), + iaa.Sometimes(0.2, iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.1*255), per_channel=0.5)), + iaa.Sometimes(0.7, iaa.JpegCompression(compression=(10, 65))), + iaa.Resize(org_size), + ]) + + aug_img = aug_seq(images=x) + return aug_img[0] + + +def random_gray(x, p=0.5): + """input single RGB PIL Image instance""" + x = np.array(x) + x = x[np.newaxis, :, :, :] + aug = iaa.Sometimes(p, iaa.Grayscale(alpha=1.0)) + aug_img = aug(images=x) + return aug_img[0] + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/data/image_folder.py b/face_parse/PSFRGAN-master/PSFRGAN-master/data/image_folder.py new file mode 100644 index 0000000..d0b4b30 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/data/image_folder.py @@ -0,0 +1,67 @@ +"""A modified image folder class + +We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) +so that this class can load images from both current directory and its subdirectories. +""" + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', + '.tif', '.TIF', '.tiff', '.TIFF', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir, max_dataset_size=float("inf")): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images[:min(max_dataset_size, len(images))] + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/data/single_dataset.py b/face_parse/PSFRGAN-master/PSFRGAN-master/data/single_dataset.py new file mode 100644 index 0000000..2f6cb2f --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/data/single_dataset.py @@ -0,0 +1,43 @@ +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image +import numpy as np + +class SingleDataset(BaseDataset): + """This dataset class can load a set of images specified by the path --dataroot /path/to/data. + + It can be used for generating CycleGAN results only for one side with the model option '-model test'. + """ + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt) + self.A_paths = sorted(make_dataset(opt.src_dir, opt.max_dataset_size)) + input_nc = self.opt.output_nc + self.transform = get_transform(opt, grayscale=(input_nc == 1)) + self.opt = opt + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns a dictionary that contains A and A_paths + A(tensor) - - an image in one domain + A_paths(str) - - the path of the image + """ + A_path = self.A_paths[index] + A_img = Image.open(A_path).convert('RGB') + A_img = A_img.resize((512, 512), Image.BICUBIC) + + A = self.transform(A_img) + return {'LR': A, 'LR_paths': A_path} + + def __len__(self): + """Return the total number of images in the dataset.""" + return len(self.A_paths) diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/generate_mask.py b/face_parse/PSFRGAN-master/PSFRGAN-master/generate_mask.py new file mode 100644 index 0000000..5bd3e52 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/generate_mask.py @@ -0,0 +1,92 @@ +import os +from options.test_options import TestOptions +from data import create_dataset +from models import create_model +from utils import utils +from PIL import Image +from tqdm import tqdm +import torch +import time +import numpy as np +import cv2 +import glob +from torchvision.transforms import transforms + +if __name__ == '__main__': + opt = TestOptions() + opt = opt.parse() # get test options + opt.num_threads = 0 # test code only supports num_threads = 1 + opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. + opt.no_flip = True + + dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options + model = create_model(opt) # create a model given opt.model and other options + model.load_pretrain_models() + + netP = model.netP + model.eval() + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + image_dir = "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan/" + output_dir = "G:/VGGFace2-HQ/VGGface2_HQ_original_aligned_mask" + + temp_path = os.path.join(image_dir,'*/') + pathes = glob.glob(temp_path) + dataset = [] + for dir_item in pathes: + join_path = glob.glob(os.path.join(dir_item,'*.jpg')) + print("processing %s"%dir_item,end='\r') + temp_list = [] + for item in join_path: + temp_list.append(item) + dataset.append(temp_list) + + # ------------------------ restore ------------------------ + for i_dir in dataset: + path = os.path.dirname(i_dir[0]) + dir_name = os.path.join(output_dir, os.path.basename(path)) + if not os.path.exists(dir_name): + os.makedirs(dir_name) + + for img_path in i_dir: + hr_img = Image.open(img_path).convert('RGB') + inp = to_tensor(hr_img).unsqueeze(0) + with torch.no_grad(): + parse_map, _ = netP(inp) + parse_map_sm = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float() + ref_parse_img = utils.color_parse_map(parse_map_sm) + img_name = os.path.basename(img_path) + basename, ext = os.path.splitext(img_name) + save_face_name = f'{basename}.png' + # print(save_face_name) + save_path = os.path.join(dir_name, save_face_name) + # os.makedirs(opt.save_masks_dir, exist_ok=True) + img = cv2.cvtColor(ref_parse_img[0],cv2.COLOR_RGB2GRAY) + cv2.imwrite(save_path,img) + + # for i, data in tqdm(enumerate(dataset), total=len(dataset)//opt.batch_size): + # inp = data['LR'] + # with torch.no_grad(): + # parse_map, _ = netP(inp) + # parse_map_sm = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float() + # img_path = data['LR_paths'] # get image paths + # ref_parse_img = utils.color_parse_map(parse_map_sm) + # for i in range(len(img_path)): + # img_name = os.path.basename(img_path[i]) + # basename, ext = os.path.splitext(img_name) + # save_face_name = f'{basename}.png' + # # print(save_face_name) + # save_path = os.path.join(opt.save_masks_dir, save_face_name) + # os.makedirs(opt.save_masks_dir, exist_ok=True) + # img = cv2.cvtColor(ref_parse_img[i],cv2.COLOR_RGB2GRAY) + # cv2.imwrite(save_path,img) + # save_img = Image.fromarray(ref_parse_img[i]) + # save_img.save(save_path) + + + + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/models/__init__.py b/face_parse/PSFRGAN-master/PSFRGAN-master/models/__init__.py new file mode 100644 index 0000000..fc01113 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/models/__init__.py @@ -0,0 +1,67 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import importlib +from models.base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/models/base_model.py b/face_parse/PSFRGAN-master/PSFRGAN-master/models/base_model.py new file mode 100644 index 0000000..7cda51b --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/models/base_model.py @@ -0,0 +1,248 @@ +import os +import torch +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this fucntion, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): specify the images that you want to display and save. + -- self.visual_names (str list): define networks used in our training. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU + + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch + self.load_networks(load_suffix) + self.print_networks(opt.verbose) + + def eval(self): + """Make models eval mode during test time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self): + """ Return image paths that are used to load current data""" + return self.image_paths + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + def get_lr(self,): + lrs = {} + for idx, p in enumerate(self.optimizers): + lrs['LR{}'.format(idx)] = p.param_groups[0]['lr'] + return lrs + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret['Loss_' + name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch, info=None): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + print('Model saved in:', save_filename) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.cpu().state_dict(), save_path) + if info is not None: + torch.save(info, os.path.join(self.save_dir, '%s.info' % epoch)) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.load_model_names: + if isinstance(name, str): + load_filename = '%s_net_%s.pth' % (epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel) or isinstance(net, torch.nn.parallel.DistributedDataParallel): + net = net.module + print('loading the model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + map_location = str(self.device) + + state_dict = torch.load(load_path, map_location=map_location) + + # patch InstanceNorm checkpoints prior to 0.4 + # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + # net.load_state_dict(state_dict) + if not self.opt.no_strict_load: + net.load_state_dict(state_dict) + # Load partial weights + else: + model_dict = net.state_dict() + pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} + model_dict.update(pretrained_dict) + net.load_state_dict(model_dict, strict=False) + + info_path = os.path.join(self.save_dir, '%s.info' % epoch) + if os.path.exists(info_path): + info_dict = torch.load(info_path) + for k, v in info_dict.items(): + setattr(self.opt, k, v) + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/models/blocks.py b/face_parse/PSFRGAN-master/PSFRGAN-master/models/blocks.py new file mode 100644 index 0000000..3e02fd5 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/models/blocks.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter +from torch.nn import functional as F +import numpy as np + + +class NormLayer(nn.Module): + """Normalization Layers. + ------------ + # Arguments + - channels: input channels, for batch norm and instance norm. + - input_size: input shape without batch size, for layer norm. + """ + def __init__(self, channels, normalize_shape=None, norm_type='bn'): + super(NormLayer, self).__init__() + norm_type = norm_type.lower() + self.norm_type = norm_type + self.channels = channels + if norm_type == 'bn': + self.norm = nn.BatchNorm2d(channels, affine=True) + elif norm_type == 'in': + self.norm = nn.InstanceNorm2d(channels, affine=False) + elif norm_type == 'gn': + self.norm = nn.GroupNorm(32, channels, affine=True) + elif norm_type == 'pixel': + self.norm = lambda x: F.normalize(x, p=2, dim=1) + elif norm_type == 'layer': + self.norm = nn.LayerNorm(normalize_shape) + elif norm_type == 'none': + self.norm = lambda x: x*1.0 + else: + assert 1==0, 'Norm type {} not support.'.format(norm_type) + + def forward(self, x, ref=None): + return self.norm(x) + + +class ReluLayer(nn.Module): + """Relu Layer. + ------------ + # Arguments + - relu type: type of relu layer, candidates are + - ReLU + - LeakyReLU: default relu slope 0.2 + - PRelu + - SELU + - none: direct pass + """ + def __init__(self, channels, relu_type='relu'): + super(ReluLayer, self).__init__() + relu_type = relu_type.lower() + if relu_type == 'relu': + self.func = nn.ReLU(True) + elif relu_type == 'leakyrelu': + self.func = nn.LeakyReLU(0.2, inplace=True) + elif relu_type == 'prelu': + self.func = nn.PReLU(channels) + elif relu_type == 'selu': + self.func = nn.SELU(True) + elif relu_type == 'none': + self.func = lambda x: x*1.0 + else: + assert 1==0, 'Relu type {} not support.'.format(relu_type) + + def forward(self, x): + return self.func(x) + + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, scale='none', norm_type='none', relu_type='none', use_pad=True, bias=True): + super(ConvLayer, self).__init__() + self.use_pad = use_pad + self.norm_type = norm_type + self.in_channels = in_channels + if norm_type in ['bn']: + bias = False + + stride = 2 if scale == 'down' else 1 + self.scale = scale + + self.scale_func = lambda x: x + if scale == 'up': + self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') + + self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.)/2))) + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) + + self.avgpool = nn.AvgPool2d(2, 2) + self.relu = ReluLayer(out_channels, relu_type) + self.norm = NormLayer(out_channels, norm_type=norm_type) + + def forward(self, x): + out = self.scale_func(x) + if self.use_pad: + out = self.reflection_pad(out) + out = self.conv2d(out) + if self.scale == 'down_avg': + out = self.avgpool(out) + out = self.norm(out) + out = self.relu(out) + return out + + +class ResidualBlock(nn.Module): + """ + Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html + """ + def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): + super(ResidualBlock, self).__init__() + + if scale == 'none' and c_in == c_out: + self.shortcut_func = lambda x: x + else: + self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) + + scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} + scale_conf = scale_config_dict[scale] + + self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) + self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') + + def forward(self, x): + identity = self.shortcut_func(x) + + res = self.conv1(x) + res = self.conv2(res) + return identity + res + + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/models/enhance_model.py b/face_parse/PSFRGAN-master/PSFRGAN-master/models/enhance_model.py new file mode 100644 index 0000000..d83dd7e --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/models/enhance_model.py @@ -0,0 +1,157 @@ +import os +import numpy as np +import collections + +import torch +import torch.nn as nn +import torch.optim as optim + +from models import loss +from models import networks +from .base_model import BaseModel +from utils import utils + +class EnhanceModel(BaseModel): + + def modify_commandline_options(parser, is_train): + if is_train: + parser.add_argument('--parse_net_weight', type=str, default='./pretrain_models/parse_multi_iter_90000.pth', help='parse model path') + parser.add_argument('--lambda_pix', type=float, default=10.0, help='weight for parsing map') + parser.add_argument('--lambda_pcp', type=float, default=0.0, help='weight for vgg perceptual loss') + parser.add_argument('--lambda_fm', type=float, default=10.0, help='weight for sr') + parser.add_argument('--lambda_g', type=float, default=1.0, help='weight for sr') + parser.add_argument('--lambda_ss', type=float, default=1000., help='weight for global style') + return parser + + def __init__(self, opt): + BaseModel.__init__(self, opt) + + self.netP = networks.define_P(opt, weight_path=opt.parse_net_weight) + self.netG = networks.define_G(opt, use_norm='spectral_norm') + + if self.isTrain: + self.netD = networks.define_D(opt, opt.Dinput_nc, use_norm='spectral_norm') + self.vgg_model = loss.PCPFeat(weight_path='./pretrain_models/vgg19-dcbb9e9d.pth').to(opt.device) + if len(opt.gpu_ids) > 0: + self.vgg_model = torch.nn.DataParallel(self.vgg_model, opt.gpu_ids, output_device=opt.device) + + self.model_names = ['G'] + self.loss_names = ['Pix', 'PCP', 'G', 'FM', 'D', 'SS'] # Generator loss, fm loss, parsing loss, discriminator loss + self.visual_names = ['img_LR', 'img_HR', 'img_SR', 'ref_Parse', 'hr_mask'] + self.fm_weights = [1**x for x in range(opt.D_num)] + + if self.isTrain: + self.model_names = ['G', 'D'] + self.load_model_names = ['G', 'D'] + + self.criterionParse = torch.nn.CrossEntropyLoss().to(opt.device) + self.criterionFM = loss.FMLoss().to(opt.device) + self.criterionGAN = loss.GANLoss(opt.gan_mode).to(opt.device) + self.criterionPCP = loss.PCPLoss(opt) + self.criterionPix= nn.L1Loss() + self.criterionRS = loss.RegionStyleLoss() + + self.optimizer_G = optim.Adam([p for p in self.netG.parameters() if p.requires_grad], lr=opt.g_lr, betas=(opt.beta1, 0.999)) + self.optimizer_D = optim.Adam([p for p in self.netD.parameters() if p.requires_grad], lr=opt.d_lr, betas=(opt.beta1, 0.999)) + self.optimizers = [self.optimizer_G, self.optimizer_D] + + def eval(self): + self.netG.eval() + self.netP.eval() + + def load_pretrain_models(self,): + self.netP.eval() + print('Loading pretrained LQ face parsing network from', self.opt.parse_net_weight) + if len(self.opt.gpu_ids) > 0: + self.netP.module.load_state_dict(torch.load(self.opt.parse_net_weight)) + else: + self.netP.load_state_dict(torch.load(self.opt.parse_net_weight)) + self.netG.eval() + print('Loading pretrained PSFRGAN from', self.opt.psfr_net_weight) + if len(self.opt.gpu_ids) > 0: + self.netG.module.load_state_dict(torch.load(self.opt.psfr_net_weight), strict=False) + else: + self.netG.load_state_dict(torch.load(self.opt.psfr_net_weight), strict=False) + + def set_input(self, input, cur_iters=None): + self.cur_iters = cur_iters + self.img_LR = input['LR'].to(self.opt.device) + self.img_HR = input['HR'].to(self.opt.device) + self.hr_mask = input['Mask'].to(self.opt.device) + if self.opt.debug: + print('SRNet input shape:', self.img_LR.shape, self.img_HR.shape) + + def forward(self): + with torch.no_grad(): + ref_mask, _ = self.netP(self.img_LR) + self.ref_mask_onehot = (ref_mask == ref_mask.max(dim=1, keepdim=True)[0]).float().detach() + + if self.opt.debug: + print('SRNet reference mask shape:', self.ref_mask_onehot.shape) + self.img_SR = self.netG(self.img_LR, self.ref_mask_onehot) + + self.real_D_results = self.netD(torch.cat((self.img_HR, self.hr_mask), dim=1), return_feat=True) + self.fake_D_results = self.netD(torch.cat((self.img_SR.detach(), self.hr_mask), dim=1), return_feat=False) + self.fake_G_results = self.netD(torch.cat((self.img_SR, self.hr_mask), dim=1), return_feat=True) + + self.img_SR_feats = self.vgg_model(self.img_SR) + self.img_HR_feats = self.vgg_model(self.img_HR) + + def backward_G(self): + # Pix Loss + self.loss_Pix = self.criterionPix(self.img_SR, self.img_HR) * self.opt.lambda_pix + + # semantic style loss + self.loss_SS = self.criterionRS(self.img_SR_feats, self.img_HR_feats, self.hr_mask) * self.opt.lambda_ss + + # perceptual loss + self.loss_PCP = self.criterionPCP(self.img_SR_feats, self.img_HR_feats) * self.opt.lambda_pcp + + # Feature matching loss + tmp_loss = 0 + for i, w in zip(range(self.opt.D_num), self.fm_weights): + tmp_loss = tmp_loss + self.criterionFM(self.fake_G_results[i][1], self.real_D_results[i][1]) * w + self.loss_FM = tmp_loss * self.opt.lambda_fm / self.opt.D_num + + # Generator loss + tmp_loss = 0 + for i in range(self.opt.D_num): + tmp_loss = tmp_loss + self.criterionGAN(self.fake_G_results[i][0], True, for_discriminator=False) + self.loss_G = tmp_loss * self.opt.lambda_g / self.opt.D_num + + total_loss = self.loss_Pix + self.loss_PCP + self.loss_FM + self.loss_G + self.loss_SS + total_loss.backward() + + def backward_D(self, ): + self.loss_D = 0 + for i in range(self.opt.D_num): + self.loss_D += 0.5 * (self.criterionGAN(self.fake_D_results[i], False) + self.criterionGAN(self.real_D_results[i][0], True)) + self.loss_D /= self.opt.D_num + self.loss_D.backward() + + def optimize_parameters(self, ): + # ---- Update G ------------ + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + + # ---- Update D ------------ + self.optimizer_D.zero_grad() + self.backward_D() + self.optimizer_D.step() + + def get_current_visuals(self, size=512): + out = [] + visual_imgs = [] + out.append(utils.tensor_to_numpy(self.img_LR)) + out.append(utils.tensor_to_numpy(self.img_SR)) + out.append(utils.tensor_to_numpy(self.img_HR)) + + out_imgs = [utils.batch_numpy_to_image(x, size) for x in out] + + visual_imgs += out_imgs + visual_imgs.append(utils.color_parse_map(self.ref_mask_onehot, size)) + visual_imgs.append(utils.color_parse_map(self.hr_mask, size)) + + return visual_imgs + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/models/loss.py b/face_parse/PSFRGAN-master/PSFRGAN-master/models/loss.py new file mode 100644 index 0000000..421f0de --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/models/loss.py @@ -0,0 +1,224 @@ +import torch +from torchvision import models +from utils import utils +from torch import nn + + +def tv_loss(x): + """ + Total Variation Loss. + """ + return torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) + ) + torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + + +class PCPFeat(torch.nn.Module): + """ + Features used to calculate Perceptual Loss based on ResNet50 features. + Input: (B, C, H, W), RGB, [0, 1] + """ + def __init__(self, weight_path, model='vgg'): + super(PCPFeat, self).__init__() + if model == 'vgg': + self.model = models.vgg19(pretrained=False) + self.build_vgg_layers() + elif model == 'resnet': + self.model = models.resnet50(pretrained=False) + self.build_resnet_layers() + + self.model.load_state_dict(torch.load(weight_path)) + self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + def build_resnet_layers(self): + self.layer1 = torch.nn.Sequential( + self.model.conv1, + self.model.bn1, + self.model.relu, + self.model.maxpool, + self.model.layer1 + ) + self.layer2 = self.model.layer2 + self.layer3 = self.model.layer3 + self.layer4 = self.model.layer4 + self.features = torch.nn.ModuleList( + [self.layer1, self.layer2, self.layer3, self.layer4] + ) + + def build_vgg_layers(self): + vgg_pretrained_features = self.model.features + self.features = [] + feature_layers = [0, 3, 8, 17, 26, 35] + for i in range(len(feature_layers)-1): + module_layers = torch.nn.Sequential() + for j in range(feature_layers[i], feature_layers[i+1]): + module_layers.add_module(str(j), vgg_pretrained_features[j]) + self.features.append(module_layers) + self.features = torch.nn.ModuleList(self.features) + + def preprocess(self, x): + x = (x + 1) / 2 + mean = torch.Tensor([0.485, 0.456, 0.406]).to(x) + std = torch.Tensor([0.229, 0.224, 0.225]).to(x) + mean = mean.view(1, 3, 1, 1) + std = std.view(1, 3, 1, 1) + x = (x - mean) / std + if x.shape[3] < 224: + x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False) + return x + + def forward(self, x): + x = self.preprocess(x) + + features = [] + for m in self.features: + x = m(x) + features.append(x) + return features + + +class PCPLoss(torch.nn.Module): + """Perceptual Loss. + """ + def __init__(self, + opt, + layer=5, + model='vgg', + ): + super(PCPLoss, self).__init__() + + self.mse = torch.nn.MSELoss() + self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] + + def forward(self, x_feats, y_feats): + loss = 0 + for xf, yf, w in zip(x_feats, y_feats, self.weights): + loss = loss + self.mse(xf, yf.detach()) * w + return loss + + +class FMLoss(nn.Module): + def __init__(self): + super().__init__() + self.mse = torch.nn.MSELoss() + + def forward(self, x_feats, y_feats): + loss = 0 + for xf, yf in zip(x_feats, y_feats): + loss = loss + self.mse(xf, yf.detach()) + return loss + + +class GANLoss(nn.Module): + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): + """ Initialize the GANLoss class. + Parameters: + gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool) - - label for a real image + target_fake_label (bool) - - label of a fake image + Note: Do not use sigmoid as the last layer of Discriminator. + LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. + """ + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == 'lsgan': + self.loss = nn.MSELoss() + elif gan_mode == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif gan_mode == 'hinge': + pass + elif gan_mode in ['wgangp']: + self.loss = None + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real, for_discriminator=True): + """Calculate loss given Discriminator's output and grount truth labels. + Parameters: + prediction (tensor) - - tpyically the prediction output from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + Returns: + the calculated loss. + """ + if self.gan_mode in ['lsgan', 'vanilla']: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == 'hinge': + if for_discriminator: + if target_is_real: + loss = nn.ReLU()(1 - prediction).mean() + else: + loss = nn.ReLU()(1 + prediction).mean() + else: + assert target_is_real, "The generator's hinge loss must be aiming for real" + loss = - prediction.mean() + return loss + + elif self.gan_mode == 'wgangp': + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss + + +class RegionStyleLoss(nn.Module): + def __init__(self, reg_num=19, eps=1e-8): + super().__init__() + self.reg_num = reg_num + self.eps = eps + self.mse = nn.MSELoss() + + def __masked_gram_matrix(self, x, m): + b, c, h, w = x.shape + m = m.view(b, -1, h*w) + x = x.view(b, -1, h*w) + total_elements = m.sum(2) + self.eps + + x = x * m + G = torch.bmm(x, x.transpose(1, 2)) + return G / (c * total_elements.view(b, 1, 1)) + + def __layer_gram_matrix(self, x, mask): + b, c, h, w = x.shape + all_gm = [] + for i in range(self.reg_num): + sub_mask = mask[:, i].unsqueeze(1) + gram_matrix = self.__masked_gram_matrix(x, sub_mask) + all_gm.append(gram_matrix) + return torch.stack(all_gm, dim=1) + + def forward(self, x_feats, y_feats, mask): + loss = 0 + for xf, yf in zip(x_feats[2:], y_feats[2:]): + tmp_mask = torch.nn.functional.interpolate(mask, xf.shape[2:]) + xf_gm = self.__layer_gram_matrix(xf, tmp_mask) + yf_gm = self.__layer_gram_matrix(yf, tmp_mask) + tmp_loss = self.mse(xf_gm, yf_gm.detach()) + loss = loss + tmp_loss + return loss + + +if __name__ == '__main__': + x = [ + torch.randn(2, 64, 512, 512), + torch.randn(2, 128, 256, 256), + torch.randn(2, 256, 128, 128), + torch.randn(2, 512, 64, 64), + torch.randn(2, 512, 32, 32), + ] + + y = torch.randint(10, (2, 19, 512, 512)).float() + loss = RegionStyleLoss() + print(loss(x, x, y)) + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/models/networks.py b/face_parse/PSFRGAN-master/PSFRGAN-master/models/networks.py new file mode 100644 index 0000000..40bcaed --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/models/networks.py @@ -0,0 +1,254 @@ +from models.blocks import * +import torch +from torch import nn +from torch.nn import init +from torch.optim import lr_scheduler +from utils import utils +import numpy as np + +from models import psfrnet +import torch.nn.utils as tutils +from models.loss import PCPFeat + + +def apply_norm(net, weight_norm_type): + for m in net.modules(): + if isinstance(m, nn.Conv2d): + if weight_norm_type.lower() == 'spectral_norm': + tutils.spectral_norm(m) + elif weight_norm_type.lower() == 'weight_norm': + tutils.weight_norm(m) + else: + pass + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def define_P(opt, in_size=512, out_size=512, min_feat_size=32, relu_type='LeakyReLU', isTrain=True, weight_path=None): + net = ParseNet(in_size, out_size, min_feat_size, 64, 19, norm_type=opt.Pnorm, relu_type=relu_type, ch_range=[32, 256]) + if not isTrain: + net.eval() + if weight_path is not None: + net.load_state_dict(torch.load(weight_path)) + if len(opt.gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(opt.device) + net = torch.nn.DataParallel(net, opt.gpu_ids, output_device=opt.device) + return net + + +def define_G(opt, isTrain=True, use_norm='none', relu_type='LeakyReLU'): + net = psfrnet.PSFRGenerator(3, 3, in_size=opt.Gin_size, out_size=opt.Gout_size, relu_type=relu_type, parse_ch=19, norm_type=opt.Gnorm) + apply_norm(net, use_norm) + if not isTrain: + net.eval() + if len(opt.gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(opt.device) + net = torch.nn.DataParallel(net, opt.gpu_ids, output_device=opt.device) + # init_weights(net, init_type='normal', init_gain=0.02) + return net + + +def define_D(opt, in_channel=3, isTrain=True, use_norm='none'): + net = MultiScaleDiscriminator(in_channel, opt.ndf, opt.n_layers_D, opt.Dnorm, num_D=opt.D_num) + apply_norm(net, use_norm) + if not isTrain: + net.eval() + if len(opt.gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(opt.device) + net = torch.nn.DataParallel(net, opt.gpu_ids, output_device=opt.device) + init_weights(net, init_type='normal', init_gain=0.02) + return net + + +class ParseNet(nn.Module): + def __init__(self, + in_size=128, + out_size=128, + min_feat_size=32, + base_ch=64, + parsing_ch=19, + res_depth=10, + relu_type='prelu', + norm_type='bn', + ch_range=[32, 512], + ): + super().__init__() + self.res_depth = res_depth + act_args = {'norm_type': norm_type, 'relu_type': relu_type} + min_ch, max_ch = ch_range + + ch_clip = lambda x: max(min_ch, min(x, max_ch)) + min_feat_size = min(in_size, min_feat_size) + + down_steps = int(np.log2(in_size//min_feat_size)) + up_steps = int(np.log2(out_size//min_feat_size)) + + # =============== define encoder-body-decoder ==================== + self.encoder = [] + self.encoder.append(ConvLayer(3, base_ch, 3, 1)) + head_ch = base_ch + for i in range(down_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) + self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) + head_ch = head_ch * 2 + + self.body = [] + for i in range(res_depth): + self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) + + self.decoder = [] + for i in range(up_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) + self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) + head_ch = head_ch // 2 + + self.encoder = nn.Sequential(*self.encoder) + self.body = nn.Sequential(*self.body) + self.decoder = nn.Sequential(*self.decoder) + self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) + self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) + + def forward(self, x): + feat = self.encoder(x) + x = feat + self.body(feat) + x = self.decoder(x) + out_img = self.out_img_conv(x) + out_mask = self.out_mask_conv(x) + return out_mask, out_img + + +class MultiScaleDiscriminator(nn.Module): + def __init__(self, input_ch, base_ch=64, n_layers=3, norm_type='none', relu_type='LeakyReLU', num_D=4): + super().__init__() + + self.D_pool = nn.ModuleList() + for i in range(num_D): + netD = NLayerDiscriminator(input_ch, base_ch, depth=n_layers, norm_type=norm_type, relu_type=relu_type) + self.D_pool.append(netD) + + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + + def forward(self, input, return_feat=False): + results = [] + for netd in self.D_pool: + output = netd(input, return_feat) + results.append(output) + # Downsample input + input = self.downsample(input) + return results + + +class NLayerDiscriminator(nn.Module): + def __init__(self, + input_ch = 3, + base_ch = 64, + max_ch = 1024, + depth = 4, + norm_type = 'none', + relu_type = 'LeakyReLU', + ): + super().__init__() + + nargs = {'norm_type': norm_type, 'relu_type': relu_type} + self.norm_type = norm_type + self.input_ch = input_ch + + self.model = [] + self.model.append(ConvLayer(input_ch, base_ch, norm_type='none', relu_type=relu_type)) + for i in range(depth): + cin = min(base_ch * 2**(i), max_ch) + cout = min(base_ch * 2**(i+1), max_ch) + self.model.append(ConvLayer(cin, cout, scale='down_avg', **nargs)) + self.model = nn.Sequential(*self.model) + self.score_out = ConvLayer(cout, 1, use_pad=False) + + def forward(self, x, return_feat=False): + ret_feats = [] + for idx, m in enumerate(self.model): + x = m(x) + ret_feats.append(x) + x = self.score_out(x) + if return_feat: + return x, ret_feats + else: + return x + + + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/models/parse_model.py b/face_parse/PSFRGAN-master/PSFRGAN-master/models/parse_model.py new file mode 100644 index 0000000..dea5d62 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/models/parse_model.py @@ -0,0 +1,80 @@ +import torch +from .base_model import BaseModel +from . import networks +from utils import utils + +class ParseModel(BaseModel): + def modify_commandline_options(parser, is_train): + if is_train: + parser.add_argument('--parse_map', type=float, default=1.0, help='weight for parsing map') + parser.add_argument('--parse_sr', type=float, default=1.0, help='weight for sr') + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + self.loss_names = ['P', 'SR'] + self.visual_names = ['img_LR', 'img_HR', 'gt_Parse', 'img_SR', 'pred_Parse'] + + self.model_names = ['P'] + self.netP = networks.define_P(opt) + + if self.isTrain: # only defined during training time + self.criterionParse = torch.nn.CrossEntropyLoss() + self.criterionSR = torch.nn.L1Loss() + self.optimizer = torch.optim.Adam(self.netP.parameters(), lr=opt.lr, betas=(0.9, 0.999)) + self.optimizers = [self.optimizer] + + def set_input(self, input, cur_iters=None): + self.img_LR = input['LR'].to(self.opt.device) + self.img_HR = input['HR'].to(self.opt.device) + self.gt_Parse = input['Mask'].to(self.opt.device) + if self.opt.debug: + print('ParseNet input shape:', self.img_LR.shape, self.img_HR.shape, self.gt_Parse.shape) + + def load_pretrain_models(self,): + self.netP.eval() + print('Loading pretrained LQ face parsing network from', self.opt.parse_net_weight) + self.netP.load_state_dict(torch.load(self.opt.parse_net_weight)) + + def forward(self): + self.pred_Parse, self.img_SR = self.netP(self.img_LR) + if self.opt.debug: + print('ParseNet output shape', self.pred_Parse.shape, self.img_SR.shape) + + def backward(self): + self.loss_P = self.criterionParse(self.pred_Parse, self.gt_Parse) * self.opt.parse_map + self.loss_SR = self.criterionSR(self.img_SR, self.img_HR) * self.opt.parse_sr + + loss = self.loss_P + self.loss_SR + loss.backward() + + def optimize_parameters(self): + self.optimizer.zero_grad() # clear network G's existing gradients + self.backward() # calculate gradients for network G + self.optimizer.step() + + def get_current_visuals(self, size=512): + out = [] + visual_imgs = [] + out.append(utils.tensor_to_numpy(self.img_LR)) + out.append(utils.tensor_to_numpy(self.img_SR)) + out.append(utils.tensor_to_numpy(self.img_HR)) + out_imgs = [utils.batch_numpy_to_image(x, size) for x in out] + + visual_imgs.append(out_imgs[0]) + visual_imgs.append(out_imgs[1]) + visual_imgs.append(utils.color_parse_map(self.pred_Parse)) + visual_imgs.append(utils.color_parse_map(self.gt_Parse.unsqueeze(1))) + visual_imgs.append(out_imgs[2]) + + return visual_imgs + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/models/psfrnet.py b/face_parse/PSFRGAN-master/PSFRGAN-master/models/psfrnet.py new file mode 100644 index 0000000..7ea2d03 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/models/psfrnet.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +from torch.nn import init +import numpy as np +from models.blocks import * + + +class SPADENorm(nn.Module): + def __init__(self, norm_nc, ref_nc, norm_type='spade', ksz=3): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + mid_c = 64 + + self.norm_type = norm_type + if norm_type == 'spade': + self.conv1 = nn.Sequential( + nn.Conv2d(ref_nc, mid_c, ksz, 1, ksz//2), + nn.LeakyReLU(0.2, True), + ) + self.gamma_conv = nn.Conv2d(mid_c, norm_nc, ksz, 1, ksz//2) + self.beta_conv = nn.Conv2d(mid_c, norm_nc, ksz, 1, ksz//2) + + def get_gamma_beta(self, x, conv, gamma_conv, beta_conv): + act = conv(x) + gamma = gamma_conv(act) + beta = beta_conv(act) + return gamma, beta + + def forward(self, x, ref): + normalized_input = self.param_free_norm(x) + if x.shape[-1] != ref.shape[-1]: + ref = nn.functional.interpolate(ref, x.shape[2:], mode='bicubic', align_corners=False) + if self.norm_type == 'spade': + gamma, beta = self.get_gamma_beta(ref, self.conv1, self.gamma_conv, self.beta_conv) + return normalized_input * gamma + beta + elif self.norm_type == 'in': + return normalized_input + + +class SPADEResBlock(nn.Module): + def __init__(self, fin, fout, ref_nc, relu_type, norm_type='spade'): + super().__init__() + + fmiddle = min(fin, fout) + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) + + # define normalization layers + self.norm_0 = SPADENorm(fmiddle, ref_nc, norm_type) + self.norm_1 = SPADENorm(fmiddle, ref_nc, norm_type) + self.relu = ReluLayer(fmiddle, relu_type) + + def forward(self, x, ref): + res = self.conv_0(self.relu(self.norm_0(x, ref))) + res = self.conv_1(self.relu(self.norm_1(res, ref))) + out = x + res + + return out + + +class PSFRGenerator(nn.Module): + def __init__(self, input_nc, output_nc, in_size=512, out_size=512, min_feat_size=16, ngf=64, n_blocks=9, parse_ch=19, relu_type='relu', + ch_range=[32, 1024], norm_type='spade'): + super().__init__() + + min_ch, max_ch = ch_range + ch_clip = lambda x: max(min_ch, min(x, max_ch)) + get_ch = lambda size: ch_clip(1024*16//size) + + self.const_input = nn.Parameter(torch.randn(1, get_ch(min_feat_size), min_feat_size, min_feat_size)) + up_steps = int(np.log2(out_size//min_feat_size)) + self.up_steps = up_steps + + ref_ch = 19+3 + + head_ch = get_ch(min_feat_size) + head = [ + nn.Conv2d(head_ch, head_ch, kernel_size=3, padding=1), + SPADEResBlock(head_ch, head_ch, ref_ch, relu_type, norm_type), + ] + + body = [] + for i in range(up_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) + body += [ + nn.Sequential( + nn.Upsample(scale_factor=2), + nn.Conv2d(cin, cout, kernel_size=3, padding=1), + SPADEResBlock(cout, cout, ref_ch, relu_type, norm_type) + ) + ] + head_ch = head_ch // 2 + + self.img_out = nn.Conv2d(ch_clip(head_ch), output_nc, kernel_size=3, padding=1) + + self.head = nn.Sequential(*head) + self.body = nn.Sequential(*body) + self.upsample = nn.Upsample(scale_factor=2) + + def forward_spade(self, net, x, ref): + for m in net: + x = self.forward_spade_m(m, x, ref) + return x + + def forward_spade_m(self, m, x, ref): + if isinstance(m, SPADENorm) or isinstance(m, SPADEResBlock): + x = m(x, ref) + else: + x = m(x) + return x + + def forward(self, x, ref): + b, c, h, w = x.shape + const_input = self.const_input.repeat(b, 1, 1, 1) + ref_input = torch.cat((x, ref), dim=1) + + feat = self.forward_spade(self.head, const_input, ref_input) + + for idx, m in enumerate(self.body): + feat = self.forward_spade(m, feat, ref_input) + + out_img = self.img_out(feat) + + return out_img + + +if __name__ == '__main__': + x = torch.randn(2, 16, 567, 234) + nearest_interpolate(x) diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/options/__init__.py b/face_parse/PSFRGAN-master/PSFRGAN-master/options/__init__.py new file mode 100644 index 0000000..e7eedeb --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/options/base_options.py b/face_parse/PSFRGAN-master/PSFRGAN-master/options/base_options.py new file mode 100644 index 0000000..e0c3c1f --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/options/base_options.py @@ -0,0 +1,165 @@ +import argparse +import os +import numpy as np +import random +from utils import utils +import torch +import models +import data +from utils import utils + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # basic parameters + parser.add_argument('--dataroot', required=False, help='path to images') + parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpus', type=int, default=1, help='how many gpus to use') + parser.add_argument('--seed', type=int, default=123, help='Random seed for training') + parser.add_argument('--checkpoints_dir', type=str, default='./check_points', help='models are saved here') + # model parameters + parser.add_argument('--model', type=str, default='enhance', help='chooses which model to train [parse|enhance]') + parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--Dinput_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') + parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') + parser.add_argument('--n_layers_D', type=int, default=4, help='downsampling layers in discriminator') + parser.add_argument('--D_num', type=int, default=3, help='numbers of discriminators') + + parser.add_argument('--Pnorm', type=str, default='bn', help='parsing net norm [in | bn| none]') + parser.add_argument('--Gnorm', type=str, default='spade', help='generator norm [in | bn | none]') + parser.add_argument('--Dnorm', type=str, default='in', help='discriminator norm [in | bn | none]') + parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + # dataset parameters + parser.add_argument('--dataset_name', type=str, default='single', help='dataset name') + parser.add_argument('--Pimg_size', type=int, default='512', help='image size for face parse net') + parser.add_argument('--Gin_size', type=int, default='512', help='image size for face parse net') + parser.add_argument('--Gout_size', type=int, default='512', help='image size for face parse net') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--num_threads', default=8, type=int, help='# threads for loading data') + parser.add_argument('--batch_size', type=int, default=16, help='input batch size') + parser.add_argument('--load_size', type=int, default=512, help='scale images to this size') + parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--preprocess', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') + parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') + # additional parameters + parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + parser.add_argument('--debug', action='store_true', help='if specified, set to debug mode') + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + # modify dataset-related parser options + dataset_name = opt.dataset_name + dataset_option_setter = data.get_option_setter(dataset_name) + parser = dataset_option_setter(parser, self.isTrain) + + # save and return the parser + self.parser = parser + return parser.parse_args() + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + opt.expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + utils.mkdirs(opt.expr_dir) + file_name = os.path.join(opt.expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + opt.log_dir = os.path.join(opt.checkpoints_dir, 'log_dir') + utils.mkdirs(opt.log_dir) + opt.log_archive = os.path.join(opt.checkpoints_dir, 'log_archive') + utils.mkdirs(opt.log_archive) + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + if opt.debug: + opt.name = 'debug' + opt.save_iter_freq = 1 + opt.save_latest_freq = 1 + opt.visual_freq = 1 + opt.print_freq = 1 + + # Find avaliable GPUs automatically + if opt.gpus > 0: + opt.gpu_ids = utils.get_gpu_memory_map()[1][:opt.gpus] + if not isinstance(opt.gpu_ids, list): + opt.gpu_ids = [opt.gpu_ids] + torch.cuda.set_device(opt.gpu_ids[0]) + opt.device = torch.device('cuda:{}'.format(opt.gpu_ids[0 % opt.gpus])) + opt.data_device = torch.device('cuda:{}'.format(opt.gpu_ids[1 % opt.gpus])) + else: + opt.gpu_ids = [] + opt.device = torch.device('cpu') + + # set random seeds to ensure reproducibility + np.random.seed(opt.seed) + random.seed(opt.seed) + torch.manual_seed(opt.seed) + torch.cuda.manual_seed_all(opt.seed) + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + self.print_options(opt) + + self.opt = opt + return self.opt diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/options/test_options.py b/face_parse/PSFRGAN-master/PSFRGAN-master/options/test_options.py new file mode 100644 index 0000000..0e572e2 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/options/test_options.py @@ -0,0 +1,30 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--src_dir', type=str, default='G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan/n000002', help='source directory containing test images') + parser.add_argument('--save_masks_dir', type=str, default='../datasets/FFHQ/masks512', help='path to save parsing masks for FFHQ') + parser.add_argument('--test_img_path', type=str, default='', help='path for single image test') + parser.add_argument('--test_upscale', type=float, default=1, help='upsample scale for single image test') + parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') + parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') + parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') + parser.add_argument('--parse_net_weight', type=str, default='./pretrain_models/parse_multi_iter_90000.pth', help='parse model path') + parser.add_argument('--psfr_net_weight', type=str, default='./pretrain_models/psfrgan_epoch15_net_G.pth', help='parse model path') + # rewrite devalue values + # To avoid cropping, the load_size should be the same as crop_size + parser.set_defaults(load_size=parser.get_default('crop_size')) + self.isTrain = False + + return parser diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/options/train_options.py b/face_parse/PSFRGAN-master/PSFRGAN-master/options/train_options.py new file mode 100644 index 0000000..7e4fa0a --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/options/train_options.py @@ -0,0 +1,41 @@ +from .base_options import BaseOptions + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # visdom and HTML visualization parameters + parser.add_argument('--visual_freq', type=int, default=400, help='frequency of show training images in tensorboard') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + # network saving and loading parameters + parser.add_argument('--save_iter_freq', type=int, default=5000, help='frequency of saving the models') + parser.add_argument('--save_latest_freq', type=int, default=500, help='save latest freq') + parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--no_strict_load', action='store_true', help='set strict load to false') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + # training parameters + parser.add_argument('--resume_epoch', type=int, default=0, help='training resume epoch') + parser.add_argument('--resume_iter', type=int, default=0, help='training resume iter') + parser.add_argument('--total_epochs', type=int, default=50, help='# of epochs to train') + parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') + parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero') + parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') + parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') + parser.add_argument('--g_lr', type=float, default=0.0001, help='generator learning rate') + parser.add_argument('--d_lr', type=float, default=0.0004, help='discriminator learning rate') + parser.add_argument('--gan_mode', type=str, default='hinge', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + parser.add_argument('--lr_decay_gamma', type=float, default=1, help='multiply by a gamma every lr_decay_iters iterations') + + self.isTrain = True + + return parser diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/requirements.txt b/face_parse/PSFRGAN-master/PSFRGAN-master/requirements.txt new file mode 100644 index 0000000..5ffeefb --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/requirements.txt @@ -0,0 +1,11 @@ +torch==1.5.1 +torchvision==0.6.1 +tensorflow>=1.15.4 +tensorboard==1.15.0 +tensorboardX==2.1 +opencv-python +dlib +scikit-image==0.17.2 +scipy==1.4.1 +tqdm +imgaug diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/test_enhance_dir_align.py b/face_parse/PSFRGAN-master/PSFRGAN-master/test_enhance_dir_align.py new file mode 100644 index 0000000..fbccb83 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/test_enhance_dir_align.py @@ -0,0 +1,62 @@ +import os +from options.test_options import TestOptions +from data import create_dataset +from models import create_model +from utils import utils +from PIL import Image +from tqdm import tqdm +import torch +import time +import numpy as np + +if __name__ == '__main__': + opt = TestOptions().parse() # get test options + opt.num_threads = 0 # test code only supports num_threads = 1 + opt.batch_size = 4 # test code only supports batch_size = 1 + opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. + opt.no_flip = True + + dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options + model = create_model(opt) # create a model given opt.model and other options + model.load_pretrain_models() + + save_dir = opt.results_dir + os.makedirs(save_dir, exist_ok=True) + + print('creating result directory', save_dir) + netP = model.netP + netG = model.netG + model.eval() + max_size = 9999 + os.makedirs(os.path.join(save_dir, 'sr'), exist_ok=True) + for i, data in tqdm(enumerate(dataset), total=len(dataset)//opt.batch_size): + inp = data['LR'] + with torch.no_grad(): + parse_map, _ = netP(inp) + parse_map_sm = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float() + output_SR = netG(inp, parse_map_sm) + img_path = data['LR_paths'] # get image paths + for i in tqdm(range(len(img_path))): + inp_img = utils.batch_tensor_to_img(inp) + output_sr_img = utils.batch_tensor_to_img(output_SR) + ref_parse_img = utils.color_parse_map(parse_map_sm) + + save_path = os.path.join(save_dir, 'lq', os.path.basename(img_path[i])) + os.makedirs(os.path.join(save_dir, 'lq'), exist_ok=True) + save_img = Image.fromarray(inp_img[i]) + save_img.save(save_path) + + save_path = os.path.join(save_dir, 'hq', os.path.basename(img_path[i])) + os.makedirs(os.path.join(save_dir, 'hq'), exist_ok=True) + save_img = Image.fromarray(output_sr_img[i]) + save_img.save(save_path) + + save_path = os.path.join(save_dir, 'parse', os.path.basename(img_path[i])) + os.makedirs(os.path.join(save_dir, 'parse'), exist_ok=True) + save_img = Image.fromarray(ref_parse_img[i]) + save_img.save(save_path) + + if i > max_size: break + + + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/test_enhance_dir_unalign.py b/face_parse/PSFRGAN-master/PSFRGAN-master/test_enhance_dir_unalign.py new file mode 100644 index 0000000..3156fe9 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/test_enhance_dir_unalign.py @@ -0,0 +1,57 @@ +''' +This script enhance images with unaligned faces in a folder and paste it back to the original place. +''' +import dlib +import os +import cv2 +import numpy as np +from tqdm import tqdm +from skimage import transform as trans +from skimage import io + +import torch +from utils import utils +from options.test_options import TestOptions +from models import create_model + +from test_enhance_single_unalign import * + + +if __name__ == '__main__': + opt = TestOptions().parse() + # face_detector = dlib.get_frontal_face_detector() + face_detector = dlib.cnn_face_detection_model_v1('./pretrain_models/mmod_human_face_detector.dat') + lmk_predictor = dlib.shape_predictor('./pretrain_models/shape_predictor_5_face_landmarks.dat') + template_path = './pretrain_models/FFHQ_template.npy' + enhance_model = def_models(opt) + + for img_name in os.listdir(opt.src_dir): + img_path = os.path.join(opt.src_dir, img_name) + save_current_dir = os.path.join(opt.results_dir, os.path.splitext(img_name)[0]) + os.makedirs(save_current_dir, exist_ok=True) + print('======> Loading image', img_path) + img = dlib.load_rgb_image(img_path) + aligned_faces, tform_params = detect_and_align_faces(img, face_detector, lmk_predictor, template_path) + # Save aligned LQ faces + save_lq_dir = os.path.join(save_current_dir, 'LQ_faces') + os.makedirs(save_lq_dir, exist_ok=True) + print('======> Saving aligned LQ faces to', save_lq_dir) + save_imgs(aligned_faces, save_lq_dir) + + hq_faces, lq_parse_maps = enhance_faces(aligned_faces, enhance_model) + # Save LQ parsing maps and enhanced faces + save_parse_dir = os.path.join(save_current_dir, 'ParseMaps') + save_hq_dir = os.path.join(save_current_dir, 'HQ') + os.makedirs(save_parse_dir, exist_ok=True) + os.makedirs(save_hq_dir, exist_ok=True) + print('======> Save parsing map and the enhanced faces.') + save_imgs(lq_parse_maps, save_parse_dir) + save_imgs(hq_faces, save_hq_dir) + + print('======> Paste the enhanced faces back to the original image.') + hq_img = past_faces_back(img, hq_faces, tform_params, upscale=opt.test_upscale) + final_save_path = os.path.join(save_current_dir, 'hq_final.jpg') + print('======> Save final result to', final_save_path) + io.imsave(final_save_path, hq_img) + + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/test_enhance_single_unalign.py b/face_parse/PSFRGAN-master/PSFRGAN-master/test_enhance_single_unalign.py new file mode 100644 index 0000000..208aac3 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/test_enhance_single_unalign.py @@ -0,0 +1,126 @@ +''' +This script enhance all faces in one image with PSFR-GAN and paste it back to the original place. +''' +import dlib +import os +import cv2 +import numpy as np +from tqdm import tqdm +from skimage import transform as trans +from skimage import io + +import torch +from utils import utils +from options.test_options import TestOptions +from models import create_model + + +def detect_and_align_faces(img, face_detector, lmk_predictor, template_path, template_scale=2, size_threshold=999): + align_out_size = (512, 512) + ref_points = np.load(template_path) / template_scale + + # Detect landmark points + face_dets = face_detector(img, 1) + assert len(face_dets) > 0, 'No faces detected' + + aligned_faces = [] + tform_params = [] + for det in face_dets: + if isinstance(face_detector, dlib.cnn_face_detection_model_v1): + rec = det.rect # for cnn detector + else: + rec = det + if rec.width() > size_threshold or rec.height() > size_threshold: + print('Face is too large') + break + landmark_points = lmk_predictor(img, rec) + single_points = [] + for i in range(5): + single_points.append([landmark_points.part(i).x, landmark_points.part(i).y]) + single_points = np.array(single_points) + tform = trans.SimilarityTransform() + tform.estimate(single_points, ref_points) + tmp_face = trans.warp(img, tform.inverse, output_shape=align_out_size, order=3) + aligned_faces.append(tmp_face*255) + tform_params.append(tform) + return [aligned_faces, tform_params] + + +def def_models(opt): + model = create_model(opt) + model.load_pretrain_models() + model.netP.to(opt.device) + model.netG.to(opt.device) + return model + + +def enhance_faces(LQ_faces, model): + hq_faces = [] + lq_parse_maps = [] + for lq_face in tqdm(LQ_faces): + with torch.no_grad(): + lq_tensor = torch.tensor(lq_face.transpose(2, 0, 1)) / 255. * 2 - 1 + lq_tensor = lq_tensor.unsqueeze(0).float().to(model.device) + parse_map, _ = model.netP(lq_tensor) + parse_map_onehot = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float() + output_SR = model.netG(lq_tensor, parse_map_onehot) + hq_faces.append(utils.tensor_to_img(output_SR)) + lq_parse_maps.append(utils.color_parse_map(parse_map_onehot)[0]) + return hq_faces, lq_parse_maps + + +def past_faces_back(img, hq_faces, tform_params, upscale=1): + h, w = img.shape[:2] + img = cv2.resize(img, (int(w*upscale), int(h*upscale)), interpolation=cv2.INTER_CUBIC) + for hq_img, tform in tqdm(zip(hq_faces, tform_params), total=len(hq_faces)): + tform.params[0:2,0:2] /= upscale + back_img = trans.warp(hq_img/255., tform, output_shape=[int(h*upscale), int(w*upscale)], order=3) * 255 + + # blur mask to avoid border artifacts + mask = (back_img == 0) + mask = cv2.blur(mask.astype(np.float32), (5,5)) + mask = (mask > 0) + img = img * mask + (1 - mask) * back_img + return img.astype(np.uint8) + + +def save_imgs(img_list, save_dir): + for idx, img in enumerate(img_list): + save_path = os.path.join(save_dir, '{:03d}.jpg'.format(idx)) + io.imsave(save_path, img.astype(np.uint8)) + +if __name__ == '__main__': + opt = TestOptions().parse() + # face_detector = dlib.get_frontal_face_detector() + face_detector = dlib.cnn_face_detection_model_v1('./pretrain_models/mmod_human_face_detector.dat') + lmk_predictor = dlib.shape_predictor('./pretrain_models/shape_predictor_5_face_landmarks.dat') + template_path = './pretrain_models/FFHQ_template.npy' + + print('======> Loading images, crop and align faces.') + img_path = opt.test_img_path + img = dlib.load_rgb_image(img_path) + aligned_faces, tform_params = detect_and_align_faces(img, face_detector, lmk_predictor, template_path) + # Save aligned LQ faces + save_lq_dir = os.path.join(opt.results_dir, 'LQ_faces') + os.makedirs(save_lq_dir, exist_ok=True) + print('======> Saving aligned LQ faces to', save_lq_dir) + save_imgs(aligned_faces, save_lq_dir) + + enhance_model = def_models(opt) + hq_faces, lq_parse_maps = enhance_faces(aligned_faces, enhance_model) + # Save LQ parsing maps and enhanced faces + save_parse_dir = os.path.join(opt.results_dir, 'ParseMaps') + save_hq_dir = os.path.join(opt.results_dir, 'HQ') + os.makedirs(save_parse_dir, exist_ok=True) + os.makedirs(save_hq_dir, exist_ok=True) + print('======> Save parsing map and the enhanced faces.') + save_imgs(lq_parse_maps, save_parse_dir) + save_imgs(hq_faces, save_hq_dir) + + print('======> Paste the enhanced faces back to the original image.') + hq_img = past_faces_back(img, hq_faces, tform_params, upscale=opt.test_upscale) + final_save_path = os.path.join(opt.results_dir, 'hq_final.jpg') + print('======> Save final result to', final_save_path) + io.imsave(final_save_path, hq_img) + + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/train.py b/face_parse/PSFRGAN-master/PSFRGAN-master/train.py new file mode 100644 index 0000000..855b07d --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/train.py @@ -0,0 +1,79 @@ +from utils.timer import Timer +from utils.logger import Logger +from utils import utils + +from options.train_options import TrainOptions +from data import create_dataset +from models import create_model + +import torch +import os +import torch.multiprocessing as mp + +def train(opt): + + dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options + dataset_size = len(dataset) # get the number of images in the dataset. + print('The number of training images = %d' % dataset_size) + + model = create_model(opt) + model.setup(opt) + + logger = Logger(opt) + timer = Timer() + + single_epoch_iters = (dataset_size // opt.batch_size) + total_iters = opt.total_epochs * single_epoch_iters + cur_iters = opt.resume_iter + opt.resume_epoch * single_epoch_iters + start_iter = opt.resume_iter + print('Start training from epoch: {:05d}; iter: {:07d}'.format(opt.resume_epoch, opt.resume_iter)) + for epoch in range(opt.resume_epoch, opt.total_epochs + 1): + for i, data in enumerate(dataset, start=start_iter): + cur_iters += 1 + logger.set_current_iter(cur_iters) + # =================== load data =============== + model.set_input(data, cur_iters) + timer.update_time('DataTime') + + # =================== model train =============== + model.forward(), timer.update_time('Forward') + model.optimize_parameters() + loss = model.get_current_losses() + loss.update(model.get_lr()) + logger.record_losses(loss) + timer.update_time('Backward') + + # =================== save model and visualize =============== + if cur_iters % opt.print_freq == 0: + print('Model log directory: {}'.format(opt.expr_dir)) + epoch_progress = '{:03d}|{:05d}/{:05d}'.format(epoch, i, single_epoch_iters) + logger.printIterSummary(epoch_progress, cur_iters, total_iters, timer) + + if cur_iters % opt.visual_freq == 0: + visual_imgs = model.get_current_visuals() + logger.record_images(visual_imgs) + + if cur_iters % opt.save_iter_freq == 0: + print('saving current model (epoch %d, iters %d)' % (epoch, cur_iters)) + save_suffix = 'iter_%d' % cur_iters + info = {'resume_epoch': epoch, 'resume_iter': i+1} + model.save_networks(save_suffix, info) + + if cur_iters % opt.save_latest_freq == 0: + print('saving the latest model (epoch %d, iters %d)' % (epoch, cur_iters)) + info = {'resume_epoch': epoch, 'resume_iter': i+1} + model.save_networks('latest', info) + + if i >= single_epoch_iters - 1: + start_iter = 0 + break + + # model.update_learning_rate() + if opt.debug: break + if opt.debug and epoch >= 0: break + logger.close() + +if __name__ == '__main__': + opt = TrainOptions().parse() + train(opt) + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/utils/logger.py b/face_parse/PSFRGAN-master/PSFRGAN-master/utils/logger.py new file mode 100644 index 0000000..afbdfde --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/utils/logger.py @@ -0,0 +1,91 @@ +import os +from collections import OrderedDict +import numpy as np +from .utils import mkdirs +from tensorboardX import SummaryWriter +from datetime import datetime +import socket +import shutil + +class Logger(): + def __init__(self, opts): + time_stamp = '_{}'.format(datetime.now().strftime('%Y-%m-%d_%H:%M')) + self.opts = opts + self.log_dir = os.path.join(opts.log_dir, opts.name+time_stamp) + self.phase_keys = ['train', 'val', 'test'] + self.iter_log = [] + self.epoch_log = OrderedDict() + self.set_mode(opts.phase) + + # check if exist previous log belong to the same experiment name + exist_log = None + for log_name in os.listdir(opts.log_dir): + if opts.name in log_name: + exist_log = log_name + if exist_log is not None: + old_dir = os.path.join(opts.log_dir, exist_log) + archive_dir = os.path.join(opts.log_archive, exist_log) + shutil.move(old_dir, archive_dir) + + self.mk_log_file() + + self.writer = SummaryWriter(self.log_dir) + + def mk_log_file(self): + mkdirs(self.log_dir) + self.txt_files = OrderedDict() + for i in self.phase_keys: + self.txt_files[i] = os.path.join(self.log_dir, 'log_{}'.format(i)) + + def set_mode(self, mode): + self.mode = mode + self.epoch_log[mode] = [] + + def set_current_iter(self, cur_iter): + self.cur_iter = cur_iter + + def record_losses(self, items): + """ + iteration log: [iter][{key: value}] + """ + self.iter_log.append(items) + for k, v in items.items(): + if 'loss' in k.lower(): + self.writer.add_scalar('loss/{}'.format(k), v, self.cur_iter) + + def record_scalar(self, items): + """ + Add scalar records. item, {key: value} + """ + for i in items.keys(): + self.writer.add_scalar('{}'.format(i), items[i], self.cur_iter) + + def record_images(self, visuals, nrow=6, tag='ckpt_image'): + imgs = [] + max_len = visuals[0].shape[0] + for i in range(nrow): + if i >= max_len: continue + tmp_imgs = [x[i] for x in visuals] + imgs.append(np.hstack(tmp_imgs)) + imgs = np.vstack(imgs).astype(np.uint8) + self.writer.add_image(tag, imgs, self.cur_iter, dataformats='HWC') + + def record_text(self, tag, text): + self.writer.add_text(tag, text) + + def printIterSummary(self, epoch, cur_iters, total_it, timer): + msg = '{}\nIter: [{}]{:03d}/{:03d}\t\t'.format( + timer.to_string(total_it - cur_iters), epoch, cur_iters, total_it) + for k, v in self.iter_log[-1].items(): + msg += '{}: {:.6f}\t'.format(k, v) + print(msg + '\n') + with open(self.txt_files[self.mode], 'a+') as f: + f.write(msg + '\n') + + def close(self): + self.writer.export_scalars_to_json(os.path.join(self.log_dir, 'all_scalars.json')) + self.writer.close() + + + + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/utils/timer.py b/face_parse/PSFRGAN-master/PSFRGAN-master/utils/timer.py new file mode 100644 index 0000000..1ec4f74 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/utils/timer.py @@ -0,0 +1,34 @@ +import time +import datetime +from collections import OrderedDict + +class Timer(): + def __init__(self): + self.reset_timer() + self.start = time.time() + + def reset_timer(self): + self.before = time.time() + self.timer = OrderedDict() + + def restart(self): + self.before = time.time() + + def update_time(self, key): + self.timer[key] = time.time() - self.before + self.before = time.time() + + def to_string(self, iters_left, short=False): + iter_total = sum(self.timer.values()) + msg = "{:%Y-%m-%d %H:%M:%S}\tElapse: {}\tTimeLeft: {}\t".format( + datetime.datetime.now(), + datetime.timedelta(seconds=round(time.time() - self.start)), + datetime.timedelta(seconds=round(iter_total*iters_left)) + ) + if short: + msg += '{}: {:.2f}s'.format('|'.join(self.timer.keys()), iter_total) + else: + msg += '\tIterTotal: {:.2f}s\t{}: {} '.format(iter_total, + '|'.join(self.timer.keys()), ' '.join('{:.2f}s'.format(x) for x in self.timer.values())) + return msg + diff --git a/face_parse/PSFRGAN-master/PSFRGAN-master/utils/utils.py b/face_parse/PSFRGAN-master/PSFRGAN-master/utils/utils.py new file mode 100644 index 0000000..ed49cc9 --- /dev/null +++ b/face_parse/PSFRGAN-master/PSFRGAN-master/utils/utils.py @@ -0,0 +1,169 @@ +import torch +import numpy as np +import cv2 as cv +from skimage import io +from PIL import Image +import os +import subprocess + +# MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] +MASK_COLORMAP = [[0, 0, 0], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [0,0, 0], [0, 0, 0], [255, 255, 255], [255, 255, 255], [255, 255, 255], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]] +label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth'] + +def array_to_heatmap(x): + x = (x - x.min()) / (x.max() - x.min()) * 255 + x = x.astype(np.uint8) + return cv.applyColorMap(x.astype(np.uint8), cv.COLORMAP_RAINBOW) + +def img_to_tensor(img_path, device, size=None, mode='rgb'): + """ + Read image from img_path, and convert to (C, H, W) tensor in range [-1, 1] + """ + img = Image.open(img_path).convert('RGB') + img = np.array(img) + if mode=='bgr': + img = img[..., ::-1] + if size: + img = cv.resize(img, size) + img = img / 255 * 2 - 1 + img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device) + return img_tensor.float() + +def tensor_to_img(tensor, save_path=None, size=None, mode='RGB', normal=[-1, 1]): + """ + mode: RGB or L (gray image) + Input: tensor with shape (C, H, W) + Output: PIL Image + """ + if isinstance(size, int): + size = (size, size) + img_array = tensor.squeeze().data.cpu().numpy() + if mode == 'RGB': + img_array = img_array.transpose(1, 2, 0) + + if size is not None: + img_array = cv.resize(img_array, size, interpolation=cv.INTER_LINEAR) + + if len(normal): + img_array = (img_array - normal[0]) / (normal[1] - normal[0]) * 255 + img_array = img_array.clip(0, 255) + + img_array = img_array.astype(np.uint8) + if save_path: + img = Image.fromarray(img_array, mode) + img.save(save_path) + + return img_array + +def tensor_to_numpy(tensor): + return tensor.data.cpu().numpy() + +def batch_numpy_to_image(array, size=None): + """ + Input: numpy array (B, C, H, W) in [-1, 1] + """ + if isinstance(size, int): + size = (size, size) + + out_imgs = [] + array = np.clip((array + 1)/2 * 255, 0, 255) + array = np.transpose(array, (0, 2, 3, 1)) + for i in range(array.shape[0]): + if size is not None: + tmp_array = cv.resize(array[i], size) + else: + tmp_array = array[i] + out_imgs.append(tmp_array) + return np.array(out_imgs).astype(np.uint8) + +def batch_tensor_to_img(tensor, size=None): + """ + Input: (B, C, H, W) + Return: RGB image, [0, 255] + """ + arrays = tensor_to_numpy(tensor) + out_imgs = batch_numpy_to_image(arrays, size) + return out_imgs + +def color_parse_map(tensor, size=None): + """ + input: tensor or batch tensor + return: colorized parsing maps + """ + if len(tensor.shape) < 4: + tensor = tensor.unsqueeze(0) + if tensor.shape[1] > 1: + tensor = tensor.argmax(dim=1) + + tensor = tensor.squeeze(1).data.cpu().numpy() + color_maps = [] + for t in tensor: + tmp_img = np.zeros(tensor.shape[1:] + (3,)) + for idx, color in enumerate(MASK_COLORMAP): + tmp_img[t == idx] = color + if size is not None: + tmp_img = cv.resize(tmp_img, (size, size)) + color_maps.append(tmp_img.astype(np.uint8)) + return color_maps + +def onehot_parse_map(img): + """ + input: RGB color parse map + output: one hot encoding of parse map + """ + n_label = len(MASK_COLORMAP) + img = np.array(img, dtype=np.uint8) + h, w = img.shape[:2] + onehot_label = np.zeros((n_label, h, w)) + colormap = np.array(MASK_COLORMAP).reshape(n_label, 1, 1, 3) + colormap = np.tile(colormap, (1, h, w, 1)) + for idx, color in enumerate(MASK_COLORMAP): + tmp_label = colormap[idx] == img + onehot_label[idx] = tmp_label[..., 0] * tmp_label[..., 1] * tmp_label[..., 2] + return onehot_label + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + if not os.path.exists(path): + os.makedirs(path) + else: + if not os.path.exists(paths): + os.makedirs(paths) + + +def get_gpu_memory_map(): + """Get the current gpu usage within visible cuda devices. + + Returns + ------- + Memory Map: dict + Keys are device ids as integers. + Values are memory usage as integers in MB. + Device Ids: gpu ids sorted in descending order according to the available memory. + """ + result = subprocess.check_output( + [ + 'nvidia-smi', '--query-gpu=memory.used', + '--format=csv,nounits,noheader' + ]).decode('utf-8') + # Convert lines into a dictionary + gpu_memory = np.array([int(x) for x in result.strip().split('\n')]) + if 'CUDA_VISIBLE_DEVICES' in os.environ: + visible_devices = sorted([int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]) + else: + visible_devices = range(len(gpu_memory)) + gpu_memory_map = dict(zip(range(len(visible_devices)), gpu_memory[visible_devices])) + return gpu_memory_map, sorted(gpu_memory_map, key=gpu_memory_map.get) + + +if __name__ == '__main__': + hm = torch.randn(32, 68, 128, 128).cuda() + flip(hm, 2) + x = torch.ones(32, 68) + y = torch.ones(32, 68) + print(get_gpu_memory_map()) + + + diff --git a/filter.py b/filter.py new file mode 100644 index 0000000..d927ce0 --- /dev/null +++ b/filter.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: filter.py +# Created Date: Wednesday April 13th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 13th April 2022 3:49:23 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import cv2 +import torch.nn as nn +import torch +import torch.nn.functional as F + +import numpy as np +from PIL import Image +from torchvision import transforms + +class HighPass(nn.Module): + def __init__(self, w_hpf, device): + super(HighPass, self).__init__() + self.filter = torch.tensor([[-1, -1, -1], + [-1, 8., -1], + [-1, -1, -1]]).to(device) / w_hpf + + def forward(self, x): + filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1) + return F.conv2d(x, filter, padding=1, groups=x.size(1)) + +if __name__ == "__main__": + transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) + imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) + + img = "G:/swap_data/ID/2.jpg" + attr = cv2.imread(img) + attr = Image.fromarray(cv2.cvtColor(attr,cv2.COLOR_BGR2RGB)) + attr = transformer_Arcface(attr).unsqueeze(0) + results = HighPass(0.5,torch.device("cpu"))(attr) + + results = results * imagenet_std + imagenet_mean + results = results.cpu().permute(0,2,3,1)[0,...] + results = results.numpy() + results = np.clip(results,0.0,1.0) * 255 + results = cv2.cvtColor(results,cv2.COLOR_RGB2BGR) + cv2.imwrite("filter_results2.png",results) diff --git a/flops.py b/flops.py index b415e3c..e1f013a 100644 --- a/flops.py +++ b/flops.py @@ -5,7 +5,7 @@ # Created Date: Sunday February 13th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Friday, 4th March 2022 1:53:53 am +# Last Modified: Monday, 18th April 2022 10:52:57 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -22,7 +22,7 @@ from thop import clever_format if __name__ == '__main__': # # script = "Generator_modulation_up" - script = "Generator_Invobn_config3" + script = "Generator_2mask" # script = "Generator_ori_modulation_config" # script = "Generator_ori_config" class_name = "Generator" @@ -30,12 +30,13 @@ if __name__ == '__main__': model_config={ "id_dim": 512, "g_kernel_size": 3, - "in_channel":16, - "res_num": 9, + "in_channel":64, + "res_num": 3, # "up_mode": "nearest", "up_mode": "bilinear", "aggregator": "eca_invo", - "res_mode": "conv" + "res_mode": "conv", + "norm": "bn" } diff --git a/id_cos.py b/id_cos.py new file mode 100644 index 0000000..7a363de --- /dev/null +++ b/id_cos.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: id_cos.py +# Created Date: Friday March 25th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 29th March 2022 11:58:30 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# +import cv2 +from PIL import Image + +import torch +import torch.nn.functional as F +from torchvision import transforms +from insightface_func.face_detect_crop_single import Face_detect_crop + +from arcface_torch.backbones.iresnet import iresnet100 + +if __name__ == "__main__": + imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1) + imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1) + transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + arcface_ckpt = "./arcface_ckpt/arcface_checkpoint.tar" + arcface1 = torch.load(arcface_ckpt, map_location=torch.device("cpu")) + arcface = arcface1['model'].module + arcface.eval() + + root1 = "G:/VGGFace2-HQ/VGGface2_ffhq_align_256_9_28_512_bygfpgan/n000002/" + root2 = "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan/n000002/" + + # arcface_ckpt = "./arcface_torch/checkpoints/backbone.pth" # backbone.pth glint360k_cosface_r100_fp16_backbone.pth + # arcface = iresnet100(pretrained=False, fp16=False) + # arcface.load_state_dict(torch.load(arcface_ckpt, map_location='cpu')) + # arcface.eval() + + # id1 = "G:/swap_data/ID/hinton.jpg" + # id2 = "G:/hififace-master/hififace-master/assets/inference_samples/hififace/img-172.jpg" + id1 = root2 + "0003_01.jpg" + id2 = root2 + "0036_01.jpg" + + mode = "none" + cos_loss = torch.nn.CosineSimilarity() + # detect = Face_detect_crop(name='antelope', root='./insightface_func/models') + # detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode) + id_img = cv2.imread(id1) + # id_img_align_crop, _ = detect.get(id_img,256) + # cv2.imwrite("id1_crop.png",id_img_align_crop[0]) + # id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB)) + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB)) + id_img = transformer_Arcface(id_img_align_crop_pil) + id_img = id_img.unsqueeze(0) + id_img = F.interpolate(id_img,size=(112,112), mode='bicubic') + # id_img = (id_img-0.5)*2.0 + latend_id = arcface(id_img) + latend_id = F.normalize(latend_id, p=2, dim=1) + + id_img2 = cv2.imread(id2) + # id_img_align_crop2, _ = detect.get(id_img2,256) + # cv2.imwrite("id2_crop.png",id_img_align_crop2[0]) + # id_img_align_crop_pil2 = Image.fromarray(cv2.cvtColor(id_img_align_crop2[0],cv2.COLOR_BGR2RGB)) + id_img_align_crop_pil2 = Image.fromarray(cv2.cvtColor(id_img2,cv2.COLOR_BGR2RGB)) + id_img2 = transformer_Arcface(id_img_align_crop_pil2) + id_img2 = id_img2.unsqueeze(0) + id_img2 = F.interpolate(id_img2,size=(112,112), mode='bicubic') + # id_img2 = (id_img2-0.5)*2.0 + latend_id2 = arcface(id_img2) + latend_id2 = F.normalize(latend_id2, p=2, dim=1) + + cos_dis = 1 - cos_loss(latend_id, latend_id2) + print("cosine similarity:", cos_dis.item()) \ No newline at end of file diff --git a/metrics/equivariance.py b/metrics/equivariance.py new file mode 100644 index 0000000..d5559ac --- /dev/null +++ b/metrics/equivariance.py @@ -0,0 +1,267 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper +"Alias-Free Generative Adversarial Networks".""" + +import copy +import numpy as np +import torch +import torch.fft +from torch_utils.ops import upfirdn2d +from . import metric_utils + +#---------------------------------------------------------------------------- +# Utilities. + +def sinc(x): + y = (x * np.pi).abs() + z = torch.sin(y) / y.clamp(1e-30, float('inf')) + return torch.where(y < 1e-30, torch.ones_like(x), z) + +def lanczos_window(x, a): + x = x.abs() / a + return torch.where(x < 1, sinc(x), torch.zeros_like(x)) + +def rotation_matrix(angle): + angle = torch.as_tensor(angle).to(torch.float32) + mat = torch.eye(3, device=angle.device) + mat[0, 0] = angle.cos() + mat[0, 1] = angle.sin() + mat[1, 0] = -angle.sin() + mat[1, 1] = angle.cos() + return mat + +#---------------------------------------------------------------------------- +# Apply integer translation to a batch of 2D images. Corresponds to the +# operator T_x in Appendix E.1. + +def apply_integer_translation(x, tx, ty): + _N, _C, H, W = x.shape + tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) + ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) + ix = tx.round().to(torch.int64) + iy = ty.round().to(torch.int64) + + z = torch.zeros_like(x) + m = torch.zeros_like(x) + if abs(ix) < W and abs(iy) < H: + y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)] + z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y + m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1 + return z, m + +#---------------------------------------------------------------------------- +# Apply integer translation to a batch of 2D images. Corresponds to the +# operator T_x in Appendix E.2. + +def apply_fractional_translation(x, tx, ty, a=3): + _N, _C, H, W = x.shape + tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device) + ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device) + ix = tx.floor().to(torch.int64) + iy = ty.floor().to(torch.int64) + fx = tx - ix + fy = ty - iy + b = a - 1 + + z = torch.zeros_like(x) + zx0 = max(ix - b, 0) + zy0 = max(iy - b, 0) + zx1 = min(ix + a, 0) + W + zy1 = min(iy + a, 0) + H + if zx0 < zx1 and zy0 < zy1: + taps = torch.arange(a * 2, device=x.device) - b + filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0) + filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1) + y = x + y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0]) + y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a]) + y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)] + z[:, :, zy0:zy1, zx0:zx1] = y + + m = torch.zeros_like(x) + mx0 = max(ix + a, 0) + my0 = max(iy + a, 0) + mx1 = min(ix - b, 0) + W + my1 = min(iy - b, 0) + H + if mx0 < mx1 and my0 < my1: + m[:, :, my0:my1, mx0:mx1] = 1 + return z, m + +#---------------------------------------------------------------------------- +# Construct an oriented low-pass filter that applies the appropriate +# bandlimit with respect to the input and output of the given affine 2D +# image transformation. + +def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): + assert a <= amax < aflt + mat = torch.as_tensor(mat).to(torch.float32) + + # Construct 2D filter taps in input & output coordinate spaces. + taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) + yi, xi = torch.meshgrid(taps, taps) + xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) + + # Convolution of two oriented 2D sinc filters. + fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in) + fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out) + f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real + + # Convolution of two oriented 2D Lanczos windows. + wi = lanczos_window(xi, a) * lanczos_window(yi, a) + wo = lanczos_window(xo, a) * lanczos_window(yo, a) + w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real + + # Construct windowed FIR filter. + f = f * w + + # Finalize. + c = (aflt - amax) * up + f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] + f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) + f = f / f.sum([0,2], keepdim=True) / (up ** 2) + f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] + return f + +#---------------------------------------------------------------------------- +# Apply the given affine transformation to a batch of 2D images. + +def apply_affine_transformation(x, mat, up=4, **filter_kwargs): + _N, _C, H, W = x.shape + mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) + + # Construct filter. + f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) + assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 + p = f.shape[0] // 2 + + # Construct sampling grid. + theta = mat.inverse() + theta[:2, 2] *= 2 + theta[0, 2] += 1 / up / W + theta[1, 2] += 1 / up / H + theta[0, :] *= W / (W + p / up * 2) + theta[1, :] *= H / (H + p / up * 2) + theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) + g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) + + # Resample image. + y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) + z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) + + # Form mask. + m = torch.zeros_like(y) + c = p * 2 + 1 + m[:, :, c:-c, c:-c] = 1 + m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) + return z, m + +#---------------------------------------------------------------------------- +# Apply fractional rotation to a batch of 2D images. Corresponds to the +# operator R_\alpha in Appendix E.3. + +def apply_fractional_rotation(x, angle, a=3, **filter_kwargs): + angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) + mat = rotation_matrix(angle) + return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs) + +#---------------------------------------------------------------------------- +# Modify the frequency content of a batch of 2D images as if they had undergo +# fractional rotation -- but without actually rotating them. Corresponds to +# the operator R^*_\alpha in Appendix E.3. + +def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs): + angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device) + mat = rotation_matrix(-angle) + f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs) + y = upfirdn2d.filter2d(x=x, f=f) + m = torch.zeros_like(y) + c = f.shape[0] // 2 + m[:, :, c:-c, c:-c] = 1 + return y, m + +#---------------------------------------------------------------------------- +# Compute the selected equivariance metrics for the given generator. + +def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False): + assert compute_eqt_int or compute_eqt_frac or compute_eqr + + # Setup generator and labels. + G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) + I = torch.eye(3, device=opts.device) + M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None) + if M is None: + raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations') + c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) + + # Sampling loop. + sums = None + progress = opts.progress.sub(tag='eq sampling', num_items=num_samples) + for batch_start in range(0, num_samples, batch_size * opts.num_gpus): + progress.update(batch_start) + s = [] + + # Randomize noise buffers, if any. + for name, buf in G.named_buffers(): + if name.endswith('.noise_const'): + buf.copy_(torch.randn_like(buf)) + + # Run mapping network. + z = torch.randn([batch_size, G.z_dim], device=opts.device) + c = next(c_iter) + ws = G.mapping(z=z, c=c) + + # Generate reference image. + M[:] = I + orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + + # Integer translation (EQ-T). + if compute_eqt_int: + t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max + t = (t * G.img_resolution).round() / G.img_resolution + M[:] = I + M[:2, 2] = -t + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, mask = apply_integer_translation(orig, t[0], t[1]) + s += [(ref - img).square() * mask, mask] + + # Fractional translation (EQ-T_frac). + if compute_eqt_frac: + t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max + M[:] = I + M[:2, 2] = -t + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, mask = apply_fractional_translation(orig, t[0], t[1]) + s += [(ref - img).square() * mask, mask] + + # Rotation (EQ-R). + if compute_eqr: + angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi) + M[:] = rotation_matrix(-angle) + img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs) + ref, ref_mask = apply_fractional_rotation(orig, angle) + pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle) + mask = ref_mask * pseudo_mask + s += [(ref - pseudo).square() * mask, mask] + + # Accumulate results. + s = torch.stack([x.to(torch.float64).sum() for x in s]) + sums = sums + s if sums is not None else s + progress.update(num_samples) + + # Compute PSNRs. + if opts.num_gpus > 1: + torch.distributed.all_reduce(sums) + sums = sums.cpu() + mses = sums[0::2] / sums[1::2] + psnrs = np.log10(2) * 20 - mses.log10() * 10 + psnrs = tuple(psnrs.numpy()) + return psnrs[0] if len(psnrs) == 1 else psnrs + +#---------------------------------------------------------------------------- diff --git a/metrics/frechet_inception_distance.py b/metrics/frechet_inception_distance.py new file mode 100644 index 0000000..f99c828 --- /dev/null +++ b/metrics/frechet_inception_distance.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Frechet Inception Distance (FID) from the paper +"GANs trained by a two time-scale update rule converge to a local Nash +equilibrium". Matches the original implementation by Heusel et al. at +https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" + +import numpy as np +import scipy.linalg +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_fid(opts, max_real, num_gen, swav=False, sfid=False): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. + + mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, swav=swav, sfid=sfid).get_mean_cov() + + mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, swav=swav, sfid=sfid).get_mean_cov() + + if opts.rank != 0: + return float('nan') + + m = np.square(mu_gen - mu_real).sum() + s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) + return float(fid) + +#---------------------------------------------------------------------------- diff --git a/metrics/inception_score.py b/metrics/inception_score.py new file mode 100644 index 0000000..e0a3a44 --- /dev/null +++ b/metrics/inception_score.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Inception Score (IS) from the paper "Improved techniques for training +GANs". Matches the original implementation by Salimans et al. at +https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" + +import numpy as np +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_is(opts, num_gen, num_splits): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. + + gen_probs = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + capture_all=True, max_items=num_gen).get_all() + + if opts.rank != 0: + return float('nan'), float('nan') + + scores = [] + for i in range(num_splits): + part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] + kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) + kl = np.mean(np.sum(kl, axis=1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)), float(np.std(scores)) + +#---------------------------------------------------------------------------- diff --git a/metrics/kernel_inception_distance.py b/metrics/kernel_inception_distance.py new file mode 100644 index 0000000..d69325c --- /dev/null +++ b/metrics/kernel_inception_distance.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Kernel Inception Distance (KID) from the paper "Demystifying MMD +GANs". Matches the original implementation by Binkowski et al. at +https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" + +import numpy as np +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' + detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. + + real_features = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() + + gen_features = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() + + if opts.rank != 0: + return float('nan') + + n = real_features.shape[1] + m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) + t = 0 + for _subset_idx in range(num_subsets): + x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] + y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] + a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 + b = (x @ y.T / n + 1) ** 3 + t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m + kid = t / num_subsets / m + return float(kid) + +#---------------------------------------------------------------------------- diff --git a/metrics/metric_main.py b/metrics/metric_main.py new file mode 100644 index 0000000..27adc6e --- /dev/null +++ b/metrics/metric_main.py @@ -0,0 +1,151 @@ +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Main API for computing and reporting quality metrics.""" + +import os +import time +import json +import torch +import dnnlib + +from . import metric_utils +from . import frechet_inception_distance +from . import kernel_inception_distance +from . import precision_recall +from . import perceptual_path_length +from . import inception_score +from . import equivariance + +#---------------------------------------------------------------------------- + +_metric_dict = dict() # name => fn + +def register_metric(fn): + assert callable(fn) + _metric_dict[fn.__name__] = fn + return fn + +def is_valid_metric(metric): + return metric in _metric_dict + +def list_valid_metrics(): + return list(_metric_dict.keys()) + +#---------------------------------------------------------------------------- + +def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. + assert is_valid_metric(metric) + opts = metric_utils.MetricOptions(**kwargs) + + # Calculate. + start_time = time.time() + results = _metric_dict[metric](opts) + total_time = time.time() - start_time + + # Broadcast results. + for key, value in list(results.items()): + if opts.num_gpus > 1: + value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) + torch.distributed.broadcast(tensor=value, src=0) + value = float(value.cpu()) + results[key] = value + + # Decorate with metadata. + return dnnlib.EasyDict( + results = dnnlib.EasyDict(results), + metric = metric, + total_time = total_time, + total_time_str = dnnlib.util.format_time(total_time), + num_gpus = opts.num_gpus, + ) + +#---------------------------------------------------------------------------- + +def report_metric(result_dict, run_dir=None, snapshot_pkl=None): + metric = result_dict['metric'] + assert is_valid_metric(metric) + if run_dir is not None and snapshot_pkl is not None: + snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) + + jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) + print(jsonl_line) + if run_dir is not None and os.path.isdir(run_dir): + with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: + f.write(jsonl_line + '\n') + +#---------------------------------------------------------------------------- +# Recommended metrics. + +@register_metric +def fid50k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) + return dict(fid50k_full=fid) + +@register_metric +def fid10k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=10000) + return dict(fid10k_full=fid) + +@register_metric +def kid50k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) + return dict(kid50k_full=kid) + +@register_metric +def pr50k3_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) + return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) + +@register_metric +def ppl2_wend(opts): + ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) + return dict(ppl2_wend=ppl) + +@register_metric +def eqt50k_int(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) + return dict(eqt50k_int=psnr) + +@register_metric +def eqt50k_frac(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) + return dict(eqt50k_frac=psnr) + +@register_metric +def eqr50k(opts): + opts.G_kwargs.update(force_fp32=True) + psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) + return dict(eqr50k=psnr) + +# Legacy metrics. + +@register_metric +def fid50k(opts): + opts.dataset_kwargs.update(max_size=None) + fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) + return dict(fid50k=fid) + +@register_metric +def kid50k(opts): + opts.dataset_kwargs.update(max_size=None) + kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) + return dict(kid50k=kid) + +@register_metric +def pr50k3(opts): + opts.dataset_kwargs.update(max_size=None) + precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) + return dict(pr50k3_precision=precision, pr50k3_recall=recall) + +@register_metric +def is50k(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) + return dict(is50k_mean=mean, is50k_std=std) diff --git a/metrics/metric_utils.py b/metrics/metric_utils.py new file mode 100644 index 0000000..d7e3960 --- /dev/null +++ b/metrics/metric_utils.py @@ -0,0 +1,298 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous utilities used internally by the quality metrics.""" + +import os +import time +import hashlib +import pickle +import copy +import uuid +import numpy as np +import torch +import dnnlib +from tqdm import tqdm + +#---------------------------------------------------------------------------- + +class MetricOptions: + def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True, run_dir=None, cur_nimg=None, snapshot_pkl=None): + assert 0 <= rank < num_gpus + self.G = G + self.G_kwargs = dnnlib.EasyDict(G_kwargs) + self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) + self.num_gpus = num_gpus + self.rank = rank + self.device = device if device is not None else torch.device('cuda', rank) + self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() + self.cache = cache + self.run_dir = run_dir + self.cur_nimg = cur_nimg + self.snapshot_pkl = snapshot_pkl + +#---------------------------------------------------------------------------- + +_feature_detector_cache = dict() + +def get_feature_detector_name(url): + return os.path.splitext(url.split('/')[-1])[0] + +def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): + assert 0 <= rank < num_gpus + key = (url, device) + if key not in _feature_detector_cache: + is_leader = (rank == 0) + if not is_leader and num_gpus > 1: + torch.distributed.barrier() # leader goes first + with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: + _feature_detector_cache[key] = pickle.load(f).to(device) + if is_leader and num_gpus > 1: + torch.distributed.barrier() # others follow + return _feature_detector_cache[key] + +#---------------------------------------------------------------------------- + +def iterate_random_labels(opts, batch_size): + if opts.G.c_dim == 0: + c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) + while True: + yield c + else: + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + while True: + c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)] + c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) + yield c + +#---------------------------------------------------------------------------- + +class FeatureStats: + def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): + self.capture_all = capture_all + self.capture_mean_cov = capture_mean_cov + self.max_items = max_items + self.num_items = 0 + self.num_features = None + self.all_features = None + self.raw_mean = None + self.raw_cov = None + + def set_num_features(self, num_features): + if self.num_features is not None: + assert num_features == self.num_features + else: + self.num_features = num_features + self.all_features = [] + self.raw_mean = np.zeros([num_features], dtype=np.float64) + self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) + + def is_full(self): + return (self.max_items is not None) and (self.num_items >= self.max_items) + + def append(self, x): + x = np.asarray(x, dtype=np.float32) + assert x.ndim == 2 + if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): + if self.num_items >= self.max_items: + return + x = x[:self.max_items - self.num_items] + + self.set_num_features(x.shape[1]) + self.num_items += x.shape[0] + if self.capture_all: + self.all_features.append(x) + if self.capture_mean_cov: + x64 = x.astype(np.float64) + self.raw_mean += x64.sum(axis=0) + self.raw_cov += x64.T @ x64 + + def append_torch(self, x, num_gpus=1, rank=0): + assert isinstance(x, torch.Tensor) and x.ndim == 2 + assert 0 <= rank < num_gpus + if num_gpus > 1: + ys = [] + for src in range(num_gpus): + y = x.clone() + torch.distributed.broadcast(y, src=src) + ys.append(y) + x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples + self.append(x.cpu().numpy()) + + def get_all(self): + assert self.capture_all + return np.concatenate(self.all_features, axis=0) + + def get_all_torch(self): + return torch.from_numpy(self.get_all()) + + def get_mean_cov(self): + assert self.capture_mean_cov + mean = self.raw_mean / self.num_items + cov = self.raw_cov / self.num_items + cov = cov - np.outer(mean, mean) + return mean, cov + + def save(self, pkl_file): + with open(pkl_file, 'wb') as f: + pickle.dump(self.__dict__, f) + + @staticmethod + def load(pkl_file): + with open(pkl_file, 'rb') as f: + s = dnnlib.EasyDict(pickle.load(f)) + obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) + obj.__dict__.update(s) + return obj + +#---------------------------------------------------------------------------- + +class ProgressMonitor: + def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): + self.tag = tag + self.num_items = num_items + self.verbose = verbose + self.flush_interval = flush_interval + self.progress_fn = progress_fn + self.pfn_lo = pfn_lo + self.pfn_hi = pfn_hi + self.pfn_total = pfn_total + self.start_time = time.time() + self.batch_time = self.start_time + self.batch_items = 0 + if self.progress_fn is not None: + self.progress_fn(self.pfn_lo, self.pfn_total) + + def update(self, cur_items): + assert (self.num_items is None) or (cur_items <= self.num_items) + if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): + return + cur_time = time.time() + total_time = cur_time - self.start_time + time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) + if (self.verbose) and (self.tag is not None): + print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') + self.batch_time = cur_time + self.batch_items = cur_items + + if (self.progress_fn is not None) and (self.num_items is not None): + self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) + + def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): + return ProgressMonitor( + tag = tag, + num_items = num_items, + flush_interval = flush_interval, + verbose = self.verbose, + progress_fn = self.progress_fn, + pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, + pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, + pfn_total = self.pfn_total, + ) + +#---------------------------------------------------------------------------- + +def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, swav=False, sfid=False, **stats_kwargs): + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + if data_loader_kwargs is None: + data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) + + # Try to lookup from cache. + cache_file = None + if opts.cache: + det_name = get_feature_detector_name(detector_url) + + # Choose cache file name. + args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs) + md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) + cache_tag = f'{dataset.name}-{det_name}-{md5.hexdigest()}' + cache_file = os.path.join('.', 'dnnlib', 'gan-metrics', cache_tag + '.pkl') + # cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') + + # Check if the file exists (all processes must agree). + flag = os.path.isfile(cache_file) if opts.rank == 0 else False + if opts.num_gpus > 1: + flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) + torch.distributed.broadcast(tensor=flag, src=0) + flag = (float(flag.cpu()) != 0) + + # Load. + if flag: + return FeatureStats.load(cache_file) + + print('Calculating the stats for this dataset the first time\n') + print(f'Saving them to {cache_file}') + + # Initialize. + num_items = len(dataset) + if max_items is not None: + num_items = min(num_items, max_items) + stats = FeatureStats(max_items=num_items, **stats_kwargs) + progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) + + # get detector + detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) + + # Main loop. + item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] + for images, _labels in tqdm(torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs)): + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + + with torch.no_grad(): + features = detector(images.to(opts.device), **detector_kwargs) + + stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) + progress.update(stats.num_items) + + # Save to cache. + if cache_file is not None and opts.rank == 0: + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + temp_file = cache_file + '.' + uuid.uuid4().hex + stats.save(temp_file) + os.replace(temp_file, cache_file) # atomic + return stats + +#---------------------------------------------------------------------------- + +def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, swav=False, sfid=False, **stats_kwargs): + if batch_gen is None: + batch_gen = min(batch_size, 4) + assert batch_size % batch_gen == 0 + + # Setup generator and labels. + G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) + c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen) + + # Initialize. + stats = FeatureStats(**stats_kwargs) + assert stats.max_items is not None + progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) + + # get detector + detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) + + # Main loop. + while not stats.is_full(): + images = [] + for _i in range(batch_size // batch_gen): + z = torch.randn([batch_gen, G.z_dim], device=opts.device) + # img = G(z=z, c=next(c_iter), truncation_psi=0.1, **opts.G_kwargs) + img = G(z=z, c=next(c_iter), **opts.G_kwargs) + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + images.append(img) + images = torch.cat(images) + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + + with torch.no_grad(): + features = detector(images.to(opts.device), **detector_kwargs) + + stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) + progress.update(stats.num_items) + return stats diff --git a/metrics/perceptual_path_length.py b/metrics/perceptual_path_length.py new file mode 100644 index 0000000..c68519f --- /dev/null +++ b/metrics/perceptual_path_length.py @@ -0,0 +1,125 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator +Architecture for Generative Adversarial Networks". Matches the original +implementation by Karras et al. at +https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" + +import copy +import numpy as np +import torch +from . import metric_utils + +#---------------------------------------------------------------------------- + +# Spherical interpolation of a batch of vectors. +def slerp(a, b, t): + a = a / a.norm(dim=-1, keepdim=True) + b = b / b.norm(dim=-1, keepdim=True) + d = (a * b).sum(dim=-1, keepdim=True) + p = t * torch.acos(d) + c = b - d * a + c = c / c.norm(dim=-1, keepdim=True) + d = a * torch.cos(p) + c * torch.sin(p) + d = d / d.norm(dim=-1, keepdim=True) + return d + +#---------------------------------------------------------------------------- + +class PPLSampler(torch.nn.Module): + def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): + assert space in ['z', 'w'] + assert sampling in ['full', 'end'] + super().__init__() + self.G = copy.deepcopy(G) + self.G_kwargs = G_kwargs + self.epsilon = epsilon + self.space = space + self.sampling = sampling + self.crop = crop + self.vgg16 = copy.deepcopy(vgg16) + + def forward(self, c): + # Generate random latents and interpolation t-values. + t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) + z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) + + # Interpolate in W or Z. + if self.space == 'w': + w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) + wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) + wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) + else: # space == 'z' + zt0 = slerp(z0, z1, t.unsqueeze(1)) + zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) + wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) + + # Randomize noise buffers. + for name, buf in self.G.named_buffers(): + if name.endswith('.noise_const'): + buf.copy_(torch.randn_like(buf)) + + # Generate images. + img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) + + # Center crop. + if self.crop: + assert img.shape[2] == img.shape[3] + c = img.shape[2] // 8 + img = img[:, :, c*3 : c*7, c*2 : c*6] + + # Downsample to 256x256. + factor = self.G.img_resolution // 256 + if factor > 1: + img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) + + # Scale dynamic range from [-1,1] to [0,255]. + img = (img + 1) * (255 / 2) + if self.G.img_channels == 1: + img = img.repeat([1, 3, 1, 1]) + + # Evaluate differential LPIPS. + lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) + dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 + return dist + +#---------------------------------------------------------------------------- + +def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): + vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' + vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) + + # Setup sampler and labels. + sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) + sampler.eval().requires_grad_(False).to(opts.device) + c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) + + # Sampling loop. + dist = [] + progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) + for batch_start in range(0, num_samples, batch_size * opts.num_gpus): + progress.update(batch_start) + x = sampler(next(c_iter)) + for src in range(opts.num_gpus): + y = x.clone() + if opts.num_gpus > 1: + torch.distributed.broadcast(y, src=src) + dist.append(y) + progress.update(num_samples) + + # Compute PPL. + if opts.rank != 0: + return float('nan') + dist = torch.cat(dist)[:num_samples].cpu().numpy() + lo = np.percentile(dist, 1, interpolation='lower') + hi = np.percentile(dist, 99, interpolation='higher') + ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() + return float(ppl) + +#---------------------------------------------------------------------------- diff --git a/metrics/precision_recall.py b/metrics/precision_recall.py new file mode 100644 index 0000000..120ef80 --- /dev/null +++ b/metrics/precision_recall.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Precision/Recall (PR) from the paper "Improved Precision and Recall +Metric for Assessing Generative Models". Matches the original implementation +by Kynkaanniemi et al. at +https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" + +import torch +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): + assert 0 <= rank < num_gpus + num_cols = col_features.shape[0] + num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus + col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) + dist_batches = [] + for col_batch in col_batches[rank :: num_gpus]: + dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] + for src in range(num_gpus): + dist_broadcast = dist_batch.clone() + if num_gpus > 1: + torch.distributed.broadcast(dist_broadcast, src=src) + dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) + return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None + +#---------------------------------------------------------------------------- + +def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): + detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' + detector_kwargs = dict(return_features=True) + + real_features = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) + + gen_features = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) + + results = dict() + for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: + kth = [] + for manifold_batch in manifold.split(row_batch_size): + dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) + kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) + kth = torch.cat(kth) if opts.rank == 0 else None + pred = [] + for probes_batch in probes.split(row_batch_size): + dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) + pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) + results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') + return results['precision'], results['recall'] + +#---------------------------------------------------------------------------- diff --git a/speed_test.py b/speed_test.py index 0489784..bea143a 100644 --- a/speed_test.py +++ b/speed_test.py @@ -5,7 +5,7 @@ # Created Date: Thursday February 10th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 3rd March 2022 6:44:57 pm +# Last Modified: Sunday, 3rd April 2022 6:39:26 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -23,7 +23,7 @@ if __name__ == '__main__': # script = "Generator_modulation_up" # script = "Generator_modulation_up" # script = "Generator_Invobn_config3" - script = "Generator_ori_config" + script = "Generator_maskhead_config" # script = "Generator_ori_config" class_name = "Generator" arcface_ckpt= "arcface_ckpt/arcface_checkpoint.tar" @@ -31,11 +31,12 @@ if __name__ == '__main__': "id_dim": 512, "g_kernel_size": 3, "in_channel":64, - "res_num": 9, + "res_num": 0, # "up_mode": "nearest", "up_mode": "bilinear", "aggregator": "eca_invo", - "res_mode": "eca_invo" + "res_mode": "eca_invo", + "norm": "bn" } os.environ['CUDA_VISIBLE_DEVICES'] = str(0) print("GPU used : ", os.environ['CUDA_VISIBLE_DEVICES']) @@ -57,7 +58,7 @@ if __name__ == '__main__': id_latent = torch.rand((4,512)).cuda() # cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0] - attr = torch.rand((4,3,224,224)).cuda() + attr = torch.rand((4,3,512,512)).cuda() import datetime start_time = time.time() diff --git a/start_train.sh b/start_train.sh index 147a2df..852f856 100644 --- a/start_train.sh +++ b/start_train.sh @@ -1,3 +1 @@ - - -nohup python train_multigpu.py > cycle_res2.log 2>&1 & \ No newline at end of file +nohup python train_multigpu.py > 2maskloss2_1.log 2>&1 & \ No newline at end of file diff --git a/test.py b/test.py index 42cb60e..777ec55 100644 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Friday, 25th March 2022 6:13:32 pm +# Last Modified: Saturday, 23rd April 2022 10:03:56 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -32,33 +32,39 @@ def getParameters(): parser = argparse.ArgumentParser() # general settings - parser.add_argument('-v', '--version', type=str, default='cycle_res1', #cycle_res1 cycle_res2 cycle_res3 cycle_lstu1 depthwise depthwise_config0 Invobn_resinvo1 + parser.add_argument('-v', '--version', type=str, default='maskhead_recfm_2', # maskhead_recfm_2 maskloss_2 resskip_recfm_1 maskhead_recfm_1 maskhead_recfm_2 resskip_2 resskip_3 resskip_4 resskip_9 cycle_res1 cycle_res2 cycle_res3 cycle_lstu1 depthwise depthwise_config0 Invobn_resinvo1 help="version name for train, test, finetune") parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU - parser.add_argument('-s', '--checkpoint_step', type=int, default=180000, + parser.add_argument('-s', '--checkpoint_step', type=int, default=480000, help="checkpoint epoch for test phase or finetune phase") parser.add_argument('--start_checkpoint_step', type=int, default=10000, help="checkpoint epoch for test phase or finetune phase") # test - parser.add_argument('-t', '--test_script_name', type=str, default='image_list') #image_list image_nofusion + parser.add_argument('-t', '--test_script_name', type=str, default='tester_video') # video image_w_mask image_list_w_mask image_list image_nofusion parser.add_argument('-b', '--batch_size', type=int, default=1) - parser.add_argument('-n', '--node_ip', type=str, default='101.33.242.26') # 101.33.242.26 2001:da8:8000:6880:f284:d61c:3c76:f9cb + parser.add_argument('-n', '--node_ip', type=str, default='localhost') # localhost 119.29.91.52 101.33.242.26 2001:da8:8000:6880:f284:d61c:3c76:f9cb parser.add_argument('--crop_mode', type=str, default="vggface", choices=['ffhq','vggface'], help='crop mode for face detector') - parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg') # 'G:\\swap_data\\FF++\\996_img_00288.jpg' G:\\swap_data\\ID\\hinton.jpg + parser.add_argument('-i', '--id_imgs', type=str, default='G:/simswap/inputdata/2/2/10.jpg') # G:/simswap/inputdata/2/2/10.jpg G:\\swap_data\\ID\\dlrb2.jpeg 'G:\\swap_data\\FF++\\996_img_00288.jpg' G:\\swap_data\\ID\\hinton.jpg # parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg') - parser.add_argument('-a', '--attr_files', type=str, default='G:/swap_data/video/1', # G:\\swap_data\\ID\\bengio.jpg G:\\swap_data\\FF++\\056_img_00228.jpg - help="file path for attribute images or video") + parser.add_argument('-a', '--attr_files', type=str, default='G:/simswap/inputdata/3/100297.mp4', # G:/swap_data/video/1 G:\\swap_data\\ID\\bengio.jpg G:\\swap_data\\FF++\\056_img_00228.jpg + help="file path for attribute images or video") # G:/swap_data/video/2/G2218_Trim.mp4 parser.add_argument('--img_list_txt', type=str, default='./test_imgs_list.txt', # G:\\swap_data\\ID\\bengio.jpg G:\\swap_data\\FF++\\056_img_00228.jpg help="file path for image list txt") - + parser.add_argument('--record_metric', type=str2bool, default='False', + help="Whether to record the cosine similarity") + parser.add_argument('--save_mask', type=str2bool, default='False', + help="Whether to save the mask") + + parser.add_argument('--preprocess', type=str2bool, default='False', help='Whether to employ preprocess') + parser.add_argument('--use_specified_data', action='store_true') parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files') - parser.add_argument('--use_specified_data_paths', type=str2bool, default='True', choices=['True', 'False'], help='use the specified save dir') - parser.add_argument('--specified_save_path', type=str, default="", help='save results to specified dir') + parser.add_argument('--use_specified_data_paths', type=str2bool, default='False', help='use the specified save dir') + parser.add_argument('--specified_save_path', type=str, default="G:/swap_data/video/results3", help='save results to specified dir') # # logs (does not to be changed in most time) # parser.add_argument('--dataloader_workers', type=int, default=6) @@ -266,7 +272,7 @@ def main(): # Display the test information # TODO modify below lines to display your configuration information - moduleName = "test_scripts.tester_" + sys_state["test_script_name"] + moduleName = "test_scripts." + sys_state["test_script_name"] print("Start to run test script: {}".format(moduleName)) print("Test version: %s"%sys_state["version"]) print("Test Script Name: %s"%sys_state["test_script_name"]) diff --git a/test_imgs_list.txt b/test_imgs_list.txt index 1236966..cc7c842 100644 --- a/test_imgs_list.txt +++ b/test_imgs_list.txt @@ -1,17 +1,47 @@ -G:/swap_data/ID/hsw.jpg;G:/swap_data/ID/cruise.jpg;fusion +G:/swap_data/ID/hsw.jpg;G:/swap_data/ID/tom-cruise-wallpaper-hd-wallpaper-43864908.jpg;fusion G:/swap_data/ID/06.jpg;G:/swap_data/ID/hm.jpg;fusion -G:/swap_data/ID/1.png;G:/swap_data/ID/2.png;fusion +G:/swap_data/ID/1.png;G:/swap_data/ID/2.jpg;fusion G:/swap_data/ID/hinton.jpg;G:/swap_data/ID/bengio.jpg;fusion G:/swap_data/ID/hsw.jpg;G:/swap_data/ID/ts1.jpg;fusion G:/swap_data/ID/hsw.jpg;G:/swap_data/ID/lyf2.jpeg;fusion G:/swap_data/FF++/996_img_00288.jpg;G:/swap_data/FF++/056_img_00228.jpg;no G:/swap_data/ID/gxt3.jpeg;G:/swap_data/ID/lyf5.jpeg;fusion G:/swap_data/ID/hsw.jpg;G:/swap_data/ID/hb.jpeg;fusion -G:/swap_data/ID/06.jpg;G:/swap_data/ID/cruise2.jpeg;fusion +G:/swap_data/ID/06.jpg;G:/swap_data/ID/2130429-1216_tom_cruise_genes.jpg;fusion G:/swap_data/FF++/019_img_00139.jpg;G:/swap_data/FF++/018_img_00088.jpg;no G:/swap_data/FF++/052_img_00033.jpg;G:/swap_data/FF++/108_img_00150.jpg;no G:/swap_data/FF++/011_img_00448.jpg;G:/swap_data/FF++/805_img_00252.jpg;no G:/swap_data/FF++/638_img_00000.jpg;G:/swap_data/FF++/640_img_00248.jpg;no G:/swap_data/FF++/819_img_00651.jpg;G:/swap_data/FF++/786_img_00156.jpg;no G:/swap_data/FF++/416_img_00032.jpg;G:/swap_data/FF++/342_img_00062.jpg;no -G:/swap_data/ID/hsw.jpg;G:/swap_data/ID/lyf4.jpeg;fusion \ No newline at end of file +G:/swap_data/ID/hsw.jpg;G:/swap_data/ID/lyf4.jpeg;fusion +G:/swap_data/ID/lxq.jpeg;G:/swap_data/ID/zyq2.jpeg;fusion +G:/swap_data/ID/dlrb2.jpeg;G:/swap_data/ID/zyq.jpeg;fusion +G:/swap_data/ID/jl.jpg;G:/swap_data/ID/fbb2.jpg;fusion +G:/swap_data/ID/bengio.jpg;G:/swap_data/ID/messi.jpg;fusion +G:/swap_data/ID/elon-musk-hero-image.jpeg;G:/swap_data/ID/pexels-ichad-windhiagiri-3989151.jpg;fusion +G:/swap_data/ID/hsw.jpg;G:/swap_data/ID/ts.jpeg;fusion +G:/swap_data/ID/hsw.jpg;G:/swap_data/ID/zzq.jpeg;fusion +G:/swap_data/ID/06.jpg;G:/swap_data/ID/audrey-hepburn-63115_960_720.jpg;fusion +G:/swap_data/ID/ScarlettJohansson1.jpg;G:/swap_data/ID/chris-evans-captain-america.jpg;fusion +G:/swap_data/ID/sjch.jpeg;G:/swap_data/ID/zym9.jpeg;fusion +G:/swap_data/ID/ScarlettJohansson1.jpg;G:/swap_data/ID/captainamerica.jpg;fusion +G:/swap_data/ID/lyf2.jpeg;G:/swap_data/ID/Lane-Ten-Things-about-Wonder-Woman.jpg;fusion +G:/swap_data/ID/dlrb2.jpeg;G:/swap_data/ID/lyf2.jpeg;fusion +G:/swap_data/ID/hsw.jpg;G:/swap_data/ID/hb.jpeg;fusion +G:/swap_data/ID/RobertDowneyJr2.jpg;G:/swap_data/ID/GettyImages.png;fusion +G:/swap_data/ID/RobertDowneyJr2.jpg;G:/swap_data/ID/leonardo.jpg;fusion +G:/swap_data/ID/06.jpg;G:/swap_data/ID/RobertDowneyJr2.jpg;fusion +G:/swap_data/ID/lyf2.jpeg;G:/swap_data/video/5/FMQpZa5aIAAGsvb.jpg;fusion +G:/swap_data/ID/dlrb2.jpeg;G:/swap_data/video/5/FMQpZa5aIAAGsvb.jpg;fusion +G:/swap_data/ID/lyf2.jpeg;G:/swap_data/video/5/FIpgTdIaIAAsA6f.jpg;fusion +G:/swap_data/ID/dlrb2.jpeg;G:/swap_data/video/5/FIpgTdIaIAAsA6f.jpg;fusion +G:/swap_data/ID/dlrb2.jpeg;G:/swap_data/video/5/FKilL2takAAESu4.jpg;fusion +G:/swap_data/ID/lyf2.jpeg;G:/swap_data/video/5/FKilL2takAAESu4.jpg;fusion +G:/swap_data/ID/dlrb2.jpeg;G:/swap_data/video/5/FKilL2xaIAAup9D.jpg;fusion +G:/swap_data/ID/lyf2.jpeg;G:/swap_data/video/5/FKilL2xaIAAup9D.jpg;fusion +G:/swap_data/ID/dlrb2.jpeg;G:/swap_data/video/5/FKilL2xaQAA9zzV.jpg;fusion +G:/swap_data/ID/lyf2.jpeg;G:/swap_data/video/5/FKilL2xaQAA9zzV.jpg;fusion +G:/swap_data/ID/dlrb2.jpeg;G:/swap_data/video/5/FLy-vUWaMAMsQVf.jpg;fusion +G:/swap_data/ID/lyf2.jpeg;G:/swap_data/video/5/FLy-vUWaMAMsQVf.jpg;fusion +G:/swap_data/ID/dlrb2.jpeg;G:/swap_data/video/5/waaa00103jp-1.jpg;fusion \ No newline at end of file diff --git a/test_json.py b/test_json.py new file mode 100644 index 0000000..42214d8 --- /dev/null +++ b/test_json.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: test.py +# Created Date: Saturday July 3rd 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 6th April 2022 4:09:28 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + + +import os +import argparse +from torch.backends import cudnn +from utilities.json_config import readConfig +from utilities.reporter import Reporter +from utilities.sshupload import fileUploaderClass +import warnings + +warnings.filterwarnings('ignore') + +def str2bool(v): + return v.lower() in ('true') + +#################################################################################### +# To configure the seting of training\finetune\test +# +#################################################################################### +def getParameters(): + + parser = argparse.ArgumentParser() + # general settings + parser.add_argument('-v', '--version', type=str, default='maskhead_recfm_2', # maskhead_recfm_2 maskloss_2 resskip_recfm_1 maskhead_recfm_1 maskhead_recfm_2 resskip_2 resskip_3 resskip_4 resskip_9 cycle_res1 cycle_res2 cycle_res3 cycle_lstu1 depthwise depthwise_config0 Invobn_resinvo1 + help="version name for train, test, finetune") + + parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU + parser.add_argument('-s', '--checkpoint_step', type=int, default=480000, + help="checkpoint epoch for test phase or finetune phase") + parser.add_argument('--start_checkpoint_step', type=int, default=10000, + help="checkpoint epoch for test phase or finetune phase") + + # test + parser.add_argument('-t', '--test_script_name', type=str, default='video') # video image_w_mask image_list_w_mask image_list image_nofusion + parser.add_argument('-b', '--batch_size', type=int, default=1) + parser.add_argument('-n', '--node_ip', type=str, default='2001:da8:8000:6880:f284:d61c:3c76:f9cb') # localhost 119.29.91.52 101.33.242.26 2001:da8:8000:6880:f284:d61c:3c76:f9cb + parser.add_argument('--crop_mode', type=str, default="vggface", choices=['ffhq','vggface'], help='crop mode for face detector') + + + parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg') # G:\\swap_data\\ID\\dlrb2.jpeg 'G:\\swap_data\\FF++\\996_img_00288.jpg' G:\\swap_data\\ID\\hinton.jpg + # parser.add_argument('-i', '--id_imgs', type=str, default='G:\\VGGFace2-HQ\\VGGface2_ffhq_align_256_9_28_512_bygfpgan\\n000002\\0027_01.jpg') + parser.add_argument('-a', '--attr_files', type=str, default='G:/swap_data/video/2/G2218_Trim.mp4', # G:/swap_data/video/1 G:\\swap_data\\ID\\bengio.jpg G:\\swap_data\\FF++\\056_img_00228.jpg + help="file path for attribute images or video") # G:/swap_data/video/2/G2218_Trim.mp4 + parser.add_argument('--img_list_txt', type=str, default='./test_imgs_list.txt', # G:\\swap_data\\ID\\bengio.jpg G:\\swap_data\\FF++\\056_img_00228.jpg + help="file path for image list txt") + parser.add_argument('--record_metric', type=str2bool, default='False', choices=['True', 'False'], + help="Whether to record the cosine similarity") + + parser.add_argument('--use_specified_data', action='store_true') + parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files') + parser.add_argument('--use_specified_data_paths', type=str2bool, default='True', choices=['True', 'False'], help='use the specified save dir') + parser.add_argument('--specified_save_path', type=str, default="G:/swap_data/video/results3", help='save results to specified dir') + + # # logs (does not to be changed in most time) + # parser.add_argument('--dataloader_workers', type=int, default=6) + # parser.add_argument('--use_tensorboard', type=str2bool, default='True', + # choices=['True', 'False'], help='enable the tensorboard') + # parser.add_argument('--log_step', type=int, default=100) + # parser.add_argument('--sample_step', type=int, default=100) + + # # template (onece editing finished, it should be deleted) + # parser.add_argument('--str_parameter', type=str, default="default", help='str parameter') + # parser.add_argument('--str_parameter_choices', type=str, + # default="default", choices=['choice1', 'choice2','choice3'], help='str parameter with choices list') + # parser.add_argument('--int_parameter', type=int, default=0, help='int parameter') + # parser.add_argument('--float_parameter', type=float, default=0.0, help='float parameter') + # parser.add_argument('--bool_parameter', type=str2bool, default='True', choices=['True', 'False'], help='bool parameter') + # parser.add_argument('--list_str_parameter', type=str, nargs='+', default=["element1","element2"], help='str list parameter') + # parser.add_argument('--list_int_parameter', type=int, nargs='+', default=[0,1], help='int list parameter') + return parser.parse_args() + +ignoreKey = [ + "dataloader_workers", + "log_root_path", + "project_root", + "project_summary", + "project_checkpoints", + "project_samples", + "project_scripts", + "reporter_path", + "use_specified_data", + "specified_data_paths", + "dataset_path","cuda", + "test_script_name", + "test_dataloader", + "test_dataset_path", + "save_test_result", + "test_batch_size", + "node_name", + "checkpoint_epoch", + "test_dataset_path", + "test_dataset_name", + "use_my_test_date"] + +#################################################################################### +# This function will create the related directories before the +# training\fintune\test starts +# Your_log_root (version name) +# |---summary/... +# |---samples/... (save evaluated images) +# |---checkpoints/... +# |---scripts/... +# +#################################################################################### +def createDirs(sys_state): + # the base dir + if not os.path.exists(sys_state["log_root_path"]): + os.makedirs(sys_state["log_root_path"]) + + # create dirs + sys_state["project_root"] = os.path.join(sys_state["log_root_path"], + sys_state["version"]) + + project_root = sys_state["project_root"] + if not os.path.exists(project_root): + os.makedirs(project_root) + + sys_state["project_summary"] = os.path.join(project_root, "summary") + if not os.path.exists(sys_state["project_summary"]): + os.makedirs(sys_state["project_summary"]) + + sys_state["project_checkpoints"] = os.path.join(project_root, "checkpoints") + if not os.path.exists(sys_state["project_checkpoints"]): + os.makedirs(sys_state["project_checkpoints"]) + + sys_state["project_samples"] = os.path.join(project_root, "samples") + if not os.path.exists(sys_state["project_samples"]): + os.makedirs(sys_state["project_samples"]) + + sys_state["project_scripts"] = os.path.join(project_root, "scripts") + if not os.path.exists(sys_state["project_scripts"]): + os.makedirs(sys_state["project_scripts"]) + + sys_state["reporter_path"] = os.path.join(project_root,sys_state["version"]+"_report") + +def main(): + + config = getParameters() + # speed up the program + cudnn.benchmark = True + + sys_state = {} + + # set the GPU number + if config.cuda >= 0: + os.environ["CUDA_VISIBLE_DEVICES"] = str(config.cuda) + + # read system environment paths + env_config = readConfig('env/env.json') + env_config = env_config["path"] + sys_state["env_config"] = env_config + + # obtain all configurations in argparse + config_dic = vars(config) + for config_key in config_dic.keys(): + sys_state[config_key] = config_dic[config_key] + + #=======================Test Phase=========================# + + # TODO modify below lines to obtain the configuration + sys_state["log_root_path"] = env_config["train_log_root"] + + sys_state["test_samples_path"] = os.path.join(env_config["test_log_root"], + sys_state["version"] , "samples") + # if not config.use_my_test_date: + # print("Use public benchmark...") + # data_key = config.test_dataset_name.lower() + # sys_state["test_dataset_path"] = env_config["test_dataset_paths"][data_key] + # if config.test_dataset_name.lower() == "set5" or config.test_dataset_name.lower() =="set14": + # sys_state["test_dataloader"] = "setx" + # else: + # sys_state["test_dataloader"] = config.test_dataset_name.lower() + + # sys_state["test_dataset_name"] = config.test_dataset_name + + if not os.path.exists(sys_state["test_samples_path"]): + os.makedirs(sys_state["test_samples_path"]) + + # Create dirs + createDirs(sys_state) + config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"]) + + #fetch checkpoints, model_config.json and scripts from remote machine + if sys_state["node_ip"]!="localhost": + machine_config = env_config["machine_config"] + machine_config = readConfig(machine_config) + nodeinf = None + for item in machine_config: + if item["ip"] == sys_state["node_ip"]: + nodeinf = item + break + if not nodeinf: + raise Exception(print("Configuration of node %s is unavaliable"%sys_state["node_ip"])) + sys_state["remote_machine"] = nodeinf + print("ready to fetch related files from server: %s ......"%nodeinf["ip"]) + uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"]) + + remotebase = os.path.join(nodeinf['path'],"train_logs",sys_state["version"]).replace('\\','/') + + # Get the config.json + print("ready to get the config.json...") + remoteFile = os.path.join(remotebase, env_config["config_json_name"]).replace('\\','/') + localFile = config_json + + ssh_state = uploader.sshScpGet(remoteFile, localFile) + if not ssh_state: + raise Exception(print("Get file %s failed! config.json does not exist!"%remoteFile)) + print("success get the config.json from server %s"%nodeinf['ip']) + + # Get scripts + remoteDir = os.path.join(remotebase, "scripts").replace('\\','/') + localDir = os.path.join(sys_state["project_scripts"]) + ssh_state = uploader.sshScpGetDir(remoteDir, localDir) + if not ssh_state: + raise Exception(print("Get file %s failed! Program exists!"%remoteFile)) + print("Get the scripts successful!") + # Read model_config.json + json_obj = readConfig(config_json) + for item in json_obj.items(): + if item[0] in ignoreKey: + pass + else: + sys_state[item[0]] = item[1] + + # Get checkpoints + if sys_state["node_ip"]!="localhost": + + ckpt_name = "step%d_%s.pth"%(sys_state["checkpoint_step"], + sys_state["checkpoint_names"]["generator_name"]) + localFile = os.path.join(sys_state["project_checkpoints"],ckpt_name) + if not os.path.exists(localFile): + + remoteFile = os.path.join(remotebase, "checkpoints", ckpt_name).replace('\\','/') + ssh_state = uploader.sshScpGet(remoteFile, localFile, True) + if not ssh_state: + raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile)) + print("Get the checkpoint %s successfully!"%(ckpt_name)) + else: + print("%s exists!"%(ckpt_name)) + + + # TODO get the checkpoint file path + sys_state["ckp_name"] = {} + # for data_key in sys_state["checkpoint_names"].keys(): + # sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"], + # "%d_%s.pth"%(sys_state["checkpoint_epoch"], + # sys_state["checkpoint_names"][data_key])) + + # Get the test configurations + sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"] + + # make a reporter + report_path = os.path.join(env_config["test_log_root"], sys_state["version"], + sys_state["version"]+"_report") + reporter = Reporter(report_path) + reporter.writeConfig(sys_state) + + # Display the test information + # TODO modify below lines to display your configuration information + moduleName = "test_scripts.tester_" + sys_state["test_script_name"] + print("Start to run test script: {}".format(moduleName)) + print("Test version: %s"%sys_state["version"]) + print("Test Script Name: %s"%sys_state["test_script_name"]) + + package = __import__(moduleName, fromlist=True) + testerClass = getattr(package, 'Tester') + tester = testerClass(sys_state,reporter) + tester.test() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/test_scripts/tester_image_list_w_2mask.py b/test_scripts/tester_image_list_w_2mask.py new file mode 100644 index 0000000..bb6c658 --- /dev/null +++ b/test_scripts/tester_image_list_w_2mask.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: tester_commonn.py +# Created Date: Saturday July 3rd 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 12th April 2022 9:04:01 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + + + +import os +import cv2 +import time +import glob + +import torch +import torch.nn.functional as F +from torchvision import transforms + +import numpy as np +from PIL import Image + +from insightface_func.face_detect_crop_single import Face_detect_crop + +class Tester(object): + def __init__(self, config, reporter): + + self.config = config + # logger + self.reporter = reporter + + self.transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1) + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1) + + + def __init_framework__(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + model_config = self.config["model_configs"] + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.network = gen_class(**model_config["g_model"]["module_params"]) + + # TODO replace below lines to define the model framework + self.network = gen_class(**model_config["g_model"]["module_params"]) + self.network = self.network.eval() + # for name in self.network.state_dict(): + # print(name) + self.features = {} + mapping_layers = [ + "first_layer", + "down4", + "BottleNeck.2" + ] + + + + # print and recorde model structure + self.reporter.writeInfo("Model structure:") + self.reporter.writeModel(self.network.__str__()) + + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + self.arcface.eval() + self.arcface.requires_grad_(False) + + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) + print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"])) + + # train in GPU + if self.config["cuda"] >=0: + self.network = self.network.cuda() + self.arcface = self.arcface.cuda() + + + + def test(self): + + save_dir = self.config["test_samples_path"] + ckp_step = self.config["checkpoint_step"] + version = self.config["version"] + crop_mode = self.config["crop_mode"] + list_txt = self.config["img_list_txt"] + record_metric= self.config["record_metric"] + specified_save_path = self.config["specified_save_path"] + self.arcface_ckpt= self.config["arcface_ckpt"] + imgs_list = [] + + self.reporter.writeInfo("Version %s"%version) + + if os.path.isdir(specified_save_path): + print("Input a legal specified save path!") + save_dir = specified_save_path + imgs_list = [] + with open(list_txt,'r') as logf: + for line in logf: + cells = line.split(";") + imgs_list.append([cells[0],cells[1],cells[2].replace("\n","")]) + + # models + self.__init_framework__() + + mode = crop_mode.lower() + if mode == "vggface": + mode = "none" + self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') + self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode) + + cos_loss = torch.nn.CosineSimilarity() + font = cv2.FONT_HERSHEY_SIMPLEX + # Start time + import datetime + print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + print('Start =================================== test...') + start_time = time.time() + self.network.eval() + cos_dict = {} + average_cos = 0 + with torch.no_grad(): + for img in imgs_list: + id_img_n, attr_img_n, fusion= img + print("id image:%s---attr image:%s"%(id_img_n, attr_img_n)) + id_img = cv2.imread(id_img_n) + print(fusion) + if fusion.lower() == "fusion": + try: + id_img_align_crop, _ = self.detect.get(id_img,512) + except: + print("Image %s Do not detect a face!"%id_img_n) + continue + # id_basename = os.path.splitext(os.path.basename(id_img_n))[0] + # cv2.imwrite(os.path.join(save_dir, "id_%s.png"%(id_basename)),id_img_align_crop[0]) + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB)) + else: + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB)) + + id_img = self.transformer_Arcface(id_img_align_crop_pil) + id_img = id_img.unsqueeze(0).cuda() + + #create latent id + id_img = F.interpolate(id_img,size=(112,112), mode='bicubic') + latend_id = self.arcface(id_img) + latend_id = F.normalize(latend_id, p=2, dim=1) + attr_img_ori= cv2.imread(attr_img_n) + + if fusion.lower() == "fusion": + try: + attr_img_align_crop, mat = self.detect.get(attr_img_ori,512) + except: + print("Image %s Do not detect a face!"%attr_img_n) + continue + + # attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0] + # cv2.imwrite(os.path.join(save_dir, "attr_%s.png"%(attr_basename)),attr_img_align_crop[0]) + attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB)) + + else: + attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_ori,cv2.COLOR_BGR2RGB)) + + attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda() + + attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic') + + attr_id = self.arcface(attr_img_arc) + attr_id = F.normalize(attr_id, p=2, dim=1) + cos_dis = 1 - cos_loss(latend_id, attr_id) + + + results,mask_lr,mask_hr= self.network(attr_img, latend_id) + + mask_lr = mask_lr.cpu().permute(0,2,3,1)[0,...] + mask_lr = mask_lr.numpy() + # mask_lr = (mask_lr - np.min(mask_lr))/np.max(mask_lr) + mask_lr = np.clip(mask_lr,0.0,1.0) * 255 + mask_hr = mask_hr.cpu().permute(0,2,3,1)[0,...] + mask_hr = mask_hr.numpy() + # mask_hr = (mask_hr - np.min(mask_hr))/np.max(mask_hr) + mask_hr = np.clip(mask_hr,0.0,1.0) * 255 + + results_arc = F.interpolate(results,size=(112,112), mode='bicubic') + results_arc = self.arcface(results_arc) + results_arc = F.normalize(results_arc, p=2, dim=1) + results_cos_dis = 1 - cos_loss(latend_id, results_arc) + average_cos += results_cos_dis + + results = results * self.imagenet_std + self.imagenet_mean + results = results.cpu().permute(0,2,3,1)[0,...] + results = results.numpy() + results = np.clip(results,0.0,1.0) + if fusion.lower() == "fusion": + mat = mat[0] + img_white = np.full((512,512), 255, dtype=float) + + # inverse the Affine transformation matrix + mat_rev = np.zeros([2,3]) + div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0] + mat_rev[0][0] = mat[1][1]/div1 + mat_rev[0][1] = -mat[0][1]/div1 + mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1 + div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1] + mat_rev[1][0] = mat[1][0]/div2 + mat_rev[1][1] = -mat[0][0]/div2 + mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 + + orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0]) + + target_image = cv2.warpAffine(results, mat_rev, orisize) + + img_white = cv2.warpAffine(img_white, mat_rev, orisize) + + + img_white[img_white>20] =255 + + img_mask = img_white + + kernel = np.ones((40,40),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + kernel_size = (20, 20) + blur_size = tuple(2*i+1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + + img_mask /= 255 + + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + + target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 + + img1 = np.array(attr_img_ori, dtype=np.float) + img1 = img_mask * target_image + (1-img_mask) * img1 + else: + results = results*255 + img1 = cv2.cvtColor(results,cv2.COLOR_RGB2BGR) + + final_img = img1.astype(np.uint8) + id_basename = os.path.basename(id_img_n) + id_basename = os.path.splitext(os.path.basename(id_img_n))[0] + attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0] + if record_metric: + final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2) + final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2) + print(save_dir) + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename, + attr_basename,ckp_step,version)) + + cv2.imwrite(save_filename, final_img) + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask_lr.png"%(id_basename, + attr_basename,ckp_step,version)) + cv2.imwrite(save_filename,mask_lr) + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask_hr.png"%(id_basename, + attr_basename,ckp_step,version)) + cv2.imwrite(save_filename,mask_hr) + + average_cos /= len(imgs_list) + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + print("Elapsed [{}]".format(elapsed)) + print("Average cosin similarity between ID and results [{}]".format(average_cos.item())) + self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item())) \ No newline at end of file diff --git a/test_scripts/tester_image_list_w_mask.py b/test_scripts/tester_image_list_w_mask.py new file mode 100644 index 0000000..de762d8 --- /dev/null +++ b/test_scripts/tester_image_list_w_mask.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: tester_commonn.py +# Created Date: Saturday July 3rd 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Saturday, 23rd April 2022 10:04:51 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + + + +import os +import cv2 +import time +import glob + +import torch +import torch.nn.functional as F +from torchvision import transforms + +import numpy as np +from PIL import Image + +from insightface_func.face_detect_crop_single import Face_detect_crop + +class Tester(object): + def __init__(self, config, reporter): + + self.config = config + # logger + self.reporter = reporter + + self.transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1) + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1) + + + def __init_framework__(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + model_config = self.config["model_configs"] + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.network = gen_class(**model_config["g_model"]["module_params"]) + + # TODO replace below lines to define the model framework + self.network = gen_class(**model_config["g_model"]["module_params"]) + self.network = self.network.eval() + # for name in self.network.state_dict(): + # print(name) + self.features = {} + mapping_layers = [ + "first_layer", + "down4", + "BottleNeck.2" + ] + + + + # print and recorde model structure + self.reporter.writeInfo("Model structure:") + self.reporter.writeModel(self.network.__str__()) + + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + self.arcface.eval() + self.arcface.requires_grad_(False) + + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) + print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"])) + + + if self.config["preprocess"]: + print("Employ GFPGAN to upsampling detected face images!") + from face_enhancer.gfpgan import GFPGANer + version = '1.2' + if version == '1': + arch = 'original' + channel_multiplier = 1 + model_name = 'GFPGANv1' + elif version == '1.2': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANCleanv1-NoCE-C2' + elif version == '1.3': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.3' + + # determine model paths + model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth') + if not os.path.isfile(model_path): + raise ValueError(f'Model {model_name} does not exist.') + + self.restorer = GFPGANer( + model_path=model_path, + upscale=1, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=None) + + # train in GPU + if self.config["cuda"] >=0: + self.network = self.network.cuda() + self.arcface = self.arcface.cuda() + + + + def test(self): + + save_dir = self.config["test_samples_path"] + ckp_step = self.config["checkpoint_step"] + version = self.config["version"] + crop_mode = self.config["crop_mode"] + list_txt = self.config["img_list_txt"] + record_metric= self.config["record_metric"] + specified_save_path = self.config["specified_save_path"] + self.arcface_ckpt= self.config["arcface_ckpt"] + imgs_list = [] + + self.reporter.writeInfo("Version %s"%version) + + if os.path.isdir(specified_save_path): + print("Input a legal specified save path!") + save_dir = specified_save_path + imgs_list = [] + with open(list_txt,'r') as logf: + for line in logf: + cells = line.split(";") + imgs_list.append([cells[0],cells[1],cells[2].replace("\n","")]) + + # models + self.__init_framework__() + + mode = crop_mode.lower() + if mode == "vggface": + mode = "none" + self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') + self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode) + + cos_loss = torch.nn.CosineSimilarity() + font = cv2.FONT_HERSHEY_SIMPLEX + # Start time + import datetime + print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + print('Start =================================== test...') + start_time = time.time() + self.network.eval() + cos_dict = {} + average_cos = 0 + with torch.no_grad(): + for img in imgs_list: + id_img_n, attr_img_n, fusion= img + print("id image:%s---attr image:%s"%(id_img_n, attr_img_n)) + id_img = cv2.imread(id_img_n) + print(fusion) + if fusion.lower() == "fusion": + try: + id_img_align_crop, _ = self.detect.get(id_img,512) + except: + print("Image %s Do not detect a face!"%id_img_n) + continue + + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB)) + else: + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB)) + + id_img = self.transformer_Arcface(id_img_align_crop_pil) + id_img = id_img.unsqueeze(0).cuda() + + #create latent id + id_img = F.interpolate(id_img,size=(112,112), mode='bicubic') + latend_id = self.arcface(id_img) + latend_id = F.normalize(latend_id, p=2, dim=1) + attr_img_ori= cv2.imread(attr_img_n) + + if fusion.lower() == "fusion": + try: + attr_img_align_crop, mat = self.detect.get(attr_img_ori,512) + except: + print("Image %s Do not detect a face!"%attr_img_n) + continue + + # attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0] + # cv2.imwrite(os.path.join(save_dir, "attr_%s.png"%(attr_basename)),attr_img_align_crop[0]) + restored_face = attr_img_align_crop[0] + if self.config["preprocess"]: + _, _, restored_face = self.restorer.enhance( + restored_face, has_aligned=False, only_center_face=True, paste_back=True) + attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB)) + + else: + attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_ori,cv2.COLOR_BGR2RGB)) + + attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda() + + attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic') + + attr_id = self.arcface(attr_img_arc) + attr_id = F.normalize(attr_id, p=2, dim=1) + cos_dis = 1 - cos_loss(latend_id, attr_id) + + + # results,mask= self.network(attr_img, latend_id) + pred = self.network(attr_img, latend_id) + results = pred[0] + + + + results_arc = F.interpolate(results,size=(112,112), mode='bicubic') + results_arc = self.arcface(results_arc) + results_arc = F.normalize(results_arc, p=2, dim=1) + results_cos_dis = 1 - cos_loss(latend_id, results_arc) + average_cos += results_cos_dis + + results = results * self.imagenet_std + self.imagenet_mean + results = results.cpu().permute(0,2,3,1)[0,...] + results = results.numpy() + results = np.clip(results,0.0,1.0) + if fusion.lower() == "fusion": + mat = mat[0] + img_white = np.full((512,512), 255, dtype=float) + + # inverse the Affine transformation matrix + mat_rev = np.zeros([2,3]) + div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0] + mat_rev[0][0] = mat[1][1]/div1 + mat_rev[0][1] = -mat[0][1]/div1 + mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1 + div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1] + mat_rev[1][0] = mat[1][0]/div2 + mat_rev[1][1] = -mat[0][0]/div2 + mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 + + orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0]) + + target_image = cv2.warpAffine(results, mat_rev, orisize) + + img_white = cv2.warpAffine(img_white, mat_rev, orisize) + + + img_white[img_white>20] =255 + + img_mask = img_white + + kernel = np.ones((40,40),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + kernel_size = (20, 20) + blur_size = tuple(2*i+1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + + img_mask /= 255 + + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + + target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 + + img1 = np.array(attr_img_ori, dtype=np.float) + img1 = img_mask * target_image + (1-img_mask) * img1 + else: + results = results*255 + img1 = cv2.cvtColor(results,cv2.COLOR_RGB2BGR) + + final_img = img1.astype(np.uint8) + id_basename = os.path.basename(id_img_n) + id_basename = os.path.splitext(os.path.basename(id_img_n))[0] + attr_basename = os.path.splitext(os.path.basename(attr_img_n))[0] + if record_metric: + final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2) + final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2) + print(save_dir) + if self.config["preprocess"]: + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_gfpgan.png"%(id_basename, + attr_basename,ckp_step,version)) + else: + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename, + attr_basename,ckp_step,version)) + + cv2.imwrite(save_filename, final_img) + + if self.config["save_mask"]: + num = 0 + + for mask in pred[1:]: + + mask = mask.cpu().permute(0,2,3,1)[0,...] + mask = mask.numpy() + mask = (mask - np.min(mask))/np.max(mask) + mask = np.clip(mask,0.0,1.0) * 255 + + if self.config["preprocess"]: + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask%d_gfpgan.png"%(id_basename, + attr_basename,ckp_step,version,num)) + else: + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask%d.png"%(id_basename, + attr_basename,ckp_step,version,num)) + + + cv2.imwrite(save_filename,mask) + num += 1 + average_cos /= len(imgs_list) + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + print("Elapsed [{}]".format(elapsed)) + print("Average cosin similarity between ID and results [{}]".format(average_cos.item())) + self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item())) \ No newline at end of file diff --git a/test_scripts/tester_image_w_2mask.py b/test_scripts/tester_image_w_2mask.py new file mode 100644 index 0000000..0a1a65f --- /dev/null +++ b/test_scripts/tester_image_w_2mask.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: tester_commonn.py +# Created Date: Saturday July 3rd 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 12th April 2022 10:09:21 am +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + + + +import os +import cv2 +import time +import glob + +import torch +import torch.nn.functional as F +from torchvision import transforms + +import numpy as np +from PIL import Image + +from insightface_func.face_detect_crop_single import Face_detect_crop + +class Tester(object): + def __init__(self, config, reporter): + + self.config = config + # logger + self.reporter = reporter + + self.transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1) + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1) + + + def __init_framework__(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + model_config = self.config["model_configs"] + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.network = gen_class(**model_config["g_model"]["module_params"]) + + # TODO replace below lines to define the model framework + self.network = gen_class(**model_config["g_model"]["module_params"]) + self.network = self.network.eval() + # for name in self.network.state_dict(): + # print(name) + self.features = {} + mapping_layers = [ + "first_layer", + "down4", + "BottleNeck.2" + ] + + + + # print and recorde model structure + self.reporter.writeInfo("Model structure:") + self.reporter.writeModel(self.network.__str__()) + + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + self.arcface.eval() + self.arcface.requires_grad_(False) + + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) + print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"])) + + # train in GPU + if self.config["cuda"] >=0: + self.network = self.network.cuda() + self.arcface = self.arcface.cuda() + + + + def test(self): + + save_dir = self.config["test_samples_path"] + ckp_step = self.config["checkpoint_step"] + version = self.config["version"] + id_imgs = self.config["id_imgs"] + crop_mode = self.config["crop_mode"] + attr_files = self.config["attr_files"] + specified_save_path = self.config["specified_save_path"] + self.arcface_ckpt= self.config["arcface_ckpt"] + imgs_list = [] + + self.reporter.writeInfo("Version %s"%version) + + if os.path.isdir(specified_save_path): + print("Input a legal specified save path!") + save_dir = specified_save_path + + if os.path.isdir(attr_files): + print("Input a dir....") + imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True) + for item in imgs: + imgs_list.append(item) + print(imgs_list) + else: + print("Input an image....") + imgs_list.append(attr_files) + id_basename = os.path.basename(id_imgs) + id_basename = os.path.splitext(os.path.basename(id_imgs))[0] + + # models + self.__init_framework__() + + mode = crop_mode.lower() + if mode == "vggface": + mode = "none" + self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') + self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode) + + id_img = cv2.imread(id_imgs) + id_img_align_crop, _ = self.detect.get(id_img,512) + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB)) + id_img = self.transformer_Arcface(id_img_align_crop_pil) + id_img = id_img.unsqueeze(0).cuda() + + #create latent id + id_img = F.interpolate(id_img,size=(112,112), mode='bicubic') + latend_id = self.arcface(id_img) + latend_id = F.normalize(latend_id, p=2, dim=1) + cos_loss = torch.nn.CosineSimilarity() + font = cv2.FONT_HERSHEY_SIMPLEX + # Start time + import datetime + print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + print('Start =================================== test...') + start_time = time.time() + self.network.eval() + cos_dict = {} + average_cos = 0 + with torch.no_grad(): + for img in imgs_list: + print(img) + attr_img_ori= cv2.imread(img) + try: + attr_img_align_crop, mat = self.detect.get(attr_img_ori,512) + except: + continue + attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB)) + attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda() + + attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic') + # cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0]) + attr_id = self.arcface(attr_img_arc) + attr_id = F.normalize(attr_id, p=2, dim=1) + cos_dis = 1 - cos_loss(latend_id, attr_id) + + mat = mat[0] + results,mask_lr,mask_hr= self.network(attr_img, latend_id) + + mask_lr = mask_lr.cpu().permute(0,2,3,1)[0,...] + mask_lr = mask_lr.numpy() + # mask_lr = (mask_lr - np.min(mask_lr))/np.max(mask_lr) + mask_lr = np.clip(mask_lr,0.0,1.0) * 255 + mask_hr = mask_hr.cpu().permute(0,2,3,1)[0,...] + mask_hr = mask_hr.numpy() + # mask_hr = (mask_hr - np.min(mask_hr))/np.max(mask_hr) + mask_hr = np.clip(mask_hr,0.0,1.0) * 255 + + results_arc = F.interpolate(results,size=(112,112), mode='bicubic') + results_arc = self.arcface(results_arc) + results_arc = F.normalize(results_arc, p=2, dim=1) + results_cos_dis = 1 - cos_loss(latend_id, results_arc) + average_cos += results_cos_dis + + results = results * self.imagenet_std + self.imagenet_mean + results = results.cpu().permute(0,2,3,1)[0,...] + results = results.numpy() + results = np.clip(results,0.0,1.0) + img_white = np.full((512,512), 255, dtype=float) + + # inverse the Affine transformation matrix + mat_rev = np.zeros([2,3]) + div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0] + mat_rev[0][0] = mat[1][1]/div1 + mat_rev[0][1] = -mat[0][1]/div1 + mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1 + div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1] + mat_rev[1][0] = mat[1][0]/div2 + mat_rev[1][1] = -mat[0][0]/div2 + mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 + + orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0]) + + target_image = cv2.warpAffine(results, mat_rev, orisize) + + img_white = cv2.warpAffine(img_white, mat_rev, orisize) + + + img_white[img_white>20] =255 + + img_mask = img_white + + kernel = np.ones((40,40),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + kernel_size = (20, 20) + blur_size = tuple(2*i+1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + + img_mask /= 255 + + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + + target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 + + img1 = np.array(attr_img_ori, dtype=np.float) + img1 = img_mask * target_image + (1-img_mask) * img1 + final_img = img1.astype(np.uint8) + attr_basename = os.path.splitext(os.path.basename(img))[0] + final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2) + final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2) + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename, + attr_basename,ckp_step,version)) + + cv2.imwrite(save_filename, final_img) + + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask_lr.png"%(id_basename, + attr_basename,ckp_step,version)) + cv2.imwrite(save_filename,mask_lr) + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask_hr.png"%(id_basename, + attr_basename,ckp_step,version)) + cv2.imwrite(save_filename,mask_hr) + + average_cos /= len(imgs_list) + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + print("Elapsed [{}]".format(elapsed)) + print("Average cosin similarity between ID and results [{}]".format(average_cos.item())) + self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item())) \ No newline at end of file diff --git a/test_scripts/tester_image_w_2mask_gfpgan.py b/test_scripts/tester_image_w_2mask_gfpgan.py new file mode 100644 index 0000000..17c2cf7 --- /dev/null +++ b/test_scripts/tester_image_w_2mask_gfpgan.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: tester_commonn.py +# Created Date: Saturday July 3rd 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Saturday, 16th April 2022 5:20:54 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + + + +import os +import cv2 +import time +import glob + +import torch +import torch.nn.functional as F +from torchvision import transforms + +import numpy as np +from PIL import Image + +from insightface_func.face_detect_crop_single import Face_detect_crop +from face_enhancer.gfpgan import GFPGANer + +class Tester(object): + def __init__(self, config, reporter): + + self.config = config + # logger + self.reporter = reporter + + self.transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1) + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1) + + + def __init_framework__(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + model_config = self.config["model_configs"] + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.network = gen_class(**model_config["g_model"]["module_params"]) + + # TODO replace below lines to define the model framework + self.network = gen_class(**model_config["g_model"]["module_params"]) + self.network = self.network.eval() + # for name in self.network.state_dict(): + # print(name) + self.features = {} + mapping_layers = [ + "first_layer", + "down4", + "BottleNeck.2" + ] + + + + # print and recorde model structure + self.reporter.writeInfo("Model structure:") + self.reporter.writeModel(self.network.__str__()) + + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + self.arcface.eval() + self.arcface.requires_grad_(False) + + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) + print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"])) + + version = '1.2' + if version == '1': + arch = 'original' + channel_multiplier = 1 + model_name = 'GFPGANv1' + elif version == '1.2': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANCleanv1-NoCE-C2' + elif version == '1.3': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.3' + + # determine model paths + model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth') + if not os.path.isfile(model_path): + raise ValueError(f'Model {model_name} does not exist.') + + self.restorer = GFPGANer( + model_path=model_path, + upscale=1, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=None) + + # train in GPU + if self.config["cuda"] >=0: + self.network = self.network.cuda() + self.arcface = self.arcface.cuda() + + + + def test(self): + + save_dir = self.config["test_samples_path"] + ckp_step = self.config["checkpoint_step"] + version = self.config["version"] + id_imgs = self.config["id_imgs"] + crop_mode = self.config["crop_mode"] + attr_files = self.config["attr_files"] + specified_save_path = self.config["specified_save_path"] + self.arcface_ckpt= self.config["arcface_ckpt"] + imgs_list = [] + + self.reporter.writeInfo("Version %s"%version) + + if os.path.isdir(specified_save_path): + print("Input a legal specified save path!") + save_dir = specified_save_path + + if os.path.isdir(attr_files): + print("Input a dir....") + imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True) + for item in imgs: + imgs_list.append(item) + print(imgs_list) + else: + print("Input an image....") + imgs_list.append(attr_files) + id_basename = os.path.basename(id_imgs) + id_basename = os.path.splitext(os.path.basename(id_imgs))[0] + + # models + self.__init_framework__() + + mode = crop_mode.lower() + if mode == "vggface": + mode = "none" + self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') + self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode) + + id_img = cv2.imread(id_imgs) + id_img_align_crop, _ = self.detect.get(id_img,512) + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB)) + id_img = self.transformer_Arcface(id_img_align_crop_pil) + id_img = id_img.unsqueeze(0).cuda() + + #create latent id + id_img = F.interpolate(id_img,size=(112,112), mode='bicubic') + latend_id = self.arcface(id_img) + latend_id = F.normalize(latend_id, p=2, dim=1) + cos_loss = torch.nn.CosineSimilarity() + font = cv2.FONT_HERSHEY_SIMPLEX + # Start time + import datetime + print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + print('Start =================================== test...') + start_time = time.time() + self.network.eval() + cos_dict = {} + average_cos = 0 + with torch.no_grad(): + for img in imgs_list: + print(img) + attr_img_ori= cv2.imread(img) + try: + attr_img_align_crop, mat = self.detect.get(attr_img_ori,512) + except: + continue + _, _, restored_face = self.restorer.enhance( + attr_img_align_crop[0], has_aligned=False, only_center_face=True, paste_back=True) + attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB)) + attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda() + + attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic') + # cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0]) + attr_id = self.arcface(attr_img_arc) + attr_id = F.normalize(attr_id, p=2, dim=1) + cos_dis = 1 - cos_loss(latend_id, attr_id) + + mat = mat[0] + results,mask_lr,mask_hr= self.network(attr_img, latend_id) + + mask_lr = mask_lr.cpu().permute(0,2,3,1)[0,...] + mask_lr = mask_lr.numpy() + # mask_lr = (mask_lr - np.min(mask_lr))/np.max(mask_lr) + mask_lr = np.clip(mask_lr,0.0,1.0) * 255 + mask_hr = mask_hr.cpu().permute(0,2,3,1)[0,...] + mask_hr = mask_hr.numpy() + # mask_hr = (mask_hr - np.min(mask_hr))/np.max(mask_hr) + mask_hr = np.clip(mask_hr,0.0,1.0) * 255 + + results_arc = F.interpolate(results,size=(112,112), mode='bicubic') + results_arc = self.arcface(results_arc) + results_arc = F.normalize(results_arc, p=2, dim=1) + results_cos_dis = 1 - cos_loss(latend_id, results_arc) + average_cos += results_cos_dis + + results = results * self.imagenet_std + self.imagenet_mean + results = results.cpu().permute(0,2,3,1)[0,...] + results = results.numpy() + results = np.clip(results,0.0,1.0) + img_white = np.full((512,512), 255, dtype=float) + + # inverse the Affine transformation matrix + mat_rev = np.zeros([2,3]) + div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0] + mat_rev[0][0] = mat[1][1]/div1 + mat_rev[0][1] = -mat[0][1]/div1 + mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1 + div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1] + mat_rev[1][0] = mat[1][0]/div2 + mat_rev[1][1] = -mat[0][0]/div2 + mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 + + orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0]) + + target_image = cv2.warpAffine(results, mat_rev, orisize) + + img_white = cv2.warpAffine(img_white, mat_rev, orisize) + + + img_white[img_white>20] =255 + + img_mask = img_white + + kernel = np.ones((40,40),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + kernel_size = (20, 20) + blur_size = tuple(2*i+1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + + img_mask /= 255 + + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + + target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 + + img1 = np.array(attr_img_ori, dtype=np.float) + img1 = img_mask * target_image + (1-img_mask) * img1 + final_img = img1.astype(np.uint8) + attr_basename = os.path.splitext(os.path.basename(img))[0] + final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2) + final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2) + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename, + attr_basename,ckp_step,version)) + + cv2.imwrite(save_filename, final_img) + + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask_lr.png"%(id_basename, + attr_basename,ckp_step,version)) + cv2.imwrite(save_filename,mask_lr) + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask_hr.png"%(id_basename, + attr_basename,ckp_step,version)) + cv2.imwrite(save_filename,mask_hr) + + average_cos /= len(imgs_list) + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + print("Elapsed [{}]".format(elapsed)) + print("Average cosin similarity between ID and results [{}]".format(average_cos.item())) + self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item())) \ No newline at end of file diff --git a/test_scripts/tester_image_w_mask.py b/test_scripts/tester_image_w_mask.py new file mode 100644 index 0000000..9588789 --- /dev/null +++ b/test_scripts/tester_image_w_mask.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: tester_commonn.py +# Created Date: Saturday July 3rd 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Saturday, 23rd April 2022 10:05:22 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + + + +import os +import cv2 +import time +import glob + +import torch +import torch.nn.functional as F +from torchvision import transforms + +import numpy as np +from PIL import Image + +from insightface_func.face_detect_crop_single import Face_detect_crop + +class Tester(object): + def __init__(self, config, reporter): + + self.config = config + # logger + self.reporter = reporter + + self.transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1) + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1) + + + def __init_framework__(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + model_config = self.config["model_configs"] + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.network = gen_class(**model_config["g_model"]["module_params"]) + + # TODO replace below lines to define the model framework + self.network = gen_class(**model_config["g_model"]["module_params"]) + self.network = self.network.eval() + # for name in self.network.state_dict(): + # print(name) + self.features = {} + mapping_layers = [ + "first_layer", + "down4", + "BottleNeck.2" + ] + + + + # print and recorde model structure + self.reporter.writeInfo("Model structure:") + self.reporter.writeModel(self.network.__str__()) + + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + self.arcface.eval() + self.arcface.requires_grad_(False) + + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) + print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"])) + + if self.config["preprocess"]: + print("Employ GFPGAN to upsampling detected face images!") + from face_enhancer.gfpgan import GFPGANer + version = '1.2' + if version == '1': + arch = 'original' + channel_multiplier = 1 + model_name = 'GFPGANv1' + elif version == '1.2': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANCleanv1-NoCE-C2' + elif version == '1.3': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.3' + + # determine model paths + model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth') + if not os.path.isfile(model_path): + raise ValueError(f'Model {model_name} does not exist.') + + self.restorer = GFPGANer( + model_path=model_path, + upscale=1, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=None) + + # train in GPU + if self.config["cuda"] >=0: + self.network = self.network.cuda() + self.arcface = self.arcface.cuda() + + + + def test(self): + + save_dir = self.config["test_samples_path"] + ckp_step = self.config["checkpoint_step"] + version = self.config["version"] + id_imgs = self.config["id_imgs"] + crop_mode = self.config["crop_mode"] + attr_files = self.config["attr_files"] + specified_save_path = self.config["specified_save_path"] + self.arcface_ckpt= self.config["arcface_ckpt"] + imgs_list = [] + + self.reporter.writeInfo("Version %s"%version) + + if os.path.isdir(specified_save_path): + print("Input a legal specified save path!") + save_dir = specified_save_path + + if os.path.isdir(attr_files): + print("Input a dir....") + imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True) + for item in imgs: + imgs_list.append(item) + print(imgs_list) + else: + print("Input an image....") + imgs_list.append(attr_files) + id_basename = os.path.basename(id_imgs) + id_basename = os.path.splitext(os.path.basename(id_imgs))[0] + + # models + self.__init_framework__() + + mode = crop_mode.lower() + if mode == "vggface": + mode = "none" + self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') + self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode) + + id_img = cv2.imread(id_imgs) + id_img_align_crop, _ = self.detect.get(id_img,512) + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB)) + id_img = self.transformer_Arcface(id_img_align_crop_pil) + id_img = id_img.unsqueeze(0).cuda() + + #create latent id + id_img = F.interpolate(id_img,size=(112,112), mode='bicubic') + latend_id = self.arcface(id_img) + latend_id = F.normalize(latend_id, p=2, dim=1) + cos_loss = torch.nn.CosineSimilarity() + font = cv2.FONT_HERSHEY_SIMPLEX + # Start time + import datetime + print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + print('Start =================================== test...') + start_time = time.time() + self.network.eval() + cos_dict = {} + average_cos = 0 + with torch.no_grad(): + for img in imgs_list: + print(img) + attr_img_ori= cv2.imread(img) + try: + attr_img_align_crop, mat = self.detect.get(attr_img_ori,512) + except: + continue + restored_face = attr_img_align_crop[0] + if self.config["preprocess"]: + _, _, restored_face = self.restorer.enhance( + restored_face, has_aligned=False, only_center_face=True, paste_back=True) + attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB)) + attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda() + + attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic') + # cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0]) + attr_id = self.arcface(attr_img_arc) + attr_id = F.normalize(attr_id, p=2, dim=1) + cos_dis = 1 - cos_loss(latend_id, attr_id) + + mat = mat[0] + pred = self.network(attr_img, latend_id) + results = pred[0] + + results_arc = F.interpolate(results,size=(112,112), mode='bicubic') + results_arc = self.arcface(results_arc) + results_arc = F.normalize(results_arc, p=2, dim=1) + results_cos_dis = 1 - cos_loss(latend_id, results_arc) + average_cos += results_cos_dis + + results = results * self.imagenet_std + self.imagenet_mean + results = results.cpu().permute(0,2,3,1)[0,...] + results = results.numpy() + results = np.clip(results,0.0,1.0) + img_white = np.full((512,512), 255, dtype=float) + + # inverse the Affine transformation matrix + mat_rev = np.zeros([2,3]) + div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0] + mat_rev[0][0] = mat[1][1]/div1 + mat_rev[0][1] = -mat[0][1]/div1 + mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1 + div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1] + mat_rev[1][0] = mat[1][0]/div2 + mat_rev[1][1] = -mat[0][0]/div2 + mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 + + orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0]) + + target_image = cv2.warpAffine(results, mat_rev, orisize) + + img_white = cv2.warpAffine(img_white, mat_rev, orisize) + + + img_white[img_white>20] =255 + + img_mask = img_white + + kernel = np.ones((40,40),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + kernel_size = (20, 20) + blur_size = tuple(2*i+1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + + img_mask /= 255 + + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + + target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 + + img1 = np.array(attr_img_ori, dtype=np.float) + img1 = img_mask * target_image + (1-img_mask) * img1 + final_img = img1.astype(np.uint8) + attr_basename = os.path.splitext(os.path.basename(img))[0] + if self.config["record_metric"]: + final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2) + final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2) + if self.config["preprocess"]: + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_gfpgan.png"%(id_basename, + attr_basename,ckp_step,version)) + else: + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename, + attr_basename,ckp_step,version)) + cv2.imwrite(save_filename, final_img) + + if self.config["save_mask"]: + num = 0 + + for mask in pred[1:]: + + mask = mask.cpu().permute(0,2,3,1)[0,...] + mask = mask.numpy() + mask = (mask - np.min(mask))/np.max(mask) + mask = np.clip(mask,0.0,1.0) * 255 + + if self.config["preprocess"]: + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask%d_gfpgan.png"%(id_basename, + attr_basename,ckp_step,version,num)) + else: + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask%d.png"%(id_basename, + attr_basename,ckp_step,version,num)) + + + cv2.imwrite(save_filename,mask) + num += 1 + + average_cos /= len(imgs_list) + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + print("Elapsed [{}]".format(elapsed)) + print("Average cosin similarity between ID and results [{}]".format(average_cos.item())) + self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item())) \ No newline at end of file diff --git a/test_scripts/tester_image_w_mask_gfpgan.py b/test_scripts/tester_image_w_mask_gfpgan.py new file mode 100644 index 0000000..40ce8b3 --- /dev/null +++ b/test_scripts/tester_image_w_mask_gfpgan.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: tester_commonn.py +# Created Date: Saturday July 3rd 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 14th April 2022 1:48:18 am +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + + + +import os +import cv2 +import time +import glob + +import torch +import torch.nn.functional as F +from torchvision import transforms + +import numpy as np +from PIL import Image + +from insightface_func.face_detect_crop_single import Face_detect_crop +from face_enhancer.gfpgan import GFPGANer + +class Tester(object): + def __init__(self, config, reporter): + + self.config = config + # logger + self.reporter = reporter + + self.transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1) + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1) + + + def __init_framework__(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + model_config = self.config["model_configs"] + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.network = gen_class(**model_config["g_model"]["module_params"]) + + # TODO replace below lines to define the model framework + self.network = gen_class(**model_config["g_model"]["module_params"]) + self.network = self.network.eval() + # for name in self.network.state_dict(): + # print(name) + self.features = {} + mapping_layers = [ + "first_layer", + "down4", + "BottleNeck.2" + ] + + + + # print and recorde model structure + self.reporter.writeInfo("Model structure:") + self.reporter.writeModel(self.network.__str__()) + + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + self.arcface.eval() + self.arcface.requires_grad_(False) + + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.network.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) + print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"])) + + + version = '1.2' + if version == '1': + arch = 'original' + channel_multiplier = 1 + model_name = 'GFPGANv1' + elif version == '1.2': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANCleanv1-NoCE-C2' + elif version == '1.3': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.3' + + # determine model paths + model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + model_path = os.path.join('./face_enhancer/realesrgan/weights', model_name + '.pth') + if not os.path.isfile(model_path): + raise ValueError(f'Model {model_name} does not exist.') + + self.restorer = GFPGANer( + model_path=model_path, + upscale=1, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=None) + + # train in GPU + if self.config["cuda"] >=0: + self.network = self.network.cuda() + self.arcface = self.arcface.cuda() + + + + def test(self): + + save_dir = self.config["test_samples_path"] + ckp_step = self.config["checkpoint_step"] + version = self.config["version"] + id_imgs = self.config["id_imgs"] + crop_mode = self.config["crop_mode"] + attr_files = self.config["attr_files"] + specified_save_path = self.config["specified_save_path"] + self.arcface_ckpt= self.config["arcface_ckpt"] + imgs_list = [] + + self.reporter.writeInfo("Version %s"%version) + + if os.path.isdir(specified_save_path): + print("Input a legal specified save path!") + save_dir = specified_save_path + + if os.path.isdir(attr_files): + print("Input a dir....") + imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True) + for item in imgs: + imgs_list.append(item) + print(imgs_list) + else: + print("Input an image....") + imgs_list.append(attr_files) + id_basename = os.path.basename(id_imgs) + id_basename = os.path.splitext(os.path.basename(id_imgs))[0] + + # models + self.__init_framework__() + + mode = crop_mode.lower() + if mode == "vggface": + mode = "none" + self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') + self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode) + + id_img = cv2.imread(id_imgs) + id_img_align_crop, _ = self.detect.get(id_img,512) + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB)) + id_img = self.transformer_Arcface(id_img_align_crop_pil) + id_img = id_img.unsqueeze(0).cuda() + + #create latent id + id_img = F.interpolate(id_img,size=(112,112), mode='bicubic') + latend_id = self.arcface(id_img) + latend_id = F.normalize(latend_id, p=2, dim=1) + cos_loss = torch.nn.CosineSimilarity() + font = cv2.FONT_HERSHEY_SIMPLEX + # Start time + import datetime + print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + print('Start =================================== test...') + start_time = time.time() + self.network.eval() + cos_dict = {} + average_cos = 0 + with torch.no_grad(): + for img in imgs_list: + print(img) + attr_img_ori= cv2.imread(img) + try: + attr_img_align_crop, mat = self.detect.get(attr_img_ori,512) + except: + continue + _, _, restored_face = self.restorer.enhance( + attr_img_align_crop[0], has_aligned=False, only_center_face=True, paste_back=True) + # cv2.imwrite("id_wocao.png",restored_face) + attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(restored_face,cv2.COLOR_BGR2RGB)) + attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda() + + attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic') + # cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0]) + attr_id = self.arcface(attr_img_arc) + attr_id = F.normalize(attr_id, p=2, dim=1) + cos_dis = 1 - cos_loss(latend_id, attr_id) + + mat = mat[0] + results,mask= self.network(attr_img, latend_id) + + mask = mask.cpu().permute(0,2,3,1)[0,...] + mask = mask.numpy() + mask = (mask - np.min(mask))/np.max(mask) + mask = np.clip(mask,0.0,1.0) * 255 + + results_arc = F.interpolate(results,size=(112,112), mode='bicubic') + results_arc = self.arcface(results_arc) + results_arc = F.normalize(results_arc, p=2, dim=1) + results_cos_dis = 1 - cos_loss(latend_id, results_arc) + average_cos += results_cos_dis + + results = results * self.imagenet_std + self.imagenet_mean + results = results.cpu().permute(0,2,3,1)[0,...] + results = results.numpy() + results = np.clip(results,0.0,1.0) + img_white = np.full((512,512), 255, dtype=float) + + # inverse the Affine transformation matrix + mat_rev = np.zeros([2,3]) + div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0] + mat_rev[0][0] = mat[1][1]/div1 + mat_rev[0][1] = -mat[0][1]/div1 + mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1 + div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1] + mat_rev[1][0] = mat[1][0]/div2 + mat_rev[1][1] = -mat[0][0]/div2 + mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2 + + orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0]) + + target_image = cv2.warpAffine(results, mat_rev, orisize) + + img_white = cv2.warpAffine(img_white, mat_rev, orisize) + + + img_white[img_white>20] =255 + + img_mask = img_white + + kernel = np.ones((40,40),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + kernel_size = (20, 20) + blur_size = tuple(2*i+1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + + img_mask /= 255 + + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + + target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255 + + img1 = np.array(attr_img_ori, dtype=np.float) + img1 = img_mask * target_image + (1-img_mask) * img1 + final_img = img1.astype(np.uint8) + attr_basename = os.path.splitext(os.path.basename(img))[0] + final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2) + final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2) + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename, + attr_basename,ckp_step,version)) + + cv2.imwrite(save_filename, final_img) + + save_filename = os.path.join(save_dir, + "id_%s--attr_%s_ckp_%s_v_%s_mask.png"%(id_basename, + attr_basename,ckp_step,version)) + cv2.imwrite(save_filename,mask) + + average_cos /= len(imgs_list) + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + print("Elapsed [{}]".format(elapsed)) + print("Average cosin similarity between ID and results [{}]".format(average_cos.item())) + self.reporter.writeInfo("Average cosin similarity between ID and results [{}]".format(average_cos.item())) \ No newline at end of file diff --git a/test_scripts/tester_video.py b/test_scripts/tester_video.py index bfe6627..08f188e 100644 --- a/test_scripts/tester_video.py +++ b/test_scripts/tester_video.py @@ -5,7 +5,7 @@ # Created Date: Saturday July 3rd 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Friday, 21st January 2022 11:06:37 am +# Last Modified: Friday, 22nd April 2022 11:20:19 am # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -33,6 +33,8 @@ from utilities.ImagenetNorm import ImagenetNorm from parsing_model.model import BiSeNet from insightface_func.face_detect_crop_single import Face_detect_crop from utilities.reverse2original import reverse2wholeimage +from face_enhancer.gfpgan import GFPGANer +from utilities.utilities import load_file_from_url class Tester(object): def __init__(self, config, reporter): @@ -64,6 +66,7 @@ class Tester(object): def video_swap( self, video_path, + gfpgan, id_vetor, save_path, temp_results_dir='./temp_results', @@ -121,8 +124,11 @@ class Tester(object): swap_result_list = [] frame_align_crop_tenor_list = [] for frame_align_crop in frame_align_crop_list: + if gfpgan: + _, _, frame_align_crop = gfpgan.enhance( + frame_align_crop, has_aligned=False, only_center_face=True, paste_back=True) frame_align_crop_tenor = self.cv2totensor(frame_align_crop) - swap_result = self.network(frame_align_crop_tenor, id_vetor)[0] + swap_result = self.network(frame_align_crop_tenor, id_vetor)[0][0] swap_result = swap_result* self.imagenet_std + self.imagenet_mean swap_result = torch.clip(swap_result,0.0,1.0) cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) @@ -216,6 +222,39 @@ class Tester(object): # models self.__init_framework__() + if self.config["preprocess"]: + print("Employ GFPGAN to upsampling detected face images!") + version = '1.2' + if version == '1': + arch = 'original' + channel_multiplier = 1 + model_name = 'GFPGANv1' + elif version == '1.2': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANCleanv1-NoCE-C2' + elif version == '1.3': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.3' + + # determine model paths + model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth') + url_path = "https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth" + if not os.path.isfile(model_path): + # raise ValueError(f'Model {model_name} does not exist.') + print(f'Model {model_name} does not exist. Prepare to download it......') + model_path = load_file_from_url( + url=url_path, model_dir=model_path, progress=True, file_name=None) + restorer = GFPGANer( + model_path=model_path, + upscale=1, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=None) + else: + restorer = None + mode = None @@ -239,7 +278,7 @@ class Tester(object): start_time = time.time() self.network.eval() with torch.no_grad(): - self.video_swap(attr_files, latend_id, save_dir, temp_results_dir="./.temples",\ + self.video_swap(attr_files, restorer, latend_id, save_dir, temp_results_dir="./.temples",\ use_mask=False,crop_size=512) elapsed = time.time() - start_time diff --git a/test_scripts/tester_video_gfpgan.py b/test_scripts/tester_video_gfpgan.py new file mode 100644 index 0000000..c941d70 --- /dev/null +++ b/test_scripts/tester_video_gfpgan.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: tester_commonn.py +# Created Date: Saturday July 3rd 2021 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 14th April 2022 11:40:45 am +# Modified By: Chen Xuanhong +# Copyright (c) 2021 Shanghai Jiao Tong University +############################################################# + + + +import os +import cv2 +import time +import shutil + +import torch +import torch.nn.functional as F +from torchvision import transforms + +from moviepy.editor import AudioFileClip, VideoFileClip +from moviepy.video.io.ImageSequenceClip import ImageSequenceClip + +import numpy as np +from tqdm import tqdm +from PIL import Image +import glob + +from utilities.ImagenetNorm import ImagenetNorm +from parsing_model.model import BiSeNet +from insightface_func.face_detect_crop_single import Face_detect_crop +from utilities.reverse2original import reverse2wholeimage +from face_enhancer.gfpgan import GFPGANer +from utilities.utilities import load_file_from_url + +class Tester(object): + def __init__(self, config, reporter): + + self.config = config + # logger + self.reporter = reporter + + self.transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1) + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1) + + def cv2totensor(self, cv2_img): + """ + cv2_img: an image read by cv2, H*W*C + return: an 1*C*H*W tensor + """ + cv2_img = cv2.cvtColor(cv2_img,cv2.COLOR_BGR2RGB) + cv2_img = torch.from_numpy(cv2_img) + cv2_img = cv2_img.permute(2,0,1).cuda() + temp = cv2_img / 255.0 + temp -= self.imagenet_mean + temp /= self.imagenet_std + return temp.unsqueeze(0) + + def video_swap( + self, + video_path, + gfpgan, + id_vetor, + save_path, + temp_results_dir='./temp_results', + crop_size=512, + use_mask =False + ): + + video_forcheck = VideoFileClip(video_path) + if video_forcheck.audio is None: + no_audio = True + else: + no_audio = False + + del video_forcheck + + if not no_audio: + video_audio_clip = AudioFileClip(video_path) + + video = cv2.VideoCapture(video_path) + ret = True + frame_index = 0 + + frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + + # video_WIDTH = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + + # video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + fps = video.get(cv2.CAP_PROP_FPS) + if os.path.exists(temp_results_dir): + shutil.rmtree(temp_results_dir) + spNorm =ImagenetNorm() + if use_mask: + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + save_pth = os.path.join('./parsing_model', '79999_iter.pth') + net.load_state_dict(torch.load(save_pth)) + net.eval() + else: + net =None + + # while ret: + for frame_index in tqdm(range(frame_count)): + ret, frame = video.read() + if ret: + detect_results = self.detect.get(frame,crop_size) + + if detect_results is not None: + # print(frame_index) + if not os.path.exists(temp_results_dir): + os.mkdir(temp_results_dir) + frame_align_crop_list = detect_results[0] + frame_mat_list = detect_results[1] + swap_result_list = [] + frame_align_crop_tenor_list = [] + for frame_align_crop in frame_align_crop_list: + _, _, restored_face = gfpgan.enhance( + frame_align_crop, has_aligned=False, only_center_face=True, paste_back=True) + frame_align_crop_tenor = self.cv2totensor(restored_face) + swap_result = self.network(frame_align_crop_tenor, id_vetor)[0][0] + swap_result = swap_result* self.imagenet_std + self.imagenet_mean + swap_result = torch.clip(swap_result,0.0,1.0) + cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) + swap_result_list.append(swap_result) + frame_align_crop_tenor_list.append(frame_align_crop_tenor) + reverse2wholeimage(frame_align_crop_tenor_list,swap_result_list, frame_mat_list, crop_size, frame,\ + os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),pasring_model =net,use_mask=use_mask, norm = spNorm) + + else: + if not os.path.exists(temp_results_dir): + os.mkdir(temp_results_dir) + frame = frame.astype(np.uint8) + cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame) + else: + break + + video.release() + + # image_filename_list = [] + path = os.path.join(temp_results_dir,'*.jpg') + image_filenames = sorted(glob.glob(path)) + + clips = ImageSequenceClip(image_filenames,fps = fps) + + if not no_audio: + clips = clips.set_audio(video_audio_clip) + basename = os.path.basename(video_path) + basename = os.path.splitext(basename)[0] + save_filename = os.path.join(save_path, basename+".mp4") + index = 0 + while(True): + if os.path.exists(save_filename): + save_filename = os.path.join(save_path, basename+"_%d.mp4"%index) + index += 1 + else: + break + clips.write_videofile(save_filename,audio_codec='aac') + + + def __init_framework__(self): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + model_config = self.config["model_configs"] + gscript_name = self.config["com_base"] + model_config["g_model"]["script"] + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + self.network = gen_class(**model_config["g_model"]["module_params"]) + + # TODO replace below lines to define the model framework + self.network = gen_class(**model_config["g_model"]["module_params"]) + self.network = self.network.eval() + # print and recorde model structure + self.reporter.writeInfo("Model structure:") + self.reporter.writeModel(self.network.__str__()) + + arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu")) + self.arcface = arcface1['model'].module + self.arcface.eval() + self.arcface.requires_grad_(False) + + # train in GPU + if self.config["cuda"] >=0: + self.network = self.network.cuda() + self.arcface = self.arcface.cuda() + # loader1 = torch.load(self.config["ckp_name"]["generator_name"]) + # print(loader1.key()) + # pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"] + model_path = os.path.join(self.config["project_checkpoints"], + "step%d_%s.pth"%(self.config["checkpoint_step"], + self.config["checkpoint_names"]["generator_name"])) + self.network.load_state_dict(torch.load(model_path)) + # self.network.load_state_dict(torch.load(pathwocao)) + print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"])) + + def test(self): + + # save_result = self.config["saveTestResult"] + save_dir = self.config["test_samples_path"] + ckp_step = self.config["checkpoint_step"] + version = self.config["version"] + id_imgs = self.config["id_imgs"] + attr_files = self.config["attr_files"] + self.arcface_ckpt= self.config["arcface_ckpt"] + + # models + self.__init_framework__() + version = '1.2' + if version == '1': + arch = 'original' + channel_multiplier = 1 + model_name = 'GFPGANv1' + elif version == '1.2': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANCleanv1-NoCE-C2' + elif version == '1.3': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.3' + + # determine model paths + model_path = os.path.join('./face_enhancer/experiments/pretrained_models', model_name + '.pth') + url_path = "https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth" + + if not os.path.isfile(model_path): + # raise ValueError(f'Model {model_name} does not exist.') + print(f'Model {model_name} does not exist. Prepare to download it......') + model_path = load_file_from_url( + url=url_path, model_dir=model_path, progress=True, file_name=None) + + restorer = GFPGANer( + model_path=model_path, + upscale=1, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=None) + + + + mode = None + self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models') + self.detect.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode) + + id_img = cv2.imread(id_imgs) + id_img_align_crop, _ = self.detect.get(id_img,512) + # _, _, restored_face = restorer.enhance( + # id_img_align_crop[0], has_aligned=False, only_center_face=True, paste_back=True) + # cv2.imwrite("id_wocao.png",restored_face) + id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB)) + id_img = self.transformer_Arcface(id_img_align_crop_pil) + id_img = id_img.unsqueeze(0).cuda() + + #create latent id + id_img = F.interpolate(id_img,size=(112,112), mode='bicubic') + latend_id = self.arcface(id_img) + latend_id = F.normalize(latend_id, p=2, dim=1) + # Start time + import datetime + print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + print('Start =================================== test...') + start_time = time.time() + self.network.eval() + with torch.no_grad(): + self.video_swap(attr_files, restorer, latend_id, save_dir, temp_results_dir="./.temples",\ + use_mask=False,crop_size=512) + + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + print("Elapsed [{}]".format(elapsed)) \ No newline at end of file diff --git a/train_multigpu.py b/train_multigpu.py index 792c435..5d98b15 100644 --- a/train_multigpu.py +++ b/train_multigpu.py @@ -5,7 +5,7 @@ # Created Date: Tuesday April 28th 2020 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 24th March 2022 2:14:07 pm +# Last Modified: Tuesday, 19th April 2022 6:58:59 pm # Modified By: Chen Xuanhong # Copyright (c) 2020 Shanghai Jiao Tong University ############################################################# @@ -31,9 +31,9 @@ def getParameters(): parser = argparse.ArgumentParser() # general settings - parser.add_argument('-v', '--version', type=str, default='cycle_res3', + parser.add_argument('-v', '--version', type=str, default='2maskloss256_1', help="version name for train, test, finetune") - parser.add_argument('-t', '--tag', type=str, default='cycle', + parser.add_argument('-t', '--tag', type=str, default='256', help="tag for current experiment") parser.add_argument('-p', '--phase', type=str, default="train", @@ -41,14 +41,16 @@ def getParameters(): help="The phase of current project") parser.add_argument('-c', '--gpus', type=int, nargs='+', default=[0,1,2,3]) # <0 if it is set as -1, program will use CPU - parser.add_argument('-e', '--ckpt', type=int, default=74, + parser.add_argument('-e', '--ckpt', type=int, default=10000, help="checkpoint epoch for test phase or finetune phase") # training parser.add_argument('--experiment_description', type=str, - default="cycle配合残差decoder,改用starganv2的generator结构") + default="使用了一个128*128的mask以及256*256的mask,mask loss调整到100使得mask能比较完整,另外我始终认为mask head应该从encoder的输出引出,如果从decoder引出显然逻辑上\\\ + 就有很大的问题,因为这个时候mask与decoder共用同一个feature,那么如果被ID改变后的feature脸型发生了较大的改变,岂不是mask也会跟着变,mask应该反映的是encoder输入的target image的\\\ + mask,那么是否应该使用encoder的信息呢.此前从enc中引head去生成mask会生成空洞较大的mask,增大weight应该是可以改善这个问题吧,将生成器最后的两层改回adain注入") - parser.add_argument('--train_yaml', type=str, default="train_cycleloss_res.yaml") + parser.add_argument('--train_yaml', type=str, default="train_2maskhead_256.yaml") # system logger parser.add_argument('--logger', type=str, diff --git a/train_scripts/trainer_mgpu_2maskloss.py b/train_scripts/trainer_mgpu_2maskloss.py new file mode 100644 index 0000000..b1c80c6 --- /dev/null +++ b/train_scripts/trainer_mgpu_2maskloss.py @@ -0,0 +1,588 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 12th April 2022 1:51:44 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random +import shutil +import tempfile + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +from utilities.plot import plot_batch +from train_scripts.trainer_multigpu_base import TrainerBase + + +class Trainer(TrainerBase): + + def __init__(self, + config, + reporter): + super(Trainer, self).__init__(config, reporter) + + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) + + def train(self): + # Launch processes. + num_gpus = len(self.config["gpus"]) + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus) + +# TODO modify this function to build your models +def init_framework(config, reporter, device, rank): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + model_config = config["model_configs"] + + if config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + file1 = os.path.join("components", model_config["g_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + dscript_name = "components." + model_config["d_model"]["script"] + file1 = os.path.join("components", model_config["d_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + + elif config["phase"] == "finetune": + gscript_name = config["com_base"] + model_config["g_model"]["script"] + dscript_name = config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + reporter.writeInfo("Generator structure:") + reporter.writeModel(gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + dis = dis_class(**model_config["d_model"]["module_params"]) + + + # print and recorde model structure + reporter.writeInfo("Discriminator structure:") + reporter.writeModel(dis.__str__()) + + # arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + # arcface = arcface1['model'].module + + # arcface = iresnet100(pretrained=False, fp16=False) + # arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu')) + # arcface.eval() + arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + arcface = arcface1['model'].module + + # train in GPU + + # if in finetune phase, load the pretrained checkpoint + if config["phase"] == "finetune": + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["generator_name"])) + gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["discriminator_name"])) + dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + + gen = gen.to(device) + dis = dis.to(device) + arcface= arcface.to(device) + arcface.requires_grad_(False) + arcface.eval() + + + + return gen, dis, arcface + +# TODO modify this function to configurate the optimizer of your pipeline +def setup_optimizers(config, reporter, gen, dis, rank): + + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + g_train_opt = config['g_optim_config'] + d_train_opt = config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = config['optim_type'] + + if optim_type == 'Adam': + g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if config["phase"] == "finetune": + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["generator_name"])) + g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["discriminator_name"])) + d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"])) + return g_optimizer, d_optimizer + +def d_logistic_loss(real_pred, fake_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + +def d_r1_loss(d_out, x_in): + # zero-centered gradient penalty for real images + batch_size = x_in.size(0) + grad_dout = torch.autograd.grad( + outputs=d_out.sum(), inputs=x_in, + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_dout2 = grad_dout.pow(2) + assert(grad_dout2.size() == x_in.size()) + reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) + return reg + + +def g_nonsaturating_loss(fake_pred): + loss = F.softplus(-fake_pred).mean() + + return loss + +def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +# def r1_reg(d_out, x_in): +# # zero-centered gradient penalty for real images +# batch_size = x_in.size(0) +# grad_dout = torch.autograd.grad( +# outputs=d_out.sum(), inputs=x_in, +# create_graph=True, retain_graph=True, only_inputs=True +# )[0] +# grad_dout2 = grad_dout.pow(2) +# assert(grad_dout2.size() == x_in.size()) +# reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) +# return reg + +def train_loop( + rank, + config, + reporter, + temp_dir + ): + + version = config["version"] + + ckpt_dir = config["project_checkpoints"] + sample_dir = config["project_samples"] + + log_freq = config["log_step"] + model_freq = config["model_save_step"] + sample_freq = config["sample_step"] + total_step = config["total_step"] + random_seed = config["dataset_params"]["random_seed"] + d_reg_freq = config["d_reg_freq"] + + id_w = config["id_weight"] + rec_w = config["reconstruct_weight"] + rec_fm_w = config["rec_feature_match_weight"] + mask_w = config["mask_weight"] + cycle_w = config["cycle_weight"] + reg_w = config["reg_weight"] + num_gpus = len(config["gpus"]) + batch_gpu = config["batch_size"] // num_gpus + + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + + + + if rank == 0: + img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + + # Initialize. + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Create dataloader. + if rank == 0: + print('Loading training set...') + + dataset = config["dataset_paths"][config["dataset_name"]] + #================================================# + print("Prepare the train dataloader...") + dlModulename = config["dataloader"] + package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'GetLoader') + dataloader_class= dataloaderClass + dataloader = dataloader_class(dataset, + rank, + num_gpus, + batch_gpu, + **config["dataset_params"]) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + gen, dis, arcface = init_framework(config, reporter, device, rank) + + # Check for existing checkpoint + + # Print network summary tables. + # if rank == 0: + # attr = torch.empty([batch_gpu, 3, 512, 512], device=device) + # id = torch.empty([batch_gpu, 3, 112, 112], device=device) + # latent = misc.print_module_summary(arcface, [id]) + # img = misc.print_module_summary(gen, [attr, latent]) + # misc.print_module_summary(dis, [img, None]) + # del attr + # del id + # del latent + # del img + # torch.cuda.empty_cache() + + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [gen, dis, arcface]: + if module is not None and num_gpus > 1: + for param in misc.params_and_buffers(module): + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + #==============build tensorboard=================# + if config["logger"] == "tensorboard": + import torch.utils.tensorboard as tensorboard + tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) + logger = tensorboard_writer + + elif config["logger"] == "wandb": + import wandb + wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", + tags=[config["tag"]], name=version) + + wandb.config = { + "total_step": config["total_step"], + "batch_size": config["batch_size"] + } + logger = wandb + + + random.seed(random_seed) + randindex = [i for i in range(batch_gpu)] + + # set the start point for training loop + if config["phase"] == "finetune": + start = config["ckpt"] + else: + start = 0 + if rank == 0: + import datetime + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from utilities.logo_class import logo_class + logo_class.print_start_training() + + for step in range(start, total_step): + gen.train() + dis.train() + + for interval in range(2): + + src_image1, src_image2, mask_label = dataloader.next() + + if step%2 == 0: + img_id = src_image2 + else: + random.shuffle(randindex) + img_id = src_image2[randindex] + mask_label = mask_label[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + if interval == 0: + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + requires_grad(dis, True) + requires_grad(gen, False) + + d_regularize = step % d_reg_freq == 0 + if d_regularize: + src_image1.requires_grad_() + + real_logits,_ = dis(src_image1) + with torch.no_grad(): + img_fake,_,_ = gen(src_image1, latent_id.detach()) + fake_logits,_ = dis(img_fake.detach()) + + loss_D = d_logistic_loss(real_logits, fake_logits) + + if d_regularize: + loss_reg = d_r1_loss(real_logits, src_image1) + loss_D += loss_reg * reg_w * d_reg_freq + + d_optimizer.zero_grad(set_to_none=True) + loss_D.backward() + + with torch.autograd.profiler.record_function('discriminator_opt'): + # params = [param for param in dis.parameters() if param.grad is not None] + # if len(params) > 0: + # flat = torch.cat([param.grad.flatten() for param in params]) + # if num_gpus > 1: + # torch.distributed.all_reduce(flat) + # flat /= num_gpus + # misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + # grads = flat.split([param.numel() for param in params]) + # for param, grad in zip(params, grads): + # param.grad = grad.reshape(param.shape) + params = [param for param in dis.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + d_optimizer.step() + + + #================================Generator interval======================================# + else: + requires_grad(dis, False) + requires_grad(gen, True) + # model.netD.requires_grad_(True) + img_fake,lr_mask,hr_mask= gen(src_image1, latent_id.detach()) + # G loss + gen_logits,fake_feat= dis(img_fake) + # real_feat = dis.get_feature(src_image1) + loss_Gmain = g_nonsaturating_loss(gen_logits) + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id.detach())).mean() + mask_label_lr = F.interpolate(mask_label, size=(128,128), mode='bicubic') + loss_mask = l1_loss(lr_mask, mask_label_lr) + l1_loss(hr_mask, mask_label) + loss_G = loss_Gmain + loss_G_ID * id_w + loss_mask * mask_w + if step%2 == 0: + #G_Rec + real_feat = dis.get_feature(src_image1) + rec_fm = l1_loss(fake_feat, real_feat) + loss_G_Rec = l1_loss(img_fake, src_image1) + # lpips_loss = loss_fn_vgg(img_fake, src_image1).mean() + loss_G += (loss_G_Rec * rec_w + rec_fm_w * rec_fm) #+ rec_fm * rec_fm_w + else: + source1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic') + latent_source1 = arcface(source1_down) + latent_source1 = F.normalize(latent_source1, p=2, dim=1) + cycle_src,_,_ = gen(img_fake, latent_source1) + cycle_loss = l1_loss(src_image1,cycle_src) + # cycle_feat = dis.get_feature(cycle_src) + # cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) + l1_loss(real_feat["2"],cycle_feat["2"]) + loss_G += cycle_loss * cycle_w #+ cycle_fm * cycle_fm_w + + + g_optimizer.zero_grad(set_to_none=True) + loss_G.backward() + with torch.autograd.profiler.record_function('generator_opt'): + params = [param for param in gen.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + g_optimizer.step() + + # Print out log info + if rank == 0 and (step + 1) % log_freq == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + # epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \ + # rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + # format(version, elapsed, step, total_step, \ + # loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \ + # rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_loss: {:.4f}, rec_fm: {:.4f}, loss_mask: {:.4f}, D_loss: {:.4f}, D_R1: {:.4f}". \ + format(version, elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_loss.item(), rec_fm.item(), loss_mask.item(), loss_D.item(), loss_reg.item()) + print(epochinformation) + reporter.writeInfo(epochinformation) + + if config["logger"] == "tensorboard": + logger.add_scalar('G/G_loss', loss_G.item(), step) + logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + logger.add_scalar('G/cycle_loss', cycle_loss.item(), step) + # logger.add_scalar('G/cycle_fm', cycle_fm.item(), step) + # logger.add_scalar('G/rec_fm', rec_fm.item(), step) + logger.add_scalar('G/rec_fm', rec_fm.item(), step) + logger.add_scalar('G/loss_mask', loss_mask.item(), step) + logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + logger.add_scalar('D/D_loss', loss_D.item(), step) + logger.add_scalar('D/D_reg', loss_reg.item(), step) + elif config["logger"] == "wandb": + logger.log({"G_Loss": loss_G.item()}, step = step) + logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + logger.log({"cycle_loss": cycle_loss.item()}, step = step) + # logger.log({"cycle_fm": cycle_fm.item()}, step = step) + # logger.log({"rec_fm": rec_fm.item()}, step = step) + logger.log({"rec_fm": rec_fm.item()}, step = step) + logger.log({"loss_mask": loss_mask.item()}, step = step) + logger.log({"G_ID": loss_G_ID.item()}, step = step) + logger.log({"D_loss": loss_D.item()}, step = step) + logger.log({"D_reg": loss_reg.item()}, step = step) + torch.cuda.empty_cache() + + if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0): + gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* img_std + img_mean).numpy() + for r in range(batch_gpu): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_gpu): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1) + img_fake,_,_ = gen(image_infer, id_vector_src1) + + img_fake = img_fake.cpu() * img_std + img_fake = img_fake + img_mean + img_fake = img_fake.numpy() + # pred_mask = pred_mask.cpu().numpy() * 255 + for j in range(batch_gpu): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg')) + torch.cuda.empty_cache() + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if rank == 0 and (step+1) % model_freq==0: + + torch.save(gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + torch.save(dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + + torch.save(g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + + torch.save(d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + print("Rank %d process done!"%rank) + torch.distributed.barrier() \ No newline at end of file diff --git a/train_scripts/trainer_mgpu_2maskloss_256.py b/train_scripts/trainer_mgpu_2maskloss_256.py new file mode 100644 index 0000000..08be608 --- /dev/null +++ b/train_scripts/trainer_mgpu_2maskloss_256.py @@ -0,0 +1,592 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 19th April 2022 6:57:10 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random +import shutil +import tempfile + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +from utilities.plot import plot_batch +from train_scripts.trainer_multigpu_base import TrainerBase + + +class Trainer(TrainerBase): + + def __init__(self, + config, + reporter): + super(Trainer, self).__init__(config, reporter) + + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) + + def train(self): + # Launch processes. + num_gpus = len(self.config["gpus"]) + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus) + +# TODO modify this function to build your models +def init_framework(config, reporter, device, rank): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + model_config = config["model_configs"] + + if config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + file1 = os.path.join("components", model_config["g_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + dscript_name = "components." + model_config["d_model"]["script"] + file1 = os.path.join("components", model_config["d_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + + elif config["phase"] == "finetune": + gscript_name = config["com_base"] + model_config["g_model"]["script"] + dscript_name = config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + reporter.writeInfo("Generator structure:") + reporter.writeModel(gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + dis = dis_class(**model_config["d_model"]["module_params"]) + + + # print and recorde model structure + reporter.writeInfo("Discriminator structure:") + reporter.writeModel(dis.__str__()) + + # arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + # arcface = arcface1['model'].module + + # arcface = iresnet100(pretrained=False, fp16=False) + # arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu')) + # arcface.eval() + arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + arcface = arcface1['model'].module + + # train in GPU + + # if in finetune phase, load the pretrained checkpoint + if config["phase"] == "finetune": + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["generator_name"])) + gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["discriminator_name"])) + dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + + gen = gen.to(device) + dis = dis.to(device) + arcface= arcface.to(device) + arcface.requires_grad_(False) + arcface.eval() + + + + return gen, dis, arcface + +# TODO modify this function to configurate the optimizer of your pipeline +def setup_optimizers(config, reporter, gen, dis, rank): + + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + g_train_opt = config['g_optim_config'] + d_train_opt = config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = config['optim_type'] + + if optim_type == 'Adam': + g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if config["phase"] == "finetune": + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["generator_name"])) + g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["discriminator_name"])) + d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"])) + return g_optimizer, d_optimizer + +def d_logistic_loss(real_pred, fake_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + +def d_r1_loss(d_out, x_in): + # zero-centered gradient penalty for real images + batch_size = x_in.size(0) + grad_dout = torch.autograd.grad( + outputs=d_out.sum(), inputs=x_in, + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_dout2 = grad_dout.pow(2) + assert(grad_dout2.size() == x_in.size()) + reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) + return reg + + +def g_nonsaturating_loss(fake_pred): + loss = F.softplus(-fake_pred).mean() + + return loss + +def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +# def r1_reg(d_out, x_in): +# # zero-centered gradient penalty for real images +# batch_size = x_in.size(0) +# grad_dout = torch.autograd.grad( +# outputs=d_out.sum(), inputs=x_in, +# create_graph=True, retain_graph=True, only_inputs=True +# )[0] +# grad_dout2 = grad_dout.pow(2) +# assert(grad_dout2.size() == x_in.size()) +# reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) +# return reg + +def train_loop( + rank, + config, + reporter, + temp_dir + ): + + version = config["version"] + + ckpt_dir = config["project_checkpoints"] + sample_dir = config["project_samples"] + + log_freq = config["log_step"] + model_freq = config["model_save_step"] + sample_freq = config["sample_step"] + total_step = config["total_step"] + random_seed = config["dataset_params"]["random_seed"] + d_reg_freq = config["d_reg_freq"] + + id_w = config["id_weight"] + rec_w = config["reconstruct_weight"] + rec_fm_w = config["rec_feature_match_weight"] + mask_w = config["mask_weight"] + cycle_w = config["cycle_weight"] + reg_w = config["reg_weight"] + num_gpus = len(config["gpus"]) + batch_gpu = config["batch_size"] // num_gpus + + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + + + + if rank == 0: + img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + + # Initialize. + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Create dataloader. + if rank == 0: + print('Loading training set...') + + dataset = config["dataset_paths"][config["dataset_name"]] + #================================================# + print("Prepare the train dataloader...") + dlModulename = config["dataloader"] + package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'GetLoader') + dataloader_class= dataloaderClass + dataloader = dataloader_class(dataset, + rank, + num_gpus, + batch_gpu, + **config["dataset_params"]) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + gen, dis, arcface = init_framework(config, reporter, device, rank) + + # Check for existing checkpoint + + # Print network summary tables. + # if rank == 0: + # attr = torch.empty([batch_gpu, 3, 512, 512], device=device) + # id = torch.empty([batch_gpu, 3, 112, 112], device=device) + # latent = misc.print_module_summary(arcface, [id]) + # img = misc.print_module_summary(gen, [attr, latent]) + # misc.print_module_summary(dis, [img, None]) + # del attr + # del id + # del latent + # del img + # torch.cuda.empty_cache() + + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [gen, dis, arcface]: + if module is not None and num_gpus > 1: + for param in misc.params_and_buffers(module): + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + #==============build tensorboard=================# + if config["logger"] == "tensorboard": + import torch.utils.tensorboard as tensorboard + tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) + logger = tensorboard_writer + + elif config["logger"] == "wandb": + import wandb + wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", + tags=[config["tag"]], name=version) + + wandb.config = { + "total_step": config["total_step"], + "batch_size": config["batch_size"] + } + logger = wandb + + + random.seed(random_seed) + randindex = [i for i in range(batch_gpu)] + + # set the start point for training loop + if config["phase"] == "finetune": + start = config["ckpt"] + else: + start = 0 + if rank == 0: + import datetime + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from utilities.logo_class import logo_class + logo_class.print_start_training() + + for step in range(start, total_step): + gen.train() + dis.train() + + for interval in range(2): + + src_image1, src_image2, mask_label = dataloader.next() + + src_image1 = F.interpolate(src_image1,size=(256,256), mode='bicubic') + src_image2 = F.interpolate(src_image2,size=(256,256), mode='bicubic') + + if step%2 == 0: + img_id = src_image2 + else: + random.shuffle(randindex) + img_id = src_image2[randindex] + mask_label = mask_label[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + if interval == 0: + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + requires_grad(dis, True) + requires_grad(gen, False) + + d_regularize = step % d_reg_freq == 0 + if d_regularize: + src_image1.requires_grad_() + + real_logits,_ = dis(src_image1) + with torch.no_grad(): + img_fake,_,_ = gen(src_image1, latent_id.detach()) + fake_logits,_ = dis(img_fake.detach()) + + loss_D = d_logistic_loss(real_logits, fake_logits) + + if d_regularize: + loss_reg = d_r1_loss(real_logits, src_image1) + loss_D += loss_reg * reg_w * d_reg_freq + + d_optimizer.zero_grad(set_to_none=True) + loss_D.backward() + + with torch.autograd.profiler.record_function('discriminator_opt'): + # params = [param for param in dis.parameters() if param.grad is not None] + # if len(params) > 0: + # flat = torch.cat([param.grad.flatten() for param in params]) + # if num_gpus > 1: + # torch.distributed.all_reduce(flat) + # flat /= num_gpus + # misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + # grads = flat.split([param.numel() for param in params]) + # for param, grad in zip(params, grads): + # param.grad = grad.reshape(param.shape) + params = [param for param in dis.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + d_optimizer.step() + + + #================================Generator interval======================================# + else: + requires_grad(dis, False) + requires_grad(gen, True) + # model.netD.requires_grad_(True) + img_fake,lr_mask,hr_mask= gen(src_image1, latent_id.detach()) + # G loss + gen_logits,fake_feat= dis(img_fake) + # real_feat = dis.get_feature(src_image1) + loss_Gmain = g_nonsaturating_loss(gen_logits) + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id.detach())).mean() + mask_label_lr = F.interpolate(mask_label, size=(64,64), mode='bicubic') + mask_label = F.interpolate(mask_label, size=(256,256), mode='bicubic') + loss_mask = l1_loss(lr_mask, mask_label_lr) + l1_loss(hr_mask, mask_label) + loss_G = loss_Gmain + loss_G_ID * id_w + loss_mask * mask_w + if step%2 == 0: + #G_Rec + real_feat = dis.get_feature(src_image1) + rec_fm = l1_loss(fake_feat, real_feat) + loss_G_Rec = l1_loss(img_fake, src_image1) + # lpips_loss = loss_fn_vgg(img_fake, src_image1).mean() + loss_G += (loss_G_Rec * rec_w + rec_fm_w * rec_fm) #+ rec_fm * rec_fm_w + else: + source1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic') + latent_source1 = arcface(source1_down) + latent_source1 = F.normalize(latent_source1, p=2, dim=1) + cycle_src,_,_ = gen(img_fake, latent_source1) + cycle_loss = l1_loss(src_image1,cycle_src) + # cycle_feat = dis.get_feature(cycle_src) + # cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) + l1_loss(real_feat["2"],cycle_feat["2"]) + loss_G += cycle_loss * cycle_w #+ cycle_fm * cycle_fm_w + + + g_optimizer.zero_grad(set_to_none=True) + loss_G.backward() + with torch.autograd.profiler.record_function('generator_opt'): + params = [param for param in gen.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + g_optimizer.step() + + # Print out log info + if rank == 0 and (step + 1) % log_freq == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + # epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \ + # rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + # format(version, elapsed, step, total_step, \ + # loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \ + # rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_loss: {:.4f}, rec_fm: {:.4f}, loss_mask: {:.4f}, D_loss: {:.4f}, D_R1: {:.4f}". \ + format(version, elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_loss.item(), rec_fm.item(), loss_mask.item(), loss_D.item(), loss_reg.item()) + print(epochinformation) + reporter.writeInfo(epochinformation) + + if config["logger"] == "tensorboard": + logger.add_scalar('G/G_loss', loss_G.item(), step) + logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + logger.add_scalar('G/cycle_loss', cycle_loss.item(), step) + # logger.add_scalar('G/cycle_fm', cycle_fm.item(), step) + # logger.add_scalar('G/rec_fm', rec_fm.item(), step) + logger.add_scalar('G/rec_fm', rec_fm.item(), step) + logger.add_scalar('G/loss_mask', loss_mask.item(), step) + logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + logger.add_scalar('D/D_loss', loss_D.item(), step) + logger.add_scalar('D/D_reg', loss_reg.item(), step) + elif config["logger"] == "wandb": + logger.log({"G_Loss": loss_G.item()}, step = step) + logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + logger.log({"cycle_loss": cycle_loss.item()}, step = step) + # logger.log({"cycle_fm": cycle_fm.item()}, step = step) + # logger.log({"rec_fm": rec_fm.item()}, step = step) + logger.log({"rec_fm": rec_fm.item()}, step = step) + logger.log({"loss_mask": loss_mask.item()}, step = step) + logger.log({"G_ID": loss_G_ID.item()}, step = step) + logger.log({"D_loss": loss_D.item()}, step = step) + logger.log({"D_reg": loss_reg.item()}, step = step) + torch.cuda.empty_cache() + + if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0): + gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* img_std + img_mean).numpy() + for r in range(batch_gpu): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_gpu): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1) + img_fake,_,_ = gen(image_infer, id_vector_src1) + + img_fake = img_fake.cpu() * img_std + img_fake = img_fake + img_mean + img_fake = img_fake.numpy() + # pred_mask = pred_mask.cpu().numpy() * 255 + for j in range(batch_gpu): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg')) + torch.cuda.empty_cache() + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if rank == 0 and (step+1) % model_freq==0: + + torch.save(gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + torch.save(dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + + torch.save(g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + + torch.save(d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + print("Rank %d process done!"%rank) + torch.distributed.barrier() \ No newline at end of file diff --git a/train_scripts/trainer_mgpu_fm.py b/train_scripts/trainer_mgpu_fm.py new file mode 100644 index 0000000..8920688 --- /dev/null +++ b/train_scripts/trainer_mgpu_fm.py @@ -0,0 +1,582 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Saturday, 2nd April 2022 1:48:32 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random +import shutil +import tempfile + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +from utilities.plot import plot_batch +from train_scripts.trainer_multigpu_base import TrainerBase + + +class Trainer(TrainerBase): + + def __init__(self, + config, + reporter): + super(Trainer, self).__init__(config, reporter) + + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) + + def train(self): + # Launch processes. + num_gpus = len(self.config["gpus"]) + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus) + +# TODO modify this function to build your models +def init_framework(config, reporter, device, rank): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + model_config = config["model_configs"] + + if config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + file1 = os.path.join("components", model_config["g_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + dscript_name = "components." + model_config["d_model"]["script"] + file1 = os.path.join("components", model_config["d_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + + elif config["phase"] == "finetune": + gscript_name = config["com_base"] + model_config["g_model"]["script"] + dscript_name = config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + reporter.writeInfo("Generator structure:") + reporter.writeModel(gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + dis = dis_class(**model_config["d_model"]["module_params"]) + + + # print and recorde model structure + reporter.writeInfo("Discriminator structure:") + reporter.writeModel(dis.__str__()) + + # arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + # arcface = arcface1['model'].module + + # arcface = iresnet100(pretrained=False, fp16=False) + # arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu')) + # arcface.eval() + arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + arcface = arcface1['model'].module + + # train in GPU + + # if in finetune phase, load the pretrained checkpoint + if config["phase"] == "finetune": + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["generator_name"])) + gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["discriminator_name"])) + dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + + gen = gen.to(device) + dis = dis.to(device) + arcface= arcface.to(device) + arcface.requires_grad_(False) + arcface.eval() + + + + return gen, dis, arcface + +# TODO modify this function to configurate the optimizer of your pipeline +def setup_optimizers(config, reporter, gen, dis, rank): + + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + g_train_opt = config['g_optim_config'] + d_train_opt = config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = config['optim_type'] + + if optim_type == 'Adam': + g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if config["phase"] == "finetune": + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["generator_name"])) + g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["discriminator_name"])) + d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"])) + return g_optimizer, d_optimizer + +def d_logistic_loss(real_pred, fake_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + +def d_r1_loss(d_out, x_in): + # zero-centered gradient penalty for real images + batch_size = x_in.size(0) + grad_dout = torch.autograd.grad( + outputs=d_out.sum(), inputs=x_in, + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_dout2 = grad_dout.pow(2) + assert(grad_dout2.size() == x_in.size()) + reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) + return reg + + +def g_nonsaturating_loss(fake_pred): + loss = F.softplus(-fake_pred).mean() + + return loss + +def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +# def r1_reg(d_out, x_in): +# # zero-centered gradient penalty for real images +# batch_size = x_in.size(0) +# grad_dout = torch.autograd.grad( +# outputs=d_out.sum(), inputs=x_in, +# create_graph=True, retain_graph=True, only_inputs=True +# )[0] +# grad_dout2 = grad_dout.pow(2) +# assert(grad_dout2.size() == x_in.size()) +# reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) +# return reg + +def train_loop( + rank, + config, + reporter, + temp_dir + ): + + version = config["version"] + + ckpt_dir = config["project_checkpoints"] + sample_dir = config["project_samples"] + + log_freq = config["log_step"] + model_freq = config["model_save_step"] + sample_freq = config["sample_step"] + total_step = config["total_step"] + random_seed = config["dataset_params"]["random_seed"] + d_reg_freq = config["d_reg_freq"] + + id_w = config["id_weight"] + rec_w = config["reconstruct_weight"] + rec_fm_w = config["rec_feature_match_weight"] + cycle_fm_w = config["cycle_feature_match_weight"] + cycle_w = config["cycle_weight"] + reg_w = config["reg_weight"] + num_gpus = len(config["gpus"]) + batch_gpu = config["batch_size"] // num_gpus + + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + + + + if rank == 0: + img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + + # Initialize. + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Create dataloader. + if rank == 0: + print('Loading training set...') + + dataset = config["dataset_paths"][config["dataset_name"]] + #================================================# + print("Prepare the train dataloader...") + dlModulename = config["dataloader"] + package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'GetLoader') + dataloader_class= dataloaderClass + dataloader = dataloader_class(dataset, + rank, + num_gpus, + batch_gpu, + **config["dataset_params"]) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + gen, dis, arcface = init_framework(config, reporter, device, rank) + + # Check for existing checkpoint + + # Print network summary tables. + # if rank == 0: + # attr = torch.empty([batch_gpu, 3, 512, 512], device=device) + # id = torch.empty([batch_gpu, 3, 112, 112], device=device) + # latent = misc.print_module_summary(arcface, [id]) + # img = misc.print_module_summary(gen, [attr, latent]) + # misc.print_module_summary(dis, [img, None]) + # del attr + # del id + # del latent + # del img + # torch.cuda.empty_cache() + + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [gen, dis, arcface]: + if module is not None and num_gpus > 1: + for param in misc.params_and_buffers(module): + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + #==============build tensorboard=================# + if config["logger"] == "tensorboard": + import torch.utils.tensorboard as tensorboard + tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) + logger = tensorboard_writer + + elif config["logger"] == "wandb": + import wandb + wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", + tags=[config["tag"]], name=version) + + wandb.config = { + "total_step": config["total_step"], + "batch_size": config["batch_size"] + } + logger = wandb + + + random.seed(random_seed) + randindex = [i for i in range(batch_gpu)] + + # set the start point for training loop + if config["phase"] == "finetune": + start = config["ckpt"] + else: + start = 0 + if rank == 0: + import datetime + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from utilities.logo_class import logo_class + logo_class.print_start_training() + + for step in range(start, total_step): + gen.train() + dis.train() + + for interval in range(2): + + src_image1, src_image2 = dataloader.next() + + if step%2 == 0: + img_id = src_image2 + else: + random.shuffle(randindex) + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + if interval == 0: + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + requires_grad(dis, True) + requires_grad(gen, False) + + d_regularize = step % d_reg_freq == 0 + if d_regularize: + src_image1.requires_grad_() + + real_logits,_ = dis(src_image1) + with torch.no_grad(): + img_fake = gen(src_image1, latent_id.detach()) + fake_logits,_ = dis(img_fake.detach()) + + loss_D = d_logistic_loss(real_logits, fake_logits) + + if d_regularize: + loss_reg = d_r1_loss(real_logits, src_image1) + loss_D += loss_reg * reg_w * d_reg_freq + + d_optimizer.zero_grad(set_to_none=True) + loss_D.backward() + + with torch.autograd.profiler.record_function('discriminator_opt'): + # params = [param for param in dis.parameters() if param.grad is not None] + # if len(params) > 0: + # flat = torch.cat([param.grad.flatten() for param in params]) + # if num_gpus > 1: + # torch.distributed.all_reduce(flat) + # flat /= num_gpus + # misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + # grads = flat.split([param.numel() for param in params]) + # for param, grad in zip(params, grads): + # param.grad = grad.reshape(param.shape) + params = [param for param in dis.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + d_optimizer.step() + + + #================================Generator interval======================================# + else: + requires_grad(dis, False) + requires_grad(gen, True) + # model.netD.requires_grad_(True) + img_fake = gen(src_image1, latent_id.detach()) + # G loss + gen_logits,fake_feat= dis(img_fake) + # real_feat = dis.get_feature(src_image1) + loss_Gmain = g_nonsaturating_loss(gen_logits) + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id.detach())).mean() + loss_G = loss_Gmain + loss_G_ID * id_w + if step%2 == 0: + #G_Rec + real_feat = dis.get_feature(src_image1) + rec_fm = l1_loss(fake_feat, real_feat) + loss_G_Rec = l1_loss(img_fake, src_image1) + # lpips_loss = loss_fn_vgg(img_fake, src_image1).mean() + loss_G += (loss_G_Rec * rec_w + rec_fm_w * rec_fm) #+ rec_fm * rec_fm_w + else: + source1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic') + latent_source1 = arcface(source1_down) + latent_source1 = F.normalize(latent_source1, p=2, dim=1) + cycle_src = gen(img_fake, latent_source1) + cycle_loss = l1_loss(src_image1,cycle_src) + # cycle_feat = dis.get_feature(cycle_src) + # cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) + l1_loss(real_feat["2"],cycle_feat["2"]) + loss_G += cycle_loss * cycle_w #+ cycle_fm * cycle_fm_w + + + g_optimizer.zero_grad(set_to_none=True) + loss_G.backward() + with torch.autograd.profiler.record_function('generator_opt'): + params = [param for param in gen.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + g_optimizer.step() + + # Print out log info + if rank == 0 and (step + 1) % log_freq == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + # epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \ + # rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + # format(version, elapsed, step, total_step, \ + # loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \ + # rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_loss: {:.4f}, rec_fm: {:.4f}, D_loss: {:.4f}, D_R1: {:.4f}". \ + format(version, elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_loss.item(), rec_fm.item(), loss_D.item(), loss_reg.item()) + print(epochinformation) + reporter.writeInfo(epochinformation) + + if config["logger"] == "tensorboard": + logger.add_scalar('G/G_loss', loss_G.item(), step) + logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + logger.add_scalar('G/cycle_loss', cycle_loss.item(), step) + # logger.add_scalar('G/cycle_fm', cycle_fm.item(), step) + # logger.add_scalar('G/rec_fm', rec_fm.item(), step) + logger.add_scalar('G/lpips_loss', rec_fm.item(), step) + logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + logger.add_scalar('D/D_loss', loss_D.item(), step) + logger.add_scalar('D/D_reg', loss_reg.item(), step) + elif config["logger"] == "wandb": + logger.log({"G_Loss": loss_G.item()}, step = step) + logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + logger.log({"cycle_loss": cycle_loss.item()}, step = step) + # logger.log({"cycle_fm": cycle_fm.item()}, step = step) + # logger.log({"rec_fm": rec_fm.item()}, step = step) + logger.log({"lpips_loss": rec_fm.item()}, step = step) + logger.log({"G_ID": loss_G_ID.item()}, step = step) + logger.log({"D_loss": loss_D.item()}, step = step) + logger.log({"D_reg": loss_reg.item()}, step = step) + torch.cuda.empty_cache() + + if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0): + gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* img_std + img_mean).numpy() + for r in range(batch_gpu): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_gpu): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1) + img_fake = gen(image_infer, id_vector_src1).cpu() + + img_fake = img_fake * img_std + img_fake = img_fake + img_mean + img_fake = img_fake.numpy() + for j in range(batch_gpu): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg')) + torch.cuda.empty_cache() + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if rank == 0 and (step+1) % model_freq==0: + + torch.save(gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + torch.save(dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + + torch.save(g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + + torch.save(d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + print("Rank %d process done!"%rank) + torch.distributed.barrier() \ No newline at end of file diff --git a/train_scripts/trainer_mgpu_fm_w_mask.py b/train_scripts/trainer_mgpu_fm_w_mask.py new file mode 100644 index 0000000..e4425c5 --- /dev/null +++ b/train_scripts/trainer_mgpu_fm_w_mask.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Wednesday, 13th April 2022 5:37:02 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random +import shutil +import tempfile + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +from utilities.plot import plot_batch +from train_scripts.trainer_multigpu_base import TrainerBase + + +class Trainer(TrainerBase): + + def __init__(self, + config, + reporter): + super(Trainer, self).__init__(config, reporter) + + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) + + def train(self): + # Launch processes. + num_gpus = len(self.config["gpus"]) + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus) + +# TODO modify this function to build your models +def init_framework(config, reporter, device, rank): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + model_config = config["model_configs"] + + if config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + file1 = os.path.join("components", model_config["g_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + dscript_name = "components." + model_config["d_model"]["script"] + file1 = os.path.join("components", model_config["d_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + + elif config["phase"] == "finetune": + gscript_name = config["com_base"] + model_config["g_model"]["script"] + dscript_name = config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + reporter.writeInfo("Generator structure:") + reporter.writeModel(gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + dis = dis_class(**model_config["d_model"]["module_params"]) + + + # print and recorde model structure + reporter.writeInfo("Discriminator structure:") + reporter.writeModel(dis.__str__()) + + # arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + # arcface = arcface1['model'].module + + # arcface = iresnet100(pretrained=False, fp16=False) + # arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu')) + # arcface.eval() + arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + arcface = arcface1['model'].module + + # train in GPU + + # if in finetune phase, load the pretrained checkpoint + if config["phase"] == "finetune": + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["generator_name"])) + gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["discriminator_name"])) + dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + + gen = gen.to(device) + dis = dis.to(device) + arcface= arcface.to(device) + arcface.requires_grad_(False) + arcface.eval() + + + + return gen, dis, arcface + +# TODO modify this function to configurate the optimizer of your pipeline +def setup_optimizers(config, reporter, gen, dis, rank): + + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + g_train_opt = config['g_optim_config'] + d_train_opt = config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = config['optim_type'] + + if optim_type == 'Adam': + g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if config["phase"] == "finetune": + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["generator_name"])) + g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["discriminator_name"])) + d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"])) + return g_optimizer, d_optimizer + +def d_logistic_loss(real_pred, fake_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + +def d_r1_loss(d_out, x_in): + # zero-centered gradient penalty for real images + batch_size = x_in.size(0) + grad_dout = torch.autograd.grad( + outputs=d_out.sum(), inputs=x_in, + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_dout2 = grad_dout.pow(2) + assert(grad_dout2.size() == x_in.size()) + reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) + return reg + + +def g_nonsaturating_loss(fake_pred): + loss = F.softplus(-fake_pred).mean() + + return loss + +def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +# def r1_reg(d_out, x_in): +# # zero-centered gradient penalty for real images +# batch_size = x_in.size(0) +# grad_dout = torch.autograd.grad( +# outputs=d_out.sum(), inputs=x_in, +# create_graph=True, retain_graph=True, only_inputs=True +# )[0] +# grad_dout2 = grad_dout.pow(2) +# assert(grad_dout2.size() == x_in.size()) +# reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) +# return reg + +def train_loop( + rank, + config, + reporter, + temp_dir + ): + + version = config["version"] + + ckpt_dir = config["project_checkpoints"] + sample_dir = config["project_samples"] + + log_freq = config["log_step"] + model_freq = config["model_save_step"] + sample_freq = config["sample_step"] + total_step = config["total_step"] + random_seed = config["dataset_params"]["random_seed"] + d_reg_freq = config["d_reg_freq"] + + id_w = config["id_weight"] + rec_w = config["reconstruct_weight"] + rec_fm_w = config["rec_feature_match_weight"] + cycle_w = config["cycle_weight"] + reg_w = config["reg_weight"] + num_gpus = len(config["gpus"]) + batch_gpu = config["batch_size"] // num_gpus + + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + + + + if rank == 0: + img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + + # Initialize. + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Create dataloader. + if rank == 0: + print('Loading training set...') + + dataset = config["dataset_paths"][config["dataset_name"]] + #================================================# + print("Prepare the train dataloader...") + dlModulename = config["dataloader"] + package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'GetLoader') + dataloader_class= dataloaderClass + dataloader = dataloader_class(dataset, + rank, + num_gpus, + batch_gpu, + **config["dataset_params"]) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + gen, dis, arcface = init_framework(config, reporter, device, rank) + + # Check for existing checkpoint + + # Print network summary tables. + # if rank == 0: + # attr = torch.empty([batch_gpu, 3, 512, 512], device=device) + # id = torch.empty([batch_gpu, 3, 112, 112], device=device) + # latent = misc.print_module_summary(arcface, [id]) + # img = misc.print_module_summary(gen, [attr, latent]) + # misc.print_module_summary(dis, [img, None]) + # del attr + # del id + # del latent + # del img + # torch.cuda.empty_cache() + + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [gen, dis, arcface]: + if module is not None and num_gpus > 1: + for param in misc.params_and_buffers(module): + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + #==============build tensorboard=================# + if config["logger"] == "tensorboard": + import torch.utils.tensorboard as tensorboard + tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) + logger = tensorboard_writer + + elif config["logger"] == "wandb": + import wandb + wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", + tags=[config["tag"]], name=version) + + wandb.config = { + "total_step": config["total_step"], + "batch_size": config["batch_size"] + } + logger = wandb + + + random.seed(random_seed) + randindex = [i for i in range(batch_gpu)] + + # set the start point for training loop + if config["phase"] == "finetune": + start = config["ckpt"] + else: + start = 0 + if rank == 0: + import datetime + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from utilities.logo_class import logo_class + logo_class.print_start_training() + + for step in range(start, total_step): + gen.train() + dis.train() + + for interval in range(2): + + src_image1, src_image2 = dataloader.next() + + if step%2 == 0: + img_id = src_image2 + else: + random.shuffle(randindex) + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + if interval == 0: + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + requires_grad(dis, True) + requires_grad(gen, False) + + d_regularize = step % d_reg_freq == 0 + if d_regularize: + src_image1.requires_grad_() + + real_logits,_ = dis(src_image1) + with torch.no_grad(): + img_fake,_ = gen(src_image1, latent_id.detach()) + fake_logits,_ = dis(img_fake.detach()) + + loss_D = d_logistic_loss(real_logits, fake_logits) + + if d_regularize: + loss_reg = d_r1_loss(real_logits, src_image1) + loss_D += loss_reg * reg_w * d_reg_freq + + d_optimizer.zero_grad(set_to_none=True) + loss_D.backward() + + with torch.autograd.profiler.record_function('discriminator_opt'): + # params = [param for param in dis.parameters() if param.grad is not None] + # if len(params) > 0: + # flat = torch.cat([param.grad.flatten() for param in params]) + # if num_gpus > 1: + # torch.distributed.all_reduce(flat) + # flat /= num_gpus + # misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + # grads = flat.split([param.numel() for param in params]) + # for param, grad in zip(params, grads): + # param.grad = grad.reshape(param.shape) + params = [param for param in dis.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + d_optimizer.step() + + + #================================Generator interval======================================# + else: + requires_grad(dis, False) + requires_grad(gen, True) + # model.netD.requires_grad_(True) + img_fake,_ = gen(src_image1, latent_id.detach()) + # G loss + gen_logits,fake_feat= dis(img_fake) + # real_feat = dis.get_feature(src_image1) + loss_Gmain = g_nonsaturating_loss(gen_logits) + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id.detach())).mean() + loss_G = loss_Gmain + loss_G_ID * id_w + if step%2 == 0: + #G_Rec + real_feat = dis.get_feature(src_image1) + rec_fm = l1_loss(fake_feat, real_feat) + loss_G_Rec = l1_loss(img_fake, src_image1) + # lpips_loss = loss_fn_vgg(img_fake, src_image1).mean() + loss_G += (loss_G_Rec * rec_w + rec_fm_w * rec_fm) #+ rec_fm * rec_fm_w + else: + source1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic') + latent_source1 = arcface(source1_down) + latent_source1 = F.normalize(latent_source1, p=2, dim=1) + cycle_src,_ = gen(img_fake, latent_source1) + cycle_loss = l1_loss(src_image1,cycle_src) + # cycle_feat = dis.get_feature(cycle_src) + # cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) + l1_loss(real_feat["2"],cycle_feat["2"]) + loss_G += cycle_loss * cycle_w #+ cycle_fm * cycle_fm_w + + + g_optimizer.zero_grad(set_to_none=True) + loss_G.backward() + with torch.autograd.profiler.record_function('generator_opt'): + params = [param for param in gen.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + g_optimizer.step() + + # Print out log info + if rank == 0 and (step + 1) % log_freq == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + # epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \ + # rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + # format(version, elapsed, step, total_step, \ + # loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \ + # rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_loss: {:.4f}, rec_fm: {:.4f}, D_loss: {:.4f}, D_R1: {:.4f}". \ + format(version, elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_loss.item(), rec_fm.item(), loss_D.item(), loss_reg.item()) + print(epochinformation) + reporter.writeInfo(epochinformation) + + if config["logger"] == "tensorboard": + logger.add_scalar('G/G_loss', loss_G.item(), step) + logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + logger.add_scalar('G/cycle_loss', cycle_loss.item(), step) + # logger.add_scalar('G/cycle_fm', cycle_fm.item(), step) + # logger.add_scalar('G/rec_fm', rec_fm.item(), step) + logger.add_scalar('G/lpips_loss', rec_fm.item(), step) + logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + logger.add_scalar('D/D_loss', loss_D.item(), step) + logger.add_scalar('D/D_reg', loss_reg.item(), step) + elif config["logger"] == "wandb": + logger.log({"G_Loss": loss_G.item()}, step = step) + logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + logger.log({"cycle_loss": cycle_loss.item()}, step = step) + # logger.log({"cycle_fm": cycle_fm.item()}, step = step) + # logger.log({"rec_fm": rec_fm.item()}, step = step) + logger.log({"lpips_loss": rec_fm.item()}, step = step) + logger.log({"G_ID": loss_G_ID.item()}, step = step) + logger.log({"D_loss": loss_D.item()}, step = step) + logger.log({"D_reg": loss_reg.item()}, step = step) + torch.cuda.empty_cache() + + if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0): + gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* img_std + img_mean).numpy() + for r in range(batch_gpu): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_gpu): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1) + img_fake,_= gen(image_infer, id_vector_src1) + + img_fake = img_fake.cpu() * img_std + img_fake = img_fake + img_mean + img_fake = img_fake.numpy() + for j in range(batch_gpu): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg')) + torch.cuda.empty_cache() + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if rank == 0 and (step+1) % model_freq==0: + + torch.save(gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + torch.save(dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + + torch.save(g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + + torch.save(d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + print("Rank %d process done!"%rank) + torch.distributed.barrier() \ No newline at end of file diff --git a/train_scripts/trainer_mgpu_maskloss.py b/train_scripts/trainer_mgpu_maskloss.py new file mode 100644 index 0000000..ac9f005 --- /dev/null +++ b/train_scripts/trainer_mgpu_maskloss.py @@ -0,0 +1,588 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 15th April 2022 2:21:12 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random +import shutil +import tempfile + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +from utilities.plot import plot_batch +from train_scripts.trainer_multigpu_base import TrainerBase + + +class Trainer(TrainerBase): + + def __init__(self, + config, + reporter): + super(Trainer, self).__init__(config, reporter) + + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) + + def train(self): + # Launch processes. + num_gpus = len(self.config["gpus"]) + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus) + +# TODO modify this function to build your models +def init_framework(config, reporter, device, rank): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + model_config = config["model_configs"] + + if config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + file1 = os.path.join("components", model_config["g_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + dscript_name = "components." + model_config["d_model"]["script"] + file1 = os.path.join("components", model_config["d_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + + elif config["phase"] == "finetune": + gscript_name = config["com_base"] + model_config["g_model"]["script"] + dscript_name = config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + reporter.writeInfo("Generator structure:") + reporter.writeModel(gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + dis = dis_class(**model_config["d_model"]["module_params"]) + + + # print and recorde model structure + reporter.writeInfo("Discriminator structure:") + reporter.writeModel(dis.__str__()) + + # arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + # arcface = arcface1['model'].module + + # arcface = iresnet100(pretrained=False, fp16=False) + # arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu')) + # arcface.eval() + arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + arcface = arcface1['model'].module + + # train in GPU + + # if in finetune phase, load the pretrained checkpoint + if config["phase"] == "finetune": + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["generator_name"])) + gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["discriminator_name"])) + dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + + gen = gen.to(device) + dis = dis.to(device) + arcface= arcface.to(device) + arcface.requires_grad_(False) + arcface.eval() + + + + return gen, dis, arcface + +# TODO modify this function to configurate the optimizer of your pipeline +def setup_optimizers(config, reporter, gen, dis, rank): + + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + g_train_opt = config['g_optim_config'] + d_train_opt = config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = config['optim_type'] + + if optim_type == 'Adam': + g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if config["phase"] == "finetune": + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["generator_name"])) + g_optimizer.load_state_dict(torch.load(opt_path), map_location=torch.device("cpu")) + + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["discriminator_name"])) + d_optimizer.load_state_dict(torch.load(opt_path), map_location=torch.device("cpu")) + + print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"])) + return g_optimizer, d_optimizer + +def d_logistic_loss(real_pred, fake_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + +def d_r1_loss(d_out, x_in): + # zero-centered gradient penalty for real images + batch_size = x_in.size(0) + grad_dout = torch.autograd.grad( + outputs=d_out.sum(), inputs=x_in, + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_dout2 = grad_dout.pow(2) + assert(grad_dout2.size() == x_in.size()) + reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) + return reg + + +def g_nonsaturating_loss(fake_pred): + loss = F.softplus(-fake_pred).mean() + + return loss + +def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +# def r1_reg(d_out, x_in): +# # zero-centered gradient penalty for real images +# batch_size = x_in.size(0) +# grad_dout = torch.autograd.grad( +# outputs=d_out.sum(), inputs=x_in, +# create_graph=True, retain_graph=True, only_inputs=True +# )[0] +# grad_dout2 = grad_dout.pow(2) +# assert(grad_dout2.size() == x_in.size()) +# reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) +# return reg + +def train_loop( + rank, + config, + reporter, + temp_dir + ): + + version = config["version"] + + ckpt_dir = config["project_checkpoints"] + sample_dir = config["project_samples"] + + log_freq = config["log_step"] + model_freq = config["model_save_step"] + sample_freq = config["sample_step"] + total_step = config["total_step"] + random_seed = config["dataset_params"]["random_seed"] + d_reg_freq = config["d_reg_freq"] + + id_w = config["id_weight"] + rec_w = config["reconstruct_weight"] + rec_fm_w = config["rec_feature_match_weight"] + mask_w = config["mask_weight"] + cycle_w = config["cycle_weight"] + reg_w = config["reg_weight"] + num_gpus = len(config["gpus"]) + batch_gpu = config["batch_size"] // num_gpus + + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + + + + if rank == 0: + img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + + # Initialize. + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Create dataloader. + if rank == 0: + print('Loading training set...') + + dataset = config["dataset_paths"][config["dataset_name"]] + #================================================# + print("Prepare the train dataloader...") + dlModulename = config["dataloader"] + package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'GetLoader') + dataloader_class= dataloaderClass + dataloader = dataloader_class(dataset, + rank, + num_gpus, + batch_gpu, + **config["dataset_params"]) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + gen, dis, arcface = init_framework(config, reporter, device, rank) + + # Check for existing checkpoint + + # Print network summary tables. + # if rank == 0: + # attr = torch.empty([batch_gpu, 3, 512, 512], device=device) + # id = torch.empty([batch_gpu, 3, 112, 112], device=device) + # latent = misc.print_module_summary(arcface, [id]) + # img = misc.print_module_summary(gen, [attr, latent]) + # misc.print_module_summary(dis, [img, None]) + # del attr + # del id + # del latent + # del img + # torch.cuda.empty_cache() + + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [gen, dis, arcface]: + if module is not None and num_gpus > 1: + for param in misc.params_and_buffers(module): + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + + g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + #==============build tensorboard=================# + if config["logger"] == "tensorboard": + import torch.utils.tensorboard as tensorboard + tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) + logger = tensorboard_writer + + elif config["logger"] == "wandb": + import wandb + wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", + tags=[config["tag"]], name=version) + + wandb.config = { + "total_step": config["total_step"], + "batch_size": config["batch_size"] + } + logger = wandb + + + random.seed(random_seed) + randindex = [i for i in range(batch_gpu)] + + # set the start point for training loop + if config["phase"] == "finetune": + start = config["ckpt"] + else: + start = 0 + if rank == 0: + import datetime + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from utilities.logo_class import logo_class + logo_class.print_start_training() + + for step in range(start, total_step): + gen.train() + dis.train() + + for interval in range(2): + + src_image1, src_image2, mask_label = dataloader.next() + + if step%2 == 0: + img_id = src_image2 + else: + random.shuffle(randindex) + img_id = src_image2[randindex] + mask_label = mask_label[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + if interval == 0: + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + requires_grad(dis, True) + requires_grad(gen, False) + + d_regularize = step % d_reg_freq == 0 + if d_regularize: + src_image1.requires_grad_() + + real_logits,_ = dis(src_image1) + with torch.no_grad(): + img_fake,_ = gen(src_image1, latent_id.detach()) + fake_logits,_ = dis(img_fake.detach()) + + loss_D = d_logistic_loss(real_logits, fake_logits) + + if d_regularize: + loss_reg = d_r1_loss(real_logits, src_image1) + loss_D += loss_reg * reg_w * d_reg_freq + + d_optimizer.zero_grad(set_to_none=True) + loss_D.backward() + + with torch.autograd.profiler.record_function('discriminator_opt'): + # params = [param for param in dis.parameters() if param.grad is not None] + # if len(params) > 0: + # flat = torch.cat([param.grad.flatten() for param in params]) + # if num_gpus > 1: + # torch.distributed.all_reduce(flat) + # flat /= num_gpus + # misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + # grads = flat.split([param.numel() for param in params]) + # for param, grad in zip(params, grads): + # param.grad = grad.reshape(param.shape) + params = [param for param in dis.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + d_optimizer.step() + + + #================================Generator interval======================================# + else: + requires_grad(dis, False) + requires_grad(gen, True) + # model.netD.requires_grad_(True) + img_fake,pred_mask= gen(src_image1, latent_id.detach()) + # G loss + gen_logits,fake_feat= dis(img_fake) + # real_feat = dis.get_feature(src_image1) + loss_Gmain = g_nonsaturating_loss(gen_logits) + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id.detach())).mean() + mask_label = F.interpolate(mask_label, size=(128,128), mode='bilinear') + loss_mask = l1_loss(pred_mask, mask_label) + loss_G = loss_Gmain + loss_G_ID * id_w + loss_mask * mask_w + if step%2 == 0: + #G_Rec + real_feat = dis.get_feature(src_image1) + rec_fm = l1_loss(fake_feat, real_feat) + loss_G_Rec = l1_loss(img_fake, src_image1) + # lpips_loss = loss_fn_vgg(img_fake, src_image1).mean() + loss_G += (loss_G_Rec * rec_w + rec_fm_w * rec_fm) #+ rec_fm * rec_fm_w + else: + source1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic') + latent_source1 = arcface(source1_down) + latent_source1 = F.normalize(latent_source1, p=2, dim=1) + cycle_src,_ = gen(img_fake, latent_source1) + cycle_loss = l1_loss(src_image1,cycle_src) + # cycle_feat = dis.get_feature(cycle_src) + # cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) + l1_loss(real_feat["2"],cycle_feat["2"]) + loss_G += cycle_loss * cycle_w #+ cycle_fm * cycle_fm_w + + + g_optimizer.zero_grad(set_to_none=True) + loss_G.backward() + with torch.autograd.profiler.record_function('generator_opt'): + params = [param for param in gen.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + g_optimizer.step() + + # Print out log info + if rank == 0 and (step + 1) % log_freq == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + # epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \ + # rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + # format(version, elapsed, step, total_step, \ + # loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \ + # rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_loss: {:.4f}, rec_fm: {:.4f}, loss_mask: {:.4f}, D_loss: {:.4f}, D_R1: {:.4f}". \ + format(version, elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_loss.item(), rec_fm.item(), loss_mask.item(), loss_D.item(), loss_reg.item()) + print(epochinformation) + reporter.writeInfo(epochinformation) + + if config["logger"] == "tensorboard": + logger.add_scalar('G/G_loss', loss_G.item(), step) + logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + logger.add_scalar('G/cycle_loss', cycle_loss.item(), step) + # logger.add_scalar('G/cycle_fm', cycle_fm.item(), step) + # logger.add_scalar('G/rec_fm', rec_fm.item(), step) + logger.add_scalar('G/rec_fm', rec_fm.item(), step) + logger.add_scalar('G/loss_mask', loss_mask.item(), step) + logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + logger.add_scalar('D/D_loss', loss_D.item(), step) + logger.add_scalar('D/D_reg', loss_reg.item(), step) + elif config["logger"] == "wandb": + logger.log({"G_Loss": loss_G.item()}, step = step) + logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + logger.log({"cycle_loss": cycle_loss.item()}, step = step) + # logger.log({"cycle_fm": cycle_fm.item()}, step = step) + # logger.log({"rec_fm": rec_fm.item()}, step = step) + logger.log({"rec_fm": rec_fm.item()}, step = step) + logger.log({"loss_mask": loss_mask.item()}, step = step) + logger.log({"G_ID": loss_G_ID.item()}, step = step) + logger.log({"D_loss": loss_D.item()}, step = step) + logger.log({"D_reg": loss_reg.item()}, step = step) + torch.cuda.empty_cache() + + if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0): + gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* img_std + img_mean).numpy() + for r in range(batch_gpu): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_gpu): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1) + img_fake,pred_mask = gen(image_infer, id_vector_src1) + + img_fake = img_fake.cpu() * img_std + img_fake = img_fake + img_mean + img_fake = img_fake.numpy() + pred_mask = pred_mask.cpu().numpy() * 255 + for j in range(batch_gpu): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg')) + torch.cuda.empty_cache() + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if rank == 0 and (step+1) % model_freq==0: + + torch.save(gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + torch.save(dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + + torch.save(g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + + torch.save(d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + print("Rank %d process done!"%rank) + torch.distributed.barrier() \ No newline at end of file diff --git a/train_scripts/trainer_multi_gpu.py b/train_scripts/trainer_multi_gpu.py index 2763ac4..73b5607 100644 --- a/train_scripts/trainer_multi_gpu.py +++ b/train_scripts/trainer_multi_gpu.py @@ -5,7 +5,7 @@ # Created Date: Sunday January 9th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 17th March 2022 1:01:52 am +# Last Modified: Saturday, 26th March 2022 4:58:52 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -433,9 +433,6 @@ def train_loop( if rank == 0 and (step + 1) % log_freq == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed)) - # print("ready to report losses") - # ID_Total= loss_G_ID - # torch.distributed.all_reduce(ID_Total) epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \ diff --git a/train_scripts/trainer_multi_gpu_cycle.py b/train_scripts/trainer_multi_gpu_cycle.py index 3997816..9abccb7 100644 --- a/train_scripts/trainer_multi_gpu_cycle.py +++ b/train_scripts/trainer_multi_gpu_cycle.py @@ -5,7 +5,7 @@ # Created Date: Sunday January 9th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 17th March 2022 1:01:52 am +# Last Modified: Sunday, 27th March 2022 12:58:54 am # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# @@ -397,8 +397,9 @@ def train_loop( # model.netD.requires_grad_(True) img_fake = gen(src_image1, latent_id) # G loss - gen_logits,feat = dis(img_fake, None) - real_feat = dis.get_feature(src_image1) + # gen_logits,feat = dis(img_fake, None) + gen_logits,_ = dis(img_fake, None) + # real_feat = dis.get_feature(src_image1) loss_Gmain = (-gen_logits).mean() img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') latent_fake = arcface(img_fake_down) @@ -407,18 +408,18 @@ def train_loop( loss_G = loss_Gmain + loss_G_ID * id_w if step%2 == 0: #G_Rec - rec_fm = l1_loss(feat["3"],real_feat["3"]) + # rec_fm = l1_loss(feat["3"],real_feat["3"]) + l1_loss(feat["2"],real_feat["2"]) loss_G_Rec = l1_loss(img_fake, src_image1) - loss_G += loss_G_Rec * rec_w + rec_fm * rec_fm_w + loss_G += loss_G_Rec * rec_w #+ rec_fm * rec_fm_w else: source1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic') latent_source1 = arcface(source1_down) latent_source1 = F.normalize(latent_source1, p=2, dim=1) cycle_src = gen(img_fake, latent_source1) cycle_loss = l1_loss(src_image1,cycle_src) - cycle_feat = dis.get_feature(cycle_src) - cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) - loss_G += cycle_loss * cycle_w + cycle_fm * cycle_fm_w + # cycle_feat = dis.get_feature(cycle_src) + # cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) + l1_loss(real_feat["2"],cycle_feat["2"]) + loss_G += cycle_loss * cycle_w #+ cycle_fm * cycle_fm_w g_optimizer.zero_grad(set_to_none=True) @@ -448,12 +449,14 @@ def train_loop( # ID_Total= loss_G_ID # torch.distributed.all_reduce(ID_Total) - epochinformation="[{}], Elapsed [{}], Step [{}/{}], \ - G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \ - rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + # epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \ + # rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + # format(version, elapsed, step, total_step, \ + # loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \ + # rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ format(version, elapsed, step, total_step, \ - loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \ - rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) print(epochinformation) reporter.writeInfo(epochinformation) @@ -461,8 +464,8 @@ def train_loop( logger.add_scalar('G/G_loss', loss_G.item(), step) logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) logger.add_scalar('G/cycle_loss', cycle_loss.item(), step) - logger.add_scalar('G/cycle_fm', cycle_fm.item(), step) - logger.add_scalar('G/rec_fm', rec_fm.item(), step) + # logger.add_scalar('G/cycle_fm', cycle_fm.item(), step) + # logger.add_scalar('G/rec_fm', rec_fm.item(), step) logger.add_scalar('G/G_ID', loss_G_ID.item(), step) logger.add_scalar('D/D_loss', loss_D.item(), step) logger.add_scalar('D/D_fake', loss_Dgen.item(), step) @@ -471,8 +474,8 @@ def train_loop( logger.log({"G_Loss": loss_G.item()}, step = step) logger.log({"G_Rec": loss_G_Rec.item()}, step = step) logger.log({"cycle_loss": cycle_loss.item()}, step = step) - logger.log({"cycle_fm": cycle_fm.item()}, step = step) - logger.log({"rec_fm": rec_fm.item()}, step = step) + # logger.log({"cycle_fm": cycle_fm.item()}, step = step) + # logger.log({"rec_fm": rec_fm.item()}, step = step) logger.log({"G_ID": loss_G_ID.item()}, step = step) logger.log({"D_loss": loss_D.item()}, step = step) logger.log({"D_fake": loss_Dgen.item()}, step = step) diff --git a/train_scripts/trainer_multi_gpu_cycle_nonstatue_dis.py b/train_scripts/trainer_multi_gpu_cycle_nonstatue_dis.py new file mode 100644 index 0000000..8c80e15 --- /dev/null +++ b/train_scripts/trainer_multi_gpu_cycle_nonstatue_dis.py @@ -0,0 +1,586 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: trainer_naiv512.py +# Created Date: Sunday January 9th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Tuesday, 29th March 2022 9:16:41 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import os +import time +import random +import shutil +import tempfile + +import lpips + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch_utils import misc +from torch_utils import training_stats +from torch_utils.ops import conv2d_gradfix +from torch_utils.ops import grid_sample_gradfix + +from arcface_torch.backbones.iresnet import iresnet100 + +from utilities.plot import plot_batch +from losses.cos import cosin_metric +from train_scripts.trainer_multigpu_base import TrainerBase + + +class Trainer(TrainerBase): + + def __init__(self, + config, + reporter): + super(Trainer, self).__init__(config, reporter) + + import inspect + print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe())) + + def train(self): + # Launch processes. + num_gpus = len(self.config["gpus"]) + print('Launching processes...') + torch.multiprocessing.set_start_method('spawn') + with tempfile.TemporaryDirectory() as temp_dir: + torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus) + +# TODO modify this function to build your models +def init_framework(config, reporter, device, rank): + ''' + This function is designed to define the framework, + and print the framework information into the log file + ''' + #===============build models================# + print("build models...") + # TODO [import models here] + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + model_config = config["model_configs"] + + if config["phase"] == "train": + gscript_name = "components." + model_config["g_model"]["script"] + file1 = os.path.join("components", model_config["g_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["g_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + dscript_name = "components." + model_config["d_model"]["script"] + file1 = os.path.join("components", model_config["d_model"]["script"]+".py") + tgtfile1 = os.path.join(config["project_scripts"], model_config["d_model"]["script"]+".py") + shutil.copyfile(file1,tgtfile1) + + elif config["phase"] == "finetune": + gscript_name = config["com_base"] + model_config["g_model"]["script"] + dscript_name = config["com_base"] + model_config["d_model"]["script"] + + class_name = model_config["g_model"]["class_name"] + package = __import__(gscript_name, fromlist=True) + gen_class = getattr(package, class_name) + gen = gen_class(**model_config["g_model"]["module_params"]) + + # print and recorde model structure + reporter.writeInfo("Generator structure:") + reporter.writeModel(gen.__str__()) + + class_name = model_config["d_model"]["class_name"] + package = __import__(dscript_name, fromlist=True) + dis_class = getattr(package, class_name) + dis = dis_class(**model_config["d_model"]["module_params"]) + + + # print and recorde model structure + reporter.writeInfo("Discriminator structure:") + reporter.writeModel(dis.__str__()) + + # arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + # arcface = arcface1['model'].module + + # arcface = iresnet100(pretrained=False, fp16=False) + # arcface.load_state_dict(torch.load(config["arcface_ckpt"], map_location='cpu')) + # arcface.eval() + arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu")) + arcface = arcface1['model'].module + + # train in GPU + + # if in finetune phase, load the pretrained checkpoint + if config["phase"] == "finetune": + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["generator_name"])) + gen.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + model_path = os.path.join(config["project_checkpoints"], + "step%d_%s.pth"%(config["ckpt"], + config["checkpoint_names"]["discriminator_name"])) + dis.load_state_dict(torch.load(model_path), map_location=torch.device("cpu")) + + print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"])) + + gen = gen.to(device) + dis = dis.to(device) + arcface= arcface.to(device) + arcface.requires_grad_(False) + arcface.eval() + + + + return gen, dis, arcface + +# TODO modify this function to configurate the optimizer of your pipeline +def setup_optimizers(config, reporter, gen, dis, rank): + + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + g_train_opt = config['g_optim_config'] + d_train_opt = config['d_optim_config'] + + g_optim_params = [] + d_optim_params = [] + for k, v in gen.named_parameters(): + if v.requires_grad: + g_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + for k, v in dis.named_parameters(): + if v.requires_grad: + d_optim_params.append(v) + else: + reporter.writeInfo(f'Params {k} will not be optimized.') + print(f'Params {k} will not be optimized.') + + optim_type = config['optim_type'] + + if optim_type == 'Adam': + g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt) + d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt) + else: + raise NotImplementedError( + f'optimizer {optim_type} is not supperted yet.') + # self.optimizers.append(self.optimizer_g) + if config["phase"] == "finetune": + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["generator_name"])) + g_optimizer.load_state_dict(torch.load(opt_path)) + + opt_path = os.path.join(config["project_checkpoints"], + "step%d_optim_%s.pth"%(config["ckpt"], + config["optimizer_names"]["discriminator_name"])) + d_optimizer.load_state_dict(torch.load(opt_path)) + + print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"])) + return g_optimizer, d_optimizer + +def d_logistic_loss(real_pred, fake_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + +def d_r1_loss(d_out, x_in): + # zero-centered gradient penalty for real images + batch_size = x_in.size(0) + grad_dout = torch.autograd.grad( + outputs=d_out.sum(), inputs=x_in, + create_graph=True, retain_graph=True, only_inputs=True + )[0] + grad_dout2 = grad_dout.pow(2) + assert(grad_dout2.size() == x_in.size()) + reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) + return reg + + +def g_nonsaturating_loss(fake_pred): + loss = F.softplus(-fake_pred).mean() + + return loss + +def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +# def r1_reg(d_out, x_in): +# # zero-centered gradient penalty for real images +# batch_size = x_in.size(0) +# grad_dout = torch.autograd.grad( +# outputs=d_out.sum(), inputs=x_in, +# create_graph=True, retain_graph=True, only_inputs=True +# )[0] +# grad_dout2 = grad_dout.pow(2) +# assert(grad_dout2.size() == x_in.size()) +# reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) +# return reg + +def train_loop( + rank, + config, + reporter, + temp_dir + ): + + version = config["version"] + + ckpt_dir = config["project_checkpoints"] + sample_dir = config["project_samples"] + + log_freq = config["log_step"] + model_freq = config["model_save_step"] + sample_freq = config["sample_step"] + total_step = config["total_step"] + random_seed = config["dataset_params"]["random_seed"] + d_reg_freq = config["d_reg_freq"] + + id_w = config["id_weight"] + rec_w = config["reconstruct_weight"] + rec_fm_w = config["rec_feature_match_weight"] + cycle_fm_w = config["cycle_feature_match_weight"] + cycle_w = config["cycle_weight"] + reg_w = config["reg_weight"] + lpips_w = config["lpips_weight"] + num_gpus = len(config["gpus"]) + batch_gpu = config["batch_size"] // num_gpus + + init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) + if os.name == 'nt': + init_method = 'file:///' + init_file.replace('\\', '/') + torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus) + else: + init_method = f'file://{init_file}' + torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus) + + # Init torch_utils. + sync_device = torch.device('cuda', rank) + training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) + + + + if rank == 0: + img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1) + img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1) + + + # Initialize. + device = torch.device('cuda', rank) + np.random.seed(random_seed * num_gpus + rank) + torch.manual_seed(random_seed * num_gpus + rank) + torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. + torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. + conv2d_gradfix.enabled = True # Improves training speed. + grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. + + # Create dataloader. + if rank == 0: + print('Loading training set...') + + dataset = config["dataset_paths"][config["dataset_name"]] + #================================================# + print("Prepare the train dataloader...") + dlModulename = config["dataloader"] + package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True) + dataloaderClass = getattr(package, 'GetLoader') + dataloader_class= dataloaderClass + dataloader = dataloader_class(dataset, + rank, + num_gpus, + batch_gpu, + **config["dataset_params"]) + + # Construct networks. + if rank == 0: + print('Constructing networks...') + gen, dis, arcface = init_framework(config, reporter, device, rank) + + # Check for existing checkpoint + + # Print network summary tables. + # if rank == 0: + # attr = torch.empty([batch_gpu, 3, 512, 512], device=device) + # id = torch.empty([batch_gpu, 3, 112, 112], device=device) + # latent = misc.print_module_summary(arcface, [id]) + # img = misc.print_module_summary(gen, [attr, latent]) + # misc.print_module_summary(dis, [img, None]) + # del attr + # del id + # del latent + # del img + # torch.cuda.empty_cache() + + + # Distribute across GPUs. + if rank == 0: + print(f'Distributing across {num_gpus} GPUs...') + for module in [gen, dis, arcface]: + if module is not None and num_gpus > 1: + for param in misc.params_and_buffers(module): + torch.distributed.broadcast(param, src=0) + + # Setup training phases. + if rank == 0: + print('Setting up training phases...') + #===============build losses===================# + # TODO replace below lines to build your losses + # MSE_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + cos_loss = torch.nn.CosineSimilarity() + loss_fn_vgg = lpips.LPIPS(net='vgg') + + g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank) + + # Initialize logs. + if rank == 0: + print('Initializing logs...') + #==============build tensorboard=================# + if config["logger"] == "tensorboard": + import torch.utils.tensorboard as tensorboard + tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"]) + logger = tensorboard_writer + + elif config["logger"] == "wandb": + import wandb + wandb.init(project="Simswap_HQ", entity="xhchen", notes="512", + tags=[config["tag"]], name=version) + + wandb.config = { + "total_step": config["total_step"], + "batch_size": config["batch_size"] + } + logger = wandb + + + random.seed(random_seed) + randindex = [i for i in range(batch_gpu)] + + # set the start point for training loop + if config["phase"] == "finetune": + start = config["ckpt"] + else: + start = 0 + if rank == 0: + import datetime + start_time = time.time() + + # Caculate the epoch number + print("Total step = %d"%total_step) + + print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + from utilities.logo_class import logo_class + logo_class.print_start_training() + + for step in range(start, total_step): + gen.train() + dis.train() + + for interval in range(2): + + src_image1, src_image2 = dataloader.next() + + if step%2 == 0: + img_id = src_image2 + else: + random.shuffle(randindex) + img_id = src_image2[randindex] + + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + if interval == 0: + img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic') + latent_id = arcface(img_id_112) + latent_id = F.normalize(latent_id, p=2, dim=1) + + requires_grad(dis, True) + requires_grad(gen, False) + + d_regularize = step % d_reg_freq == 0 + if d_regularize: + src_image1.requires_grad_() + + real_logits = dis(src_image1) + with torch.no_grad(): + img_fake = gen(src_image1, latent_id.detach()) + fake_logits = dis(img_fake.detach()) + + loss_D = d_logistic_loss(real_logits, fake_logits) + + if d_regularize: + loss_reg = d_r1_loss(real_logits, src_image1) + loss_D += loss_reg * reg_w * d_reg_freq + + d_optimizer.zero_grad(set_to_none=True) + loss_D.backward() + + with torch.autograd.profiler.record_function('discriminator_opt'): + # params = [param for param in dis.parameters() if param.grad is not None] + # if len(params) > 0: + # flat = torch.cat([param.grad.flatten() for param in params]) + # if num_gpus > 1: + # torch.distributed.all_reduce(flat) + # flat /= num_gpus + # misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + # grads = flat.split([param.numel() for param in params]) + # for param, grad in zip(params, grads): + # param.grad = grad.reshape(param.shape) + params = [param for param in dis.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + d_optimizer.step() + + + #================================Generator interval======================================# + else: + requires_grad(dis, False) + requires_grad(gen, True) + # model.netD.requires_grad_(True) + img_fake = gen(src_image1, latent_id.detach()) + # G loss + gen_logits = dis(img_fake) + # real_feat = dis.get_feature(src_image1) + loss_Gmain = g_nonsaturating_loss(gen_logits) + img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic') + latent_fake = arcface(img_fake_down) + latent_fake = F.normalize(latent_fake, p=2, dim=1) + loss_G_ID = (1 - cos_loss(latent_fake, latent_id.detach())).mean() + loss_G = loss_Gmain + loss_G_ID * id_w + if step%2 == 0: + #G_Rec + # rec_fm = l1_loss(feat["3"],real_feat["3"]) + l1_loss(feat["2"],real_feat["2"]) + loss_G_Rec = l1_loss(img_fake, src_image1) + lpips_loss = loss_fn_vgg(img_fake, src_image1).mean() + loss_G += loss_G_Rec * rec_w + lpips_w * lpips_loss #+ rec_fm * rec_fm_w + else: + source1_down = F.interpolate(src_image1, size=(112,112), mode='bicubic') + latent_source1 = arcface(source1_down) + latent_source1 = F.normalize(latent_source1, p=2, dim=1) + cycle_src = gen(img_fake, latent_source1) + cycle_loss = l1_loss(src_image1,cycle_src) + # cycle_feat = dis.get_feature(cycle_src) + # cycle_fm = l1_loss(real_feat["3"],cycle_feat["3"]) + l1_loss(real_feat["2"],cycle_feat["2"]) + loss_G += cycle_loss * cycle_w #+ cycle_fm * cycle_fm_w + + + g_optimizer.zero_grad(set_to_none=True) + loss_G.backward() + with torch.autograd.profiler.record_function('generator_opt'): + params = [param for param in gen.parameters() if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + torch.distributed.all_reduce(flat) + flat /= num_gpus + misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + g_optimizer.step() + + # Print out log info + if rank == 0 and (step + 1) % log_freq == 0: + elapsed = time.time() - start_time + elapsed = str(datetime.timedelta(seconds=elapsed)) + # epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_fm: {:.4f}, \ + # rec_fm: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \ + # format(version, elapsed, step, total_step, \ + # loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_fm.item(), \ + # rec_fm.item(), cycle_loss.item(), loss_D.item(), loss_Dgen.item(), loss_Dreal.item()) + epochinformation="[{}], Elapsed [{}], Step [{}/{}], G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, cycle_loss: {:.4f}, D_loss: {:.4f}, D_R1: {:.4f}". \ + format(version, elapsed, step, total_step, \ + loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), cycle_loss.item(), loss_D.item(), loss_reg.item()) + print(epochinformation) + reporter.writeInfo(epochinformation) + + if config["logger"] == "tensorboard": + logger.add_scalar('G/G_loss', loss_G.item(), step) + logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step) + logger.add_scalar('G/cycle_loss', cycle_loss.item(), step) + # logger.add_scalar('G/cycle_fm', cycle_fm.item(), step) + # logger.add_scalar('G/rec_fm', rec_fm.item(), step) + logger.add_scalar('G/G_ID', loss_G_ID.item(), step) + logger.add_scalar('D/D_loss', loss_D.item(), step) + logger.add_scalar('D/D_reg', loss_reg.item(), step) + elif config["logger"] == "wandb": + logger.log({"G_Loss": loss_G.item()}, step = step) + logger.log({"G_Rec": loss_G_Rec.item()}, step = step) + logger.log({"cycle_loss": cycle_loss.item()}, step = step) + # logger.log({"cycle_fm": cycle_fm.item()}, step = step) + # logger.log({"rec_fm": rec_fm.item()}, step = step) + logger.log({"G_ID": loss_G_ID.item()}, step = step) + logger.log({"D_loss": loss_D.item()}, step = step) + logger.log({"D_reg": loss_reg.item()}, step = step) + torch.cuda.empty_cache() + + if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0): + gen.eval() + with torch.no_grad(): + imgs = [] + zero_img = (torch.zeros_like(src_image1[0,...])) + imgs.append(zero_img.cpu().numpy()) + save_img = ((src_image1.cpu())* img_std + img_mean).numpy() + for r in range(batch_gpu): + imgs.append(save_img[r,...]) + arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic') + id_vector_src1 = arcface(arcface_112) + id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1) + + for i in range(batch_gpu): + + imgs.append(save_img[i,...]) + image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1) + img_fake = gen(image_infer, id_vector_src1).cpu() + + img_fake = img_fake * img_std + img_fake = img_fake + img_mean + img_fake = img_fake.numpy() + for j in range(batch_gpu): + imgs.append(img_fake[j,...]) + print("Save test data") + imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1) + plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg')) + torch.cuda.empty_cache() + + + + #===============adjust learning rate============# + # if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]: + # print("Learning rate decay") + # for p in self.optimizer.param_groups: + # p['lr'] *= self.config["lr_decay"] + # print("Current learning rate is %f"%p['lr']) + + #===============save checkpoints================# + if rank == 0 and (step+1) % model_freq==0: + + torch.save(gen.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + torch.save(dis.state_dict(), + os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + + torch.save(g_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["generator_name"]))) + + torch.save(d_optimizer.state_dict(), + os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1, + config["checkpoint_names"]["discriminator_name"]))) + print("Save step %d model checkpoint!"%(step+1)) + torch.cuda.empty_cache() + print("Rank %d process done!"%rank) + torch.distributed.barrier() \ No newline at end of file diff --git a/train_yamls/train_1maskhead.yaml b/train_yamls/train_1maskhead.yaml new file mode 100644 index 0000000..2c6d1cb --- /dev/null +++ b/train_yamls/train_1maskhead.yaml @@ -0,0 +1,75 @@ +# Related scripts +train_script_name: mgpu_maskloss + +# models' scripts +model_configs: + g_model: + script: Generator_starganv2 + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 16 + +# Dataset +dataloader: VGGFace2HQ_multigpu_w_mask +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 8 +id_weight: 30.0 +reconstruct_weight: 10.0 +rec_feature_match_weight: 3.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 8.0 +mask_weight: 100.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_2maskhead.yaml b/train_yamls/train_2maskhead.yaml new file mode 100644 index 0000000..e0658d1 --- /dev/null +++ b/train_yamls/train_2maskhead.yaml @@ -0,0 +1,75 @@ +# Related scripts +train_script_name: mgpu_2maskloss + +# models' scripts +model_configs: + g_model: + script: Generator_2mask + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 8 + +# Dataset +dataloader: VGGFace2HQ_multigpu_w_mask +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 30.0 +reconstruct_weight: 10.0 +rec_feature_match_weight: 3.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 8.0 +mask_weight: 100.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_2maskhead2.yaml b/train_yamls/train_2maskhead2.yaml new file mode 100644 index 0000000..2cd61c4 --- /dev/null +++ b/train_yamls/train_2maskhead2.yaml @@ -0,0 +1,75 @@ +# Related scripts +train_script_name: mgpu_2maskloss + +# models' scripts +model_configs: + g_model: + script: Generator_2mask2 + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 16 + +# Dataset +dataloader: VGGFace2HQ_multigpu_w_mask +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 35.0 +reconstruct_weight: 10.0 +rec_feature_match_weight: 3.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 8.0 +mask_weight: 100.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_2maskhead_256.yaml b/train_yamls/train_2maskhead_256.yaml new file mode 100644 index 0000000..9a02489 --- /dev/null +++ b/train_yamls/train_2maskhead_256.yaml @@ -0,0 +1,75 @@ +# Related scripts +train_script_name: mgpu_2maskloss_256 + +# models' scripts +model_configs: + g_model: + script: Generator_256 + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 256 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 32 + +# Dataset +dataloader: VGGFace2HQ_multigpu_w_mask +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 35.0 +reconstruct_weight: 10.0 +rec_feature_match_weight: 3.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 8.0 +mask_weight: 100.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_2maskhead_DWConv.yaml b/train_yamls/train_2maskhead_DWConv.yaml new file mode 100644 index 0000000..47136b7 --- /dev/null +++ b/train_yamls/train_2maskhead_DWConv.yaml @@ -0,0 +1,75 @@ +# Related scripts +train_script_name: mgpu_2maskloss + +# models' scripts +model_configs: + g_model: + script: Generator_2mask_DWConv + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 16 + +# Dataset +dataloader: VGGFace2HQ_multigpu_w_mask +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 35.0 +reconstruct_weight: 10.0 +rec_feature_match_weight: 3.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 8.0 +mask_weight: 100.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_cycleloss_fm_nonstatu.yaml b/train_yamls/train_cycleloss_fm_nonstatu.yaml new file mode 100644 index 0000000..cc6d6f3 --- /dev/null +++ b/train_yamls/train_cycleloss_fm_nonstatu.yaml @@ -0,0 +1,76 @@ +# Related scripts +train_script_name: mgpu_fm + +# models' scripts +model_configs: + g_model: + script: Generator_featout_config + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + lstu_script: LSTU_Config + lstu_class: LSTU + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 16 + +# Dataset +dataloader: VGGFace2HQ_multigpu +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0006 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0006 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 30.0 +reconstruct_weight: 1.0 +rec_feature_match_weight: 1.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 6.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_cycleloss_resskip.yaml b/train_yamls/train_cycleloss_resskip.yaml new file mode 100644 index 0000000..4fc1f31 --- /dev/null +++ b/train_yamls/train_cycleloss_resskip.yaml @@ -0,0 +1,70 @@ +# Related scripts +train_script_name: multi_gpu_cycle + +# models' scripts +model_configs: + g_model: + script: Generator_ResSkip_config + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + + d_model: + script: projected_discriminator + class_name: ProjectedDiscriminator + module_params: + diffaug: False + interp224: False + backbone_kwargs: {} + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 8 + +# Dataset +dataloader: VGGFace2HQ_multigpu +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +id_weight: 20.0 +reconstruct_weight: 5.0 +rec_feature_match_weight: 1.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 5.0 + +# Log +log_step: 400 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_cycleloss_resskip_nonstatu.yaml b/train_yamls/train_cycleloss_resskip_nonstatu.yaml new file mode 100644 index 0000000..2f404dc --- /dev/null +++ b/train_yamls/train_cycleloss_resskip_nonstatu.yaml @@ -0,0 +1,76 @@ +# Related scripts +train_script_name: multi_gpu_cycle_nonstatue_dis + +# models' scripts +model_configs: + g_model: + script: Generator_ResSkip_config1 + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + lstu_script: LSTU_Config + lstu_class: LSTU + + d_model: + script: Nonstau_Discriminator + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 8 + +# Dataset +dataloader: VGGFace2HQ_multigpu +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 20.0 +reconstruct_weight: 1.0 +rec_feature_match_weight: 1.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 5.0 + +# Log +log_step: 400 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_maskhead_fm_nonstatu.yaml b/train_yamls/train_maskhead_fm_nonstatu.yaml new file mode 100644 index 0000000..015661a --- /dev/null +++ b/train_yamls/train_maskhead_fm_nonstatu.yaml @@ -0,0 +1,74 @@ +# Related scripts +train_script_name: mgpu_fm_w_mask + +# models' scripts +model_configs: + g_model: + script: Generator_maskhead_config + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 8 + +# Dataset +dataloader: VGGFace2HQ_multigpu +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0006 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0006 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 30.0 +reconstruct_weight: 1.0 +rec_feature_match_weight: 1.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 8.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_maskhead_fm_vggstyle.yaml b/train_yamls/train_maskhead_fm_vggstyle.yaml new file mode 100644 index 0000000..179567c --- /dev/null +++ b/train_yamls/train_maskhead_fm_vggstyle.yaml @@ -0,0 +1,74 @@ +# Related scripts +train_script_name: mgpu_fm_w_mask + +# models' scripts +model_configs: + g_model: + script: Generator_VGGStyle_maskhead_config + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 16 + +# Dataset +dataloader: VGGFace2HQ_multigpu +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0006 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0006 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 30.0 +reconstruct_weight: 1.0 +rec_feature_match_weight: 1.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 8.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_maskhead_hififace.yaml b/train_yamls/train_maskhead_hififace.yaml new file mode 100644 index 0000000..16560b8 --- /dev/null +++ b/train_yamls/train_maskhead_hififace.yaml @@ -0,0 +1,75 @@ +# Related scripts +train_script_name: mgpu_maskloss + +# models' scripts +model_configs: + g_model: + script: Generator_maskhead_config1 + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 8 + +# Dataset +dataloader: VGGFace2HQ_multigpu_w_mask +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 25.0 +reconstruct_weight: 3.0 +rec_feature_match_weight: 1.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 8.0 +mask_weight: 30.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_maskhead_hififace1.yaml b/train_yamls/train_maskhead_hififace1.yaml new file mode 100644 index 0000000..1c0ab9a --- /dev/null +++ b/train_yamls/train_maskhead_hififace1.yaml @@ -0,0 +1,75 @@ +# Related scripts +train_script_name: mgpu_maskloss + +# models' scripts +model_configs: + g_model: + script: Generator_maskhead_config2 + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 16 + +# Dataset +dataloader: VGGFace2HQ_multigpu_w_mask +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0004 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 25.0 +reconstruct_weight: 3.0 +rec_feature_match_weight: 1.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 8.0 +mask_weight: 30.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/train_yamls/train_maskloss.yaml b/train_yamls/train_maskloss.yaml new file mode 100644 index 0000000..022197b --- /dev/null +++ b/train_yamls/train_maskloss.yaml @@ -0,0 +1,75 @@ +# Related scripts +train_script_name: mgpu_maskloss + +# models' scripts +model_configs: + g_model: + script: Generator_maskhead_config + class_name: Generator + module_params: + id_dim: 512 + g_kernel_size: 3 + in_channel: 64 + res_num: 3 + up_mode: bilinear + aggregator: "conv" + res_mode: "conv" + norm: "bn" + + d_model: + script: Nonstau_Discriminator_FM + class_name: Discriminator + module_params: + img_size: 512 + max_conv_dim: 512 + norm: "bn" + +# arcface_ckpt: arcface_torch/checkpoints/glint360k_cosface_r100_fp16_backbone.pth +arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar + +# Training information +batch_size: 16 + +# Dataset +dataloader: VGGFace2HQ_multigpu_w_mask +dataset_name: vggface2_hq +dataset_params: + random_seed: 1234 + dataloader_workers: 6 + +eval_dataloader: DIV2K_hdf5 +eval_dataset_name: DF2K_H5_Eval +eval_batch_size: 2 + +# Dataset + +# Optimizer +optim_type: Adam +g_optim_config: + lr: 0.0006 + betas: [ 0, 0.99] + eps: !!float 1e-8 + +d_optim_config: + lr: 0.0006 + betas: [ 0, 0.99] + eps: !!float 1e-8 + + +d_reg_freq: 16 +id_weight: 30.0 +reconstruct_weight: 1.0 +rec_feature_match_weight: 1.0 +cycle_feature_match_weight: 1.0 +cycle_weight: 1.0 +reg_weight: 8.0 +mask_weight: 10.0 + +# Log +log_step: 500 +model_save_step: 10000 +total_step: 1000000 +sample_step: 1000 +checkpoint_names: + generator_name: Generator + discriminator_name: Discriminator \ No newline at end of file diff --git a/utilities/reverse2original.py b/utilities/reverse2original.py index 8b7fbc9..acd9c80 100644 --- a/utilities/reverse2original.py +++ b/utilities/reverse2original.py @@ -86,7 +86,7 @@ def reverse2wholeimage(b_align_crop_tenor_list,swaped_imgs, mats, crop_size, ori # print(mats) # print(len(b_align_crop_tenor_list)) for swaped_img, mat ,source_img in zip(swaped_imgs, mats,b_align_crop_tenor_list): - swaped_img = swaped_img.cpu().detach().numpy().transpose((1, 2, 0)) + swaped_img = swaped_img.cpu().numpy().transpose((1, 2, 0)) img_white = np.full((crop_size,crop_size), 255, dtype=float) # inverse the Affine transformation matrix diff --git a/utilities/sshupload.py b/utilities/sshupload.py index acdc996..18f3c91 100644 --- a/utilities/sshupload.py +++ b/utilities/sshupload.py @@ -5,12 +5,18 @@ # Created Date: Tuesday September 24th 2019 # Author: Lcx # Email: chenxuanhongzju@outlook.com -# Last Modified: Friday, 18th February 2022 3:20:14 pm +# Last Modified: Thursday, 14th April 2022 12:33:07 pm # Modified By: Chen Xuanhong # Copyright (c) 2019 Shanghai Jiao Tong University ############################################################# -import paramiko,os +try: + import paramiko +except: + from pip._internal import main + main(['install', 'paramiko']) + import paramiko +import os from pathlib import Path # ssh传输类: diff --git a/utilities/utilities.py b/utilities/utilities.py index 503b190..3f942cf 100644 --- a/utilities/utilities.py +++ b/utilities/utilities.py @@ -5,16 +5,50 @@ # Created Date: Monday April 6th 2020 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Tuesday, 12th October 2021 2:18:05 pm +# Last Modified: Thursday, 14th April 2022 11:34:54 am # Modified By: Chen Xuanhong # Copyright (c) 2020 Shanghai Jiao Tong University ############################################################# +import os import cv2 import torch from PIL import Image import numpy as np from torchvision import transforms +from torch.hub import download_url_to_file, get_dir +from urllib.parse import urlparse + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file # Gram Matrix def Gram(tensor: torch.Tensor): diff --git a/vggface2hq_failed.txt b/vggface2hq_failed.txt index caa9ec3..865e8a4 100644 --- a/vggface2hq_failed.txt +++ b/vggface2hq_failed.txt @@ -1,109 +1,1798 @@ n000002/0054_01.jpg +n000002/0055_01.jpg +n000002/0138_01.jpg +n000002/0150_02.jpg +n000002/0208_01.jpg +n000002/0252_01.jpg +n000002/0273_01.jpg +n000002/0276_01.jpg +n000003/0024_01.jpg +n000003/0098_01.jpg +n000003/0219_01.jpg n000004/0026_01.jpg n000004/0084_01.jpg +n000004/0103_02.jpg +n000004/0118_01.jpg +n000004/0144_02.jpg +n000004/0155_01.jpg +n000004/0180_01.jpg +n000004/0231_01.jpg +n000004/0237_01.jpg +n000004/0239_01.jpg +n000004/0258_01.jpg n000005/0138_01.jpg +n000005/0144_01.jpg +n000005/0287_01.jpg +n000006/0007_01.jpg n000006/0014_01.jpg n000006/0036_02.jpg n000006/0091_01.jpg +n000006/0103_01.jpg +n000006/0281_01.jpg n000006/0300_01.jpg -n000006/0519_01.jpg n000006/0351_01.jpg +n000006/0430_01.jpg +n000006/0519_01.jpg +n000007/0021_01.jpg +n000007/0042_01.jpg +n000007/0045_01.jpg +n000007/0050_02.jpg +n000007/0080_01.jpg +n000007/0086_01.jpg n000007/0106_02.jpg n000007/0115_01.jpg +n000007/0116_03.jpg n000007/0119_01.jpg -n000007/0181_01.jpg -n000007/0174_01.jpg -n000007/0148_02.jpg +n000007/0137_01.jpg n000007/0140_02.jpg +n000007/0148_02.jpg +n000007/0174_01.jpg +n000007/0181_01.jpg +n000007/0182_02.jpg +n000007/0213_02.jpg +n000007/0226_02.jpg +n000007/0229_01.jpg +n000007/0432_01.jpg n000008/0072_01.jpg -n000009/0150_02.jpg -n000009/0096_01.jpg -n000009/0068_01.jpg +n000008/0297_01.jpg +n000010/0068_01.jpg +n000010/0069_01.jpg +n000010/0096_01.jpg +n000010/0150_02.jpg +n000010/0155_02.jpg +n000010/0223_01.jpg +n000011/0112_01.jpg +n000011/0142_02.jpg +n000011/0200_01.jpg +n000011/0217_01.jpg +n000011/0229_02.jpg +n000011/0291_02.jpg +n000012/0173_01.jpg +n000012/0180_01.jpg +n000012/0198_01.jpg +n000012/0282_01.jpg +n000012/0294_01.jpg +n000012/0307_01.jpg +n000012/0338_01.jpg +n000013/0029_06.jpg +n000013/0128_01.jpg +n000013/0132_01.jpg +n000013/0148_01.jpg +n000013/0190_02.jpg +n000013/0225_01.jpg +n000013/0277_01.jpg +n000013/0335_01.jpg +n000013/0337_01.jpg +n000013/0341_02.jpg n000014/0163_01.jpg -n000015/0402_01.jpg +n000015/0029_02.jpg +n000015/0059_01.jpg +n000015/0133_01.jpg +n000015/0243_02.jpg n000015/0392_02.jpg -n000016/0047_03.jpg +n000015/0393_01.jpg +n000015/0402_01.jpg +n000016/0189_01.jpg +n000016/0237_01.jpg n000016/0266_01.jpg +n000016/0385_04.jpg +n000016/0391_01.jpg +n000016/0405_01.jpg +n000016/0477_02.jpg n000016/0500_01.jpg n000016/0503_01.jpg -n000016/0405_01.jpg +n000016/0503_01.jpg n000017/0123_02.jpg +n000017/0124_01.jpg n000017/0163_01.jpg -n000018/0163_01.jpg -n000018/0212_01.jpg -n000018/0216_01.jpg -n000018/0189_01.jpg -n000018/0293_01.jpg -n000018/0280_01.jpg -n000018/0317_01.jpg +n000017/0262_01.jpg n000019/0038_01.jpg -n000019/0061_01.jpg n000019/0055_01.jpg +n000019/0061_01.jpg n000019/0114_01.jpg -n000019/0182_01.jpg n000019/0130_02.jpg +n000019/0149_02.jpg +n000019/0170_01.jpg +n000019/0182_01.jpg +n000019/0219_01.jpg +n000019/0221_02.jpg +n000019/0234_02.jpg +n000019/0249_01.jpg n000019/0259_01.jpg +n000019/0273_01.jpg +n000019/0306_01.jpg +n000019/0313_01.jpg +n000019/0333_01.jpg +n000019/0350_02.jpg n000020/0006_01.jpg n000020/0071_01.jpg n000020/0074_02.jpg n000020/0099_02.jpg -n000020/0367_01.jpg n000020/0379_01.jpg +n000020/0400_01.jpg n000021/0120_02.jpg n000021/0221_01.jpg n000022/0051_01.jpg n000022/0071_01.jpg n000022/0146_02.jpg +n000022/0146_02.jpg n000022/0236_01.jpg -n000023/0133_01.jpg +n000023/0008_01.jpg +n000023/0078_01.jpg n000023/0093_01.jpg -n000023/0318_01.jpg +n000023/0133_01.jpg +n000023/0162_01.jpg +n000023/0198_01.jpg +n000023/0207_03.jpg +n000023/0269_02.jpg n000023/0265_01.jpg -n000024/0073_01.jpg +n000023/0280_01.jpg +n000023/0366_01.jpg +n000023/0389_01.jpg n000024/0062_01.jpg -n000024/0409_01.jpg +n000024/0073_01.jpg n000024/0354_04.jpg -n000025/0274_02.jpg +n000024/0409_01.jpg n000025/0100_02.jpg -n000026/0059_01.jpg +n000025/0274_02.jpg +n000026/0038_01.jpg n000026/0041_01.jpg +n000026/0059_01.jpg n000026/0062_01.jpg n000026/0065_01.jpg +n000026/0082_02.jpg +n000026/0103_01.jpg +n000026/0137_01.jpg +n000026/0060_01.jpg n000026/0179_03.jpg -n000026/0273_01.jpg -n000026/0255_01.jpg +n000026/0196_01.jpg n000026/0248_01.jpg -n000026/0182_02.jpg -n000026/0157_02.jpg -n000026/0211_02.jpg n000026/0255_01.jpg -n000026/0442_01.jpg +n000026/0273_01.jpg +n000026/0280_01.jpg +n000027/0023_02.jpg +n000027/0023_05.jpg +n000027/0115_01.jpg +n000027/0157_02.jpg +n000027/0171_01.jpg +n000027/0182_02.jpg +n000027/0211_02.jpg +n000027/0255_01.jpg +n000027/0274_04.jpg +n000027/0318_04.jpg +n000027/0326_01.jpg +n000027/0401_01.jpg +n000027/0402_01.jpg +n000027/0438_01.jpg +n000027/0442_01.jpg +n000027/0493_01.jpg +n000028/0040_04.jpg +n000028/0056_01.jpg n000028/0134_01.jpg n000028/0136_03.jpg -n000028/0168_01.jpg +n000028/0138_01.jpg +n000028/0144_02.jpg +n000028/0156_01.jpg n000028/0162_01.jpg -n000028/0384_01.jpg +n000028/0168_01.jpg +n000028/0205_01.jpg n000028/0220_01.jpg +n000028/0249_01.jpg +n000028/0300_01.jpg +n000028/0324_02.jpg +n000028/0343_01.jpg n000028/0352_01.jpg +n000028/0384_01.jpg +n000028/0392_01.jpg +n000028/0408_02.jpg +n000028/0412_02.jpg n000030/0112_01.jpg -n000030/0195_01.jpg +n000030/0119_01.jpg +n000030/0156_01.jpg n000030/0192_01.jpg +n000030/0195_01.jpg +n000030/0203_01.jpg +n000030/0218_02.jpg n000030/0305_01.jpg n000031/0025_01.jpg +n000031/0080_02.jpg +n000031/0141_01.jpg +n000031/0196_01.jpg n000031/0215_01.jpg n000031/0286_02.jpg n000032/0085_01.jpg +n000032/0100_01.jpg +n000032/0100_02.jpg +n000032/0233_01.jpg n000032/0261_01.jpg -n000032/0428_01.jpg +n000032/0350_01.jpg +n000032/0374_01.jpg n000032/0393_02.jpg +n000032/0428_01.jpg +n000032/0443_01.jpg +n000032/0459_01.jpg +n000032/0465_02.jpg n000033/0031_01.jpg n000033/0032_02.jpg n000033/0034_01.jpg +n000033/0034_02.jpg n000033/0080_01.jpg +n000033/0100_01.jpg +n000033/0100_02.jpg n000033/0122_01.jpg n000033/0164_02.jpg +n000033/0166_01.jpg +n000033/0250_02.jpg +n000033/0327_01.jpg +n000033/0337_01.jpg n000034/0327_01.jpg +n000035/0072_02.jpg +n000035/0099_01.jpg +n000035/0132_03.jpg +n000035/0134_01.jpg +n000035/0150_01.jpg +n000035/0158_01.jpg +n000035/0159_02.jpg +n000035/0167_01.jpg n000035/0170_01.jpg +n000035/0200_01.jpg +n000002/0013_01.jpg +n000002/0018_01.jpg +n000002/0023_01.jpg +n000002/0027_01.jpg +n000002/0031_06.jpg +n000002/0031_08.jpg +n000002/0042_01.jpg +n000002/0058_01.jpg +n000002/0068_01.jpg +n000002/0075_01.jpg +n000002/0078_01.jpg +n000002/0094_01.jpg +n000002/0095_01.jpg +n000002/0110_03.jpg +n000002/0125_01.jpg +n000002/0141_01.jpg +n000002/0142_01.jpg +n000002/0152_02.jpg +n000002/0170_01.jpg +n000002/0171_01.jpg +n000002/0179_01.jpg +n000002/0180_01.jpg +n000002/0184_01.jpg +n000002/0193_01.jpg +n000002/0197_01.jpg +n000002/0199_01.jpg +n000002/0201_01.jpg +n000002/0209_01.jpg +n000002/0210_01.jpg +n000002/0216_01.jpg +n000002/0217_01.jpg +n000002/0218_01.jpg +n000002/0227_02.jpg +n000002/0231_01.jpg +n000002/0233_01.jpg +n000002/0237_01.jpg +n000002/0240_01.jpg +n000002/0239_01.jpg +n000002/0245_01.jpg +n000002/0249_01.jpg +n000002/0257_01.jpg +n000002/0259_01.jpg +n000002/0261_01.jpg +n000002/0262_01.jpg +n000002/0265_01.jpg +n000002/0268_01.jpg +n000002/0270_01.jpg +n000002/0275_01.jpg +n000002/0277_01.jpg +n000002/0279_01.jpg +n000002/0284_02.jpg +n000002/0298_01.jpg +n000002/0304_01.jpg +n000002/0305_01.jpg +n000002/0311_01.jpg +n000002/0312_01.jpg +n000002/0316_01.jpg +n000002/0317_01.jpg +n000002/0321_01.jpg +n000002/0323_01.jpg +n000003/0006_01.jpg +n000003/0010_01.jpg +n000003/0011_02.jpg +n000003/0013_01.jpg +n000003/0013_01.jpg +n000003/0021_01.jpg +n000003/0026_01.jpg +n000003/0027_02.jpg +n000003/0036_01.jpg +n000003/0038_01.jpg +n000003/0044_02.jpg +n000003/0054_01.jpg +n000003/0055_01.jpg +n000003/0064_02.jpg +n000003/0073_01.jpg +n000003/0074_02.jpg +n000003/0083_01.jpg +n000003/0085_01.jpg +n000003/0086_01.jpg +n000003/0099_03.jpg +n000003/0097_01.jpg +n000003/0100_03.jpg +n000003/0101_01.jpg +n000003/0102_01.jpg +n000003/0103_01.jpg +n000003/0104_01.jpg +n000003/0108_01.jpg +n000003/0115_01.jpg +n000003/0116_02.jpg +n000003/0118_01.jpg +n000003/0120_01.jpg +n000003/0122_06.jpg +n000003/0124_03.jpg +n000003/0125_01.jpg +n000003/0129_01.jpg +n000003/0130_01.jpg +n000003/0131_01.jpg +n000003/0133_01.jpg +n000003/0136_01.jpg +n000003/0137_01.jpg +n000003/0143_01.jpg +n000003/0144_01.jpg +n000003/0149_01.jpg +n000003/0155_01.jpg +n000003/0157_02.jpg +n000003/0161_01.jpg +n000003/0162_01.jpg +n000003/0163_02.jpg +n000003/0164_02.jpg +n000003/0165_01.jpg +n000003/0167_02.jpg +n000003/0168_02.jpg +n000003/0170_01.jpg +n000003/0172_02.jpg +n000003/0173_01.jpg +n000003/0177_02.jpg +n000003/0181_01.jpg +n000003/0183_02.jpg +n000003/0200_02.jpg +n000003/0201_01.jpg +n000003/0202_01.jpg +n000003/0206_01.jpg +n000003/0207_02.jpg +n000003/0222_01.jpg +n000003/0226_01.jpg +n000003/0240_01.jpg +n000003/0241_01.jpg +n000003/0244_02.jpg +n000003/0245_01.jpg +n000003/0246_01.jpg +n000003/0249_01.jpg +n000003/0253_03.jpg +n000004/0018_01.jpg +n000004/0040_01.jpg +n000004/0041_01.jpg +n000004/0057_03.jpg +n000004/0060_01.jpg +n000004/0073_01.jpg +n000004/0090_01.jpg +n000004/0097_01.jpg +n000004/0124_01.jpg +n000004/0131_01.jpg +n000004/0165_01.jpg +n000004/0171_03.jpg +n000004/0175_01.jpg +n000004/0178_01.jpg +n000004/0182_02.jpg +n000004/0184_02.jpg +n000004/0224_02.jpg +n000004/0225_01.jpg +n000004/0228_01.jpg +n000004/0235_01.jpg +n000004/0241_01.jpg +n000004/0243_01.jpg +n000004/0248_01.jpg +n000004/0252_01.jpg +n000004/0251_01.jpg +n000004/0253_02.jpg +n000004/0255_02.jpg +n000004/0260_04.jpg +n000004/0268_02.jpg +n000004/0272_01.jpg +n000004/0274_01.jpg +n000004/0276_01.jpg +n000004/0277_01.jpg +n000004/0279_01.jpg +n000004/0290_02.jpg +n000004/0296_02.jpg +n000004/0315_01.jpg +n000004/0324_02.jpg +n000004/0328_01.jpg +n000004/0334_01.jpg +n000004/0340_01.jpg +n000004/0343_01.jpg +n000004/0350_01.jpg +n000004/0354_01.jpg +n000004/0391_01.jpg +n000004/0393_01.jpg +n000004/0396_01.jpg +n000004/0402_01.jpg +n000004/0420_01.jpg +n000005/0025_01.jpg +n000005/0045_01.jpg +n000005/0052_01.jpg +n000005/0063_01.jpg +n000005/0078_01.jpg +n000005/0080_01.jpg +n000005/0087_01.jpg +n000005/0101_01.jpg +n000005/0102_01.jpg +n000005/0104_01.jpg +n000005/0105_01.jpg +n000005/0106_01.jpg +n000005/0108_01.jpg +n000005/0117_01.jpg +n000005/0124_01.jpg +n000005/0130_01.jpg +n000005/0136_01.jpg +n000005/0142_01.jpg +n000005/0143_01.jpg +n000005/0146_01.jpg +n000005/0148_01.jpg +n000005/0150_02.jpg +n000005/0156_01.jpg +n000005/0160_02.jpg +n000005/0163_02.jpg +n000005/0164_01.jpg +n000005/0165_01.jpg +n000005/0167_02.jpg +n000005/0174_01.jpg +n000005/0175_01.jpg +n000005/0180_01.jpg +n000005/0181_02.jpg +n000005/0182_01.jpg +n000005/0185_01.jpg +n000005/0190_01.jpg +n000005/0192_01.jpg +n000005/0194_03.jpg +n000005/0195_01.jpg +n000005/0197_02.jpg +n000005/0203_01.jpg +n000005/0205_01.jpg +n000005/0210_02.jpg +n000005/0213_01.jpg +n000005/0219_01.jpg +n000005/0220_01.jpg +n000005/0221_01.jpg +n000005/0222_02.jpg +n000005/0226_01.jpg +n000005/0229_01.jpg +n000005/0233_01.jpg +n000005/0241_01.jpg +n000005/0284_02.jpg +n000005/0306_01.jpg +n000005/0350_01.jpg +n000005/0406_01.jpg +n000005/0413_01.jpg +n000005/0424_01.jpg +n000005/0430_02.jpg +n000005/0431_01.jpg +n000006/0001_01.jpg +n000006/0004_04.jpg +n000006/0051_01.jpg +n000006/0101_01.jpg +n000006/0146_01.jpg +n000006/0156_01.jpg +n000006/0165_02.jpg +n000006/0174_01.jpg +n000006/0183_01.jpg +n000006/0185_02.jpg +n000006/0187_03.jpg +n000006/0187_04.jpg +n000006/0189_01.jpg +n000006/0198_01.jpg +n000006/0206_01.jpg +n000006/0225_01.jpg +n000006/0231_01.jpg +n000006/0235_01.jpg +n000006/0242_01.jpg +n000006/0248_01.jpg +n000006/0249_01.jpg +n000006/0252_01.jpg +n000006/0257_01.jpg +n000006/0258_03.jpg +n000006/0262_01.jpg +n000006/0264_01.jpg +n000006/0268_01.jpg +n000006/0275_04.jpg +n000006/0279_01.jpg +n000006/0283_01.jpg +n000006/0284_01.jpg +n000006/0314_01.jpg +n000006/0316_01.jpg +n000006/0319_01.jpg +n000006/0323_01.jpg +n000006/0324_01.jpg +n000006/0325_01.jpg +n000006/0326_01.jpg +n000006/0328_02.jpg +n000006/0329_01.jpg +n000006/0332_01.jpg +n000006/0333_01.jpg +n000006/0334_01.jpg +n000006/0335_01.jpg +n000006/0336_01.jpg +n000006/0337_01.jpg +n000006/0338_01.jpg +n000006/0341_01.jpg +n000006/0347_01.jpg +n000006/0349_02.jpg +n000006/0284_01.jpg +n000006/0283_01.jpg +n000006/0314_01.jpg +n000006/0315_01.jpg +n000006/0316_01.jpg +n000006/0319_01.jpg +n000006/0324_01.jpg +n000006/0333_01.jpg +n000006/0332_01.jpg +n000006/0334_01.jpg +n000006/0335_01.jpg +n000006/0336_01.jpg +n000006/0340_01.jpg +n000006/0341_01.jpg +n000006/0343_01.jpg +n000006/0349_02.jpg +n000006/0350_01.jpg +n000006/0352_01.jpg +n000006/0353_01.jpg +n000006/0354_01.jpg +n000006/0356_01.jpg +n000006/0358_01.jpg +n000006/0359_03.jpg +n000006/0360_01.jpg +n000006/0361_01.jpg +n000006/0362_01.jpg +n000006/0363_01.jpg +n000006/0367_01.jpg +n000006/0368_02.jpg +n000006/0369_01.jpg +n000006/0372_01.jpg +n000006/0374_01.jpg +n000006/0377_01.jpg +n000006/0380_01.jpg +n000006/0384_01.jpg +n000006/0388_01.jpg +n000006/0389_01.jpg +n000006/0396_01.jpg +n000006/0397_01.jpg +n000006/0399_04.jpg +n000006/0400_04.jpg +n000006/0404_01.jpg +n000006/0406_01.jpg +n000006/0411_05.jpg +n000006/0413_01.jpg +n000006/0418_01.jpg +n000006/0419_01.jpg +n000006/0420_01.jpg +n000006/0426_01.jpg +n000006/0432_01.jpg +n000006/0433_03.jpg +n000006/0457_03.jpg +n000006/0467_01.jpg +n000006/0475_01.jpg +n000006/0480_02.jpg +n000006/0521_01.jpg +n000006/0523_02.jpg +n000006/0524_01.jpg +n000006/0526_01.jpg +n000006/0528_01.jpg +n000006/0532_01.jpg +n000006/0533_01.jpg +n000006/0536_01.jpg +n000006/0538_01.jpg +n000006/0542_01.jpg +n000006/0543_01.jpg +n000006/0544_02.jpg +n000006/0545_02.jpg +n000006/0548_03.jpg +n000006/0549_02.jpg +n000006/0552_01.jpg +n000006/0554_01.jpg +n000007/0002_01.jpg +n000007/0006_02.jpg +n000007/0007_01.jpg +n000007/0011_01.jpg +n000007/0012_01.jpg +n000007/0017_01.jpg +n000007/0018_01.jpg +n000007/0022_02.jpg +n000007/0023_01.jpg +n000007/0028_01.jpg +n000007/0033_02.jpg +n000007/0039_01.jpg +n000007/0040_02.jpg +n000007/0044_01.jpg +n000007/0052_01.jpg +n000007/0053_01.jpg +n000007/0054_02.jpg +n000007/0055_01.jpg +n000007/0057_01.jpg +n000007/0058_02.jpg +n000007/0060_01.jpg +n000007/0070_01.jpg +n000007/0071_01.jpg +n000007/0081_01.jpg +n000007/0085_03.jpg +n000007/0096_01.jpg +n000007/0099_01.jpg +n000007/0113_02.jpg +n000007/0118_04.jpg +n000007/0121_01.jpg +n000007/0122_01.jpg +n000007/0123_01.jpg +n000007/0124_01.jpg +n000007/0138_03.jpg +n000007/0141_03.jpg +n000007/0142_02.jpg +n000007/0145_04.jpg +n000007/0146_02.jpg +n000007/0151_03.jpg +n000007/0152_01.jpg +n000007/0153_01.jpg +n000007/0160_03.jpg +n000007/0160_04.jpg +n000007/0160_05.jpg +n000007/0160_05.jpg +n000007/0165_02.jpg +n000007/0166_01.jpg +n000007/0168_01.jpg +n000007/0169_01.jpg +n000007/0170_01.jpg +n000007/0171_02.jpg +n000007/0171_04.jpg +n000007/0172_01.jpg +n000007/0175_01.jpg +n000007/0176_01.jpg +n000007/0177_04.jpg +n000007/0185_01.jpg +n000007/0187_01.jpg +n000007/0188_01.jpg +n000007/0189_01.jpg +n000007/0195_01.jpg +n000007/0196_01.jpg +n000007/0197_02.jpg +n000007/0198_03.jpg +n000007/0200_02.jpg +n000007/0201_01.jpg +n000007/0205_01.jpg +n000007/0208_01.jpg +n000007/0209_01.jpg +n000007/0215_02.jpg +n000007/0218_01.jpg +n000007/0221_02.jpg +n000007/0227_03.jpg +n000007/0233_02.jpg +n000007/0239_01.jpg +n000007/0241_01.jpg +n000007/0246_01.jpg +n000007/0246_02.jpg +n000007/0247_01.jpg +n000007/0271_02.jpg +n000007/0280_01.jpg +n000007/0283_02.jpg +n000007/0311_02.jpg +n000007/0327_05.jpg +n000007/0379_02.jpg +n000007/0381_01.jpg +n000007/0391_01.jpg +n000007/0411_02.jpg +n000007/0419_01.jpg +n000007/0428_01.jpg +n000007/0430_01.jpg +n000008/0003_01.jpg +n000008/0005_02.jpg +n000008/0020_01.jpg +n000008/0079_01.jpg +n000008/0080_01.jpg +n000008/0091_01.jpg +n000008/0094_01.jpg +n000008/0095_01.jpg +n000008/0096_01.jpg +n000008/0098_01.jpg +n000008/0101_01.jpg +n000008/0102_01.jpg +n000008/0111_01.jpg +n000008/0112_01.jpg +n000008/0118_01.jpg +n000008/0121_01.jpg +n000008/0124_01.jpg +n000008/0127_01.jpg +n000008/0143_01.jpg +n000008/0153_01.jpg +n000008/0166_02.jpg +n000008/0174_01.jpg +n000008/0177_01.jpg +n000008/0193_01.jpg +n000008/0195_01.jpg +n000008/0196_01.jpg +n000008/0197_01.jpg +n000008/0199_01.jpg +n000008/0201_01.jpg +n000008/0204_01.jpg +n000008/0205_01.jpg +n000008/0207_01.jpg +n000008/0208_01.jpg +n000008/0212_02.jpg +n000008/0213_01.jpg +n000008/0218_02.jpg +n000008/0227_01.jpg +n000008/0239_01.jpg +n000008/0250_02.jpg +n000008/0251_01.jpg +n000008/0259_01.jpg +n000008/0276_01.jpg +n000008/0277_01.jpg +n000008/0278_01.jpg +n000008/0285_01.jpg +n000008/0288_03.jpg +n000008/0302_01.jpg +n000008/0303_01.jpg +n000008/0308_01.jpg +n000008/0309_01.jpg +n000008/0327_01.jpg +n000008/0347_01.jpg +n000010/0063_01.jpg +n000010/0079_10.jpg +n000010/0080_01.jpg +n000010/0085_01.jpg +n000010/0102_01.jpg +n000010/0105_05.jpg +n000010/0127_01.jpg +n000010/0128_01.jpg +n000010/0130_02.jpg +n000010/0130_04.jpg +n000010/0130_06.jpg +n000010/0131_02.jpg +n000010/0137_01.jpg +n000010/0138_01.jpg +n000010/0138_02.jpg +n000010/0147_01.jpg +n000010/0154_01.jpg +n000010/0157_02.jpg +n000010/0158_01.jpg +n000010/0166_01.jpg +n000010/0167_01.jpg +n000010/0214_01.jpg +n000010/0280_01.jpg +n000011/0021_01.jpg +n000011/0099_02.jpg +n000011/0114_02.jpg +n000011/0128_02.jpg +n000011/0166_01.jpg +n000011/0186_06.jpg +n000011/0210_01.jpg +n000011/0215_01.jpg +n000011/0219_01.jpg +n000011/0221_01.jpg +n000011/0224_02.jpg +n000011/0227_02.jpg +n000011/0228_01.jpg +n000011/0234_01.jpg +n000011/0238_01.jpg +n000011/0247_01.jpg +n000011/0248_01.jpg +n000011/0249_01.jpg +n000011/0267_01.jpg +n000011/0270_01.jpg +n000011/0271_02.jpg +n000011/0273_01.jpg +n000011/0279_02.jpg +n000011/0280_04.jpg +n000011/0281_01.jpg +n000011/0285_06.jpg +n000011/0293_01.jpg +n000011/0296_01.jpg +n000011/0299_01.jpg +n000011/0306_01.jpg +n000011/0306_07.jpg +n000011/0312_02.jpg +n000011/0313_01.jpg +n000011/0316_01.jpg +n000011/0317_02.jpg +n000011/0318_01.jpg +n000011/0324_01.jpg +n000011/0329_02.jpg +n000011/0334_02.jpg +n000011/0382_01.jpg +n000011/0385_01.jpg +n000011/0387_01.jpg +n000011/0397_01.jpg +n000011/0407_06.jpg +n000011/0408_01.jpg +n000011/0417_01.jpg +n000011/0424_01.jpg +n000011/0426_03.jpg +n000012/0012_02.jpg +n000012/0029_03.jpg +n000012/0032_01.jpg +n000012/0046_01.jpg +n000012/0056_02.jpg +n000012/0067_01.jpg +n000012/0068_01.jpg +n000012/0069_01.jpg +n000012/0076_01.jpg +n000012/0078_01.jpg +n000012/0100_02.jpg +n000012/0101_01.jpg +n000012/0102_01.jpg +n000012/0103_01.jpg +n000012/0109_01.jpg +n000012/0109_02.jpg +n000012/0109_03.jpg +n000012/0110_01.jpg +n000012/0112_01.jpg +n000012/0114_01.jpg +n000012/0116_01.jpg +n000012/0122_01.jpg +n000012/0141_01.jpg +n000012/0179_01.jpg +n000012/0181_01.jpg +n000012/0194_01.jpg +n000012/0208_01.jpg +n000012/0210_01.jpg +n000012/0210_02.jpg +n000012/0211_01.jpg +n000012/0243_01.jpg +n000012/0253_01.jpg +n000012/0254_01.jpg +n000012/0257_01.jpg +n000012/0263_03.jpg +n000012/0266_01.jpg +n000012/0273_02.jpg +n000012/0277_02.jpg +n000012/0279_01.jpg +n000012/0285_02.jpg +n000012/0288_01.jpg +n000012/0288_02.jpg +n000012/0289_01.jpg +n000012/0291_03.jpg +n000012/0299_01.jpg +n000012/0301_02.jpg +n000012/0304_01.jpg +n000012/0306_02.jpg +n000012/0309_03.jpg +n000012/0309_01.jpg +n000012/0315_02.jpg +n000012/0320_01.jpg +n000012/0320_02.jpg +n000012/0335_02.jpg +n000012/0340_01.jpg +n000012/0350_01.jpg +n000012/0350_02.jpg +n000012/0358_01.jpg +n000012/0360_01.jpg +n000012/0375_01.jpg +n000012/0406_01.jpg +n000012/0406_02.jpg +n000012/0410_01.jpg +n000012/0412_01.jpg +n000012/0414_02.jpg +n000012/0422_01.jpg +n000012/0426_01.jpg +n000012/0426_02.jpg +n000012/0430_01.jpg +n000013/0013_01.jpg +n000013/0014_01.jpg +n000013/0023_01.jpg +n000013/0029_04.jpg +n000013/0030_01.jpg +n000013/0041_01.jpg +n000013/0048_01.jpg +n000013/0057_01.jpg +n000013/0105_01.jpg +n000013/0106_01.jpg +n000013/0112_01.jpg +n000013/0117_01.jpg +n000013/0118_01.jpg +n000013/0123_03.jpg +n000013/0124_01.jpg +n000013/0127_01.jpg +n000013/0131_02.jpg +n000013/0131_03.jpg +n000013/0134_01.jpg +n000013/0141_01.jpg +n000013/0149_01.jpg +n000013/0157_01.jpg +n000013/0160_01.jpg +n000013/0163_01.jpg +n000013/0164_01.jpg +n000013/0165_01.jpg +n000013/0166_01.jpg +n000013/0168_01.jpg +n000013/0175_01.jpg +n000013/0176_02.jpg +n000013/0177_01.jpg +n000013/0181_02.jpg +n000013/0182_01.jpg +n000013/0186_01.jpg +n000013/0192_01.jpg +n000013/0193_02.jpg +n000013/0193_04.jpg +n000013/0196_01.jpg +n000013/0198_01.jpg +n000013/0201_01.jpg +n000013/0203_01.jpg +n000013/0204_01.jpg +n000013/0205_01.jpg +n000013/0209_01.jpg +n000013/0210_01.jpg +n000013/0211_01.jpg +n000013/0212_01.jpg +n000013/0213_01.jpg +n000013/0215_01.jpg +n000013/0220_01.jpg +n000013/0227_01.jpg +n000013/0230_01.jpg +n000013/0233_01.jpg +n000013/0237_01.jpg +n000013/0236_01.jpg +n000013/0238_01.jpg +n000013/0242_01.jpg +n000013/0245_01.jpg +n000013/0246_03.jpg +n000013/0247_01.jpg +n000013/0248_01.jpg +n000013/0249_02.jpg +n000013/0252_01.jpg +n000013/0253_01.jpg +n000013/0254_01.jpg +n000013/0258_01.jpg +n000013/0259_01.jpg +n000013/0261_01.jpg +n000013/0266_01.jpg +n000013/0268_02.jpg +n000013/0273_01.jpg +n000013/0274_02.jpg +n000013/0283_01.jpg +n000013/0293_01.jpg +n000013/0305_02.jpg +n000013/0316_01.jpg +n000013/0320_01.jpg +n000013/0323_01.jpg +n000013/0330_01.jpg +n000013/0331_01.jpg +n000013/0332_01.jpg +n000013/0340_01.jpg +n000014/0049_01.jpg +n000014/0067_06.jpg +n000014/0130_08.jpg +n000014/0130_09.jpg +n000014/0130_10.jpg +n000014/0130_12.jpg +n000014/0130_13.jpg +n000014/0130_14.jpg +n000014/0130_15.jpg +n000014/0130_19.jpg +n000014/0130_20.jpg +n000014/0130_21.jpg +n000014/0130_22.jpg +n000014/0130_25.jpg +n000014/0130_28.jpg +n000014/0130_30.jpg +n000014/0130_31.jpg +n000014/0130_32.jpg +n000014/0130_33.jpg +n000014/0130_34.jpg +n000014/0130_35.jpg +n000014/0132_01.jpg +n000014/0134_01.jpg +n000014/0158_01.jpg +n000014/0177_01.jpg +n000014/0200_01.jpg +n000014/0201_01.jpg +n000014/0203_01.jpg +n000014/0206_01.jpg +n000014/0208_01.jpg +n000014/0209_01.jpg +n000014/0213_01.jpg +n000014/0214_01.jpg +n000014/0215_01.jpg +n000014/0216_01.jpg +n000014/0217_01.jpg +n000014/0222_01.jpg +n000014/0232_02.jpg +n000014/0233_01.jpg +n000014/0244_02.jpg +n000014/0255_01.jpg +n000014/0289_02.jpg +n000014/0283_01.jpg +n000015/0020_01.jpg +n000015/0021_01.jpg +n000015/0023_01.jpg +n000015/0031_01.jpg +n000015/0034_02.jpg +n000015/0040_01.jpg +n000015/0050_01.jpg +n000015/0050_02.jpg +n000015/0052_02.jpg +n000015/0055_01.jpg +n000015/0056_01.jpg +n000015/0066_01.jpg +n000015/0067_01.jpg +n000015/0068_01.jpg +n000015/0075_01.jpg +n000015/0076_01.jpg +n000015/0078_02.jpg +n000015/0081_02.jpg +n000015/0087_03.jpg +n000015/0088_01.jpg +n000015/0096_01.jpg +n000015/0100_01.jpg +n000015/0101_04.jpg +n000015/0102_01.jpg +n000015/0103_04.jpg +n000015/0104_03.jpg +n000015/0110_01.jpg +n000015/0111_01.jpg +n000015/0112_01.jpg +n000015/0113_03.jpg +n000015/0115_01.jpg +n000015/0116_01.jpg +n000015/0117_01.jpg +n000015/0118_01.jpg +n000015/0119_03.jpg +n000015/0122_01.jpg +n000015/0123_01.jpg +n000015/0126_01.jpg +n000015/0130_01.jpg +n000015/0131_01.jpg +n000015/0134_01.jpg +n000015/0138_02.jpg +n000015/0139_02.jpg +n000015/0140_01.jpg +n000015/0142_01.jpg +n000015/0147_01.jpg +n000015/0151_01.jpg +n000015/0153_02.jpg +n000015/0155_03.jpg +n000015/0161_01.jpg +n000015/0163_03.jpg +n000015/0167_04.jpg +n000015/0169_01.jpg +n000015/0173_01.jpg +n000015/0174_05.jpg +n000015/0175_03.jpg +n000015/0181_03.jpg +n000015/0185_02.jpg +n000015/0186_01.jpg +n000015/0190_02.jpg +n000015/0192_02.jpg +n000015/0194_01.jpg +n000015/0201_01.jpg +n000015/0201_03.jpg +n000015/0206_01.jpg +n000015/0288_03.jpg +n000015/0314_01.jpg +n000015/0344_06.jpg +n000015/0356_01.jpg +n000015/0372_01.jpg +n000015/0393_04.jpg +n000015/0391_01.jpg +n000015/0395_01.jpg +n000015/0415_01.jpg +n000015/0434_02.jpg +n000015/0438_01.jpg +n000015/0438_02.jpg +n000017/0036_01.jpg +n000017/0047_01.jpg +n000017/0236_01.jpg +n000017/0237_01.jpg +n000017/0262_01.jpg +n000017/0269_01.jpg +n000018/0108_01.jpg +n000018/0173_01.jpg +n000018/0206_02.jpg +n000018/0304_01.jpg +n000019/0085_01.jpg +n000019/0089_01.jpg +n000019/0106_03.jpg +n000019/0170_01.jpg +n000019/0234_02.jpg +n000019/0249_01.jpg +n000019/0273_01.jpg +n000019/0275_01.jpg +n000019/0276_01.jpg +n000019/0306_01.jpg +n000019/0309_01.jpg +n000019/0313_01.jpg +n000019/0328_01.jpg +n000019/0331_01.jpg +n000019/0333_01.jpg +n000019/0334_01.jpg +n000019/0337_01.jpg +n000019/0347_01.jpg +n000019/0350_02.jpg +n000020/0243_01.jpg +n000020/0290_01.jpg +n000020/0334_01.jpg +n000020/0400_01.jpg +n000020/0384_01.jpg +n000020/0409_01.jpg +n000020/0418_01.jpg +n000021/0046_01.jpg +n000021/0052_01.jpg +n000021/0087_01.jpg +n000021/0117_01.jpg +n000021/0143_01.jpg +n000021/0184_01.jpg +n000022/0347_01.jpg +n000022/0415_01.jpg +n000023/0008_01.jpg +n000023/0012_01.jpg +n000023/0156_01.jpg +n000023/0162_01.jpg +n000023/0198_01.jpg +n000023/0207_03.jpg +n000023/0256_01.jpg +n000023/0257_01.jpg +n000023/0269_02.jpg +n000023/0285_01.jpg +n000023/0280_01.jpg +n000023/0294_01.jpg +n000023/0319_01.jpg +n000023/0343_01.jpg +n000023/0352_01.jpg +n000023/0359_01.jpg +n000023/0366_01.jpg +n000023/0389_01.jpg +n000024/0046_02.jpg +n000024/0056_01.jpg +n000024/0188_01.jpg +n000024/0258_01.jpg +n000024/0311_01.jpg +n000024/0325_01.jpg +n000024/0327_01.jpg +n000025/0245_01.jpg +n000026/0038_01.jpg +n000026/0060_01.jpg +n000026/0075_01.jpg +n000026/0078_01.jpg +n000026/0082_02.jpg +n000026/0103_01.jpg +n000026/0104_01.jpg +n000026/0125_01.jpg +n000026/0137_01.jpg +n000026/0196_01.jpg +n000026/0280_01.jpg +n000027/0023_02.jpg +n000027/0023_05.jpg +n000027/0097_01.jpg +n000027/0099_01.jpg +n000027/0108_02.jpg +n000027/0115_01.jpg +n000027/0157_02.jpg +n000027/0171_01.jpg +n000027/0182_02.jpg +n000027/0211_02.jpg +n000027/0255_01.jpg +n000027/0256_03.jpg +n000027/0257_01.jpg +n000027/0274_04.jpg +n000027/0318_04.jpg +n000027/0401_01.jpg +n000027/0402_01.jpg +n000027/0438_01.jpg +n000027/0438_02.jpg +n000027/0440_01.jpg +n000027/0442_01.jpg +n000027/0443_01.jpg +n000027/0446_01.jpg +n000027/0456_01.jpg +n000027/0458_01.jpg +n000027/0469_02.jpg +n000027/0493_01.jpg +n000028/0040_04.jpg +n000028/0044_01.jpg +n000028/0056_01.jpg +n000028/0080_01.jpg +n000028/0083_01.jpg +n000028/0088_01.jpg +n000028/0113_01.jpg +n000028/0120_01.jpg +n000028/0138_01.jpg +n000028/0140_02.jpg +n000028/0141_02.jpg +n000028/0144_02.jpg +n000028/0147_01.jpg +n000028/0149_01.jpg +n000028/0156_01.jpg +n000028/0155_01.jpg +n000028/0161_01.jpg +n000028/0175_02.jpg +n000028/0179_01.jpg +n000028/0180_01.jpg +n000028/0205_01.jpg +n000028/0208_02.jpg +n000028/0249_01.jpg +n000028/0300_01.jpg +n000028/0324_02.jpg +n000028/0343_01.jpg +n000028/0392_01.jpg +n000028/0412_02.jpg +n000030/0155_01.jpg +n000030/0157_01.jpg +n000030/0186_01.jpg +n000030/0193_01.jpg +n000030/0203_01.jpg +n000030/0204_01.jpg +n000030/0214_01.jpg +n000030/0218_02.jpg +n000030/0220_01.jpg +n000030/0244_01.jpg +n000031/0080_02.jpg +n000031/0092_01.jpg +n000031/0174_01.jpg +n000031/0180_01.jpg +n000031/0196_01.jpg +n000031/0248_01.jpg +n000031/0319_01.jpg +n000031/0320_03.jpg +n000032/0100_01.jpg +n000032/0100_02.jpg +n000032/0209_01.jpg +n000032/0233_01.jpg +n000032/0236_01.jpg +n000032/0237_01.jpg +n000032/0238_01.jpg +n000032/0309_01.jpg +n000032/0374_01.jpg +n000032/0393_01.jpg +n000032/0401_01.jpg +n000032/0409_01.jpg +n000032/0410_01.jpg +n000032/0420_01.jpg +n000032/0422_01.jpg +n000032/0459_01.jpg +n000032/0465_02.jpg +n000032/0531_01.jpg +n000032/0540_01.jpg +n000032/0556_01.jpg +n000032/0566_01.jpg +n000032/0578_01.jpg +n000032/0580_01.jpg +n000032/0582_01.jpg +n000032/0591_01.jpg +n000032/0605_01.jpg +n000033/0034_02.jpg +n000033/0095_01.jpg +n000033/0100_01.jpg +n000033/0100_02.jpg +n000033/0107_01.jpg +n000033/0170_01.jpg +n000033/0171_01.jpg +n000033/0179_01.jpg +n000033/0207_02.jpg +n000033/0224_01.jpg +n000033/0228_01.jpg +n000033/0231_01.jpg +n000033/0232_01.jpg +n000033/0233_02.jpg +n000033/0234_01.jpg +n000033/0235_01.jpg +n000033/0247_01.jpg +n000033/0250_02.jpg +n000033/0327_01.jpg +n000033/0337_01.jpg +n000033/0344_01.jpg +n000033/0435_01.jpg +n000034/0171_01.jpg +n000035/0069_01.jpg +n000035/0072_01.jpg +n000035/0072_02.jpg +n000035/0072_04.jpg +n000035/0098_03.jpg +n000035/0099_01.jpg +n000035/0132_02.jpg +n000035/0132_03.jpg +n000035/0134_01.jpg +n000035/0150_01.jpg +n000035/0159_02.jpg +n000035/0158_01.jpg +n000035/0161_01.jpg +n000035/0167_01.jpg +n000035/0171_01.jpg +n000035/0180_01.jpg +n000035/0200_01.jpg +n000036/0003_01.jpg +n000036/0066_02.jpg +n000036/0069_01.jpg +n000036/0083_01.jpg +n000036/0117_01.jpg +n000036/0178_02.jpg +n000036/0278_01.jpg +n000036/0279_01.jpg +n000036/0280_04.jpg +n000036/0302_02.jpg +n000036/0303_03.jpg +n000036/0304_02.jpg +n000036/0335_02.jpg +n000036/0558_01.jpg +n000036/0603_02.jpg +n000037/0007_02.jpg +n000037/0016_02.jpg +n000037/0146_03.jpg +n000037/0184_01.jpg +n000037/0166_01.jpg +n000038/0016_02.jpg +n000038/0068_01.jpg +n000038/0110_01.jpg +n000038/0114_01.jpg +n000038/0118_01.jpg +n000038/0155_01.jpg +n000038/0167_02.jpg +n000038/0169_01.jpg +n000038/0171_01.jpg +n000038/0172_01.jpg +n000038/0176_01.jpg +n000038/0178_01.jpg +n000038/0210_01.jpg +n000038/0212_01.jpg +n000038/0227_01.jpg +n000038/0236_01.jpg +n000038/0237_01.jpg +n000038/0241_01.jpg +n000038/0249_01.jpg +n000038/0260_01.jpg +n000038/0265_01.jpg +n000038/0275_01.jpg +n000038/0283_01.jpg +n000038/0286_01.jpg +n000038/0290_01.jpg +n000038/0308_01.jpg +n000038/0309_01.jpg +n000038/0336_02.jpg +n000038/0343_01.jpg +n000038/0355_01.jpg +n000038/0357_01.jpg +n000038/0366_01.jpg +n000038/0429_01.jpg +n000039/0174_01.jpg +n000039/0195_03.jpg +n000039/0310_01.jpg +n000039/0311_01.jpg +n000039/0313_01.jpg +n000039/0358_02.jpg +n000041/0089_04.jpg +n000041/0119_01.jpg +n000043/0159_02.jpg +n000043/0169_01.jpg +n000043/0369_01.jpg +n000043/0389_01.jpg +n000043/0391_01.jpg +n000043/0436_01.jpg +n000043/0457_01.jpg +n000043/0458_01.jpg +n000044/0007_01.jpg +n000044/0009_02.jpg +n000044/0078_01.jpg +n000044/0099_01.jpg +n000044/0117_01.jpg +n000044/0258_01.jpg +n000044/0275_01.jpg +n000044/0325_01.jpg +n000044/0350_01.jpg +n000044/0353_02.jpg +n000044/0364_01.jpg +n000044/0374_01.jpg +n000044/0379_01.jpg +n000045/0048_01.jpg +n000045/0048_02.jpg +n000045/0054_03.jpg +n000045/0120_03.jpg +n000045/0120_03.jpg +n000045/0120_02.jpg +n000045/0128_01.jpg +n000045/0128_02.jpg +n000045/0150_02.jpg +n000045/0156_01.jpg +n000045/0170_02.jpg +n000045/0226_02.jpg +n000045/0230_01.jpg +n000045/0254_03.jpg +n000045/0256_01.jpg +n000045/0269_01.jpg +n000045/0270_01.jpg +n000046/0145_01.jpg +n000047/0102_01.jpg +n000047/0191_03.jpg +n000047/0232_02.jpg +n000047/0292_01.jpg +n000047/0324_01.jpg +n000047/0464_01.jpg +n000047/0492_02.jpg +n000047/0484_03.jpg +n000047/0496_02.jpg +n000048/0050_01.jpg +n000048/0199_01.jpg +n000048/0197_01.jpg +n000048/0232_01.jpg +n000049/0046_01.jpg +n000049/0085_01.jpg +n000049/0136_01.jpg +n000049/0155_01.jpg +n000049/0277_01.jpg +n000049/0339_01.jpg +n000049/0342_01.jpg +n000049/0345_01.jpg +n000049/0372_01.jpg +n000049/0397_02.jpg +n000049/0417_01.jpg +n000049/0418_01.jpg +n000049/0472_02.jpg +n000049/0469_01.jpg +n000049/0474_01.jpg +n000050/0098_01.jpg +n000050/0115_01.jpg +n000050/0130_01.jpg +n000050/0158_02.jpg +n000050/0189_01.jpg +n000050/0228_01.jpg +n000050/0321_01.jpg +n000050/0321_02.jpg +n000050/0323_01.jpg +n000050/0332_02.jpg +n000050/0368_01.jpg +n000050/0369_01.jpg +n000050/0444_01.jpg +n000051/0243_02.jpg +n000051/0249_01.jpg +n000051/0250_01.jpg +n000051/0258_01.jpg +n000051/0274_01.jpg +n000051/0342_01.jpg +n000051/0366_01.jpg +n000052/0233_02.jpg +n000052/0290_01.jpg +n000052/0288_02.jpg +n000052/0373_02.jpg +n000052/0387_02.jpg +n000052/0451_01.jpg +n000052/0514_01.jpg +n000053/0136_01.jpg +n000053/0280_01.jpg +n000053/0283_01.jpg +n000053/0287_01.jpg +n000053/0287_02.jpg +n000053/0288_01.jpg +n000053/0291_01.jpg +n000053/0299_02.jpg +n000053/0314_01.jpg +n000053/0329_01.jpg +n000053/0399_01.jpg +n000054/0111_01.jpg +n000054/0258_01.jpg +n000054/0261_01.jpg +n000054/0263_01.jpg +n000054/0273_03.jpg +n000054/0275_01.jpg +n000054/0319_01.jpg +n000054/0322_01.jpg +n000054/0361_01.jpg +n000054/0451_01.jpg +n000054/0453_01.jpg +n000054/0455_01.jpg +n000055/0043_01.jpg +n000055/0167_01.jpg +n000055/0172_01.jpg +n000055/0175_01.jpg +n000055/0181_01.jpg +n000055/0251_01.jpg +n000055/0255_01.jpg +n000056/0158_02.jpg +n000056/0254_01.jpg +n000057/0200_03.jpg +n000057/0293_01.jpg +n000057/0300_01.jpg +n000057/0337_01.jpg +n000057/0341_02.jpg +n000057/0344_06.jpg +n000057/0348_01.jpg +n000057/0353_01.jpg +n000057/0351_01.jpg +n000057/0356_01.jpg +n000057/0357_01.jpg +n000057/0368_01.jpg +n000057/0373_01.jpg +n000058/0266_01.jpg +n000058/0467_03.jpg +n000058/0468_01.jpg +n000059/0005_01.jpg +n000059/0013_01.jpg +n000059/0046_01.jpg +n000059/0124_01.jpg +n000059/0177_01.jpg +n000059/0177_02.jpg +n000059/0182_01.jpg +n000059/0222_01.jpg +n000002/0054_01.jpg +n000002/0055_01.jpg +n000002/0138_01.jpg +n000002/0150_02.jpg +n000002/0208_01.jpg +n000002/0252_01.jpg +n000002/0273_01.jpg +n000002/0276_01.jpg +n000003/0024_01.jpg +n000003/0098_01.jpg +n000003/0219_01.jpg +n000004/0026_01.jpg +n000004/0084_01.jpg +n000004/0103_02.jpg +n000004/0118_01.jpg +n000004/0144_02.jpg +n000004/0155_01.jpg +n000004/0180_01.jpg +n000004/0231_01.jpg +n000004/0237_01.jpg +n000004/0239_01.jpg +n000004/0258_01.jpg +n000005/0138_01.jpg +n000005/0144_01.jpg +n000005/0287_01.jpg +n000006/0007_01.jpg +n000006/0014_01.jpg +n000006/0036_02.jpg +n000006/0091_01.jpg +n000006/0103_01.jpg +n000006/0281_01.jpg +n000006/0300_01.jpg +n000006/0351_01.jpg +n000006/0430_01.jpg +n000006/0519_01.jpg +n000006/0549_02.jpg +n000007/0021_01.jpg +n000007/0042_01.jpg +n000007/0045_01.jpg +n000007/0050_02.jpg +n000007/0080_01.jpg +n000007/0086_01.jpg +n000007/0106_02.jpg +n000007/0115_01.jpg +n000007/0116_03.jpg +n000007/0119_01.jpg +n000007/0137_01.jpg +n000007/0140_02.jpg +n000007/0148_02.jpg +n000007/0174_01.jpg +n000007/0181_01.jpg +n000007/0182_02.jpg +n000007/0213_02.jpg +n000007/0226_02.jpg +n000007/0229_01.jpg +n000007/0432_01.jpg +n000008/0072_01.jpg +n000008/0297_01.jpg +n000010/0068_01.jpg +n000010/0069_01.jpg +n000010/0096_01.jpg +n000010/0150_02.jpg +n000010/0155_02.jpg +n000010/0223_01.jpg +n000011/0112_01.jpg +n000011/0142_02.jpg +n000011/0200_01.jpg +n000011/0217_01.jpg +n000011/0229_02.jpg +n000011/0291_02.jpg +n000012/0173_01.jpg +n000012/0180_01.jpg +n000012/0198_01.jpg +n000012/0282_01.jpg +n000012/0294_01.jpg +n000012/0307_01.jpg +n000012/0338_01.jpg +n000013/0029_06.jpg +n000013/0128_01.jpg +n000013/0132_01.jpg +n000013/0148_01.jpg +n000013/0190_02.jpg +n000013/0225_01.jpg +n000013/0277_01.jpg +n000013/0335_01.jpg +n000013/0337_01.jpg +n000013/0341_02.jpg +n000014/0163_01.jpg +n000015/0029_02.jpg +n000015/0059_01.jpg +n000015/0133_01.jpg +n000015/0243_02.jpg +n000015/0392_02.jpg +n000015/0393_01.jpg +n000015/0402_01.jpg +n000016/0189_01.jpg +n000016/0237_01.jpg +n000016/0266_01.jpg +n000016/0385_04.jpg +n000016/0391_01.jpg +n000016/0405_01.jpg +n000016/0477_02.jpg +n000016/0500_01.jpg +n000016/0503_01.jpg +n000016/0503_01.jpg +n000017/0123_02.jpg +n000017/0124_01.jpg +n000017/0163_01.jpg +n000017/0262_01.jpg +n000019/0038_01.jpg +n000019/0055_01.jpg +n000019/0061_01.jpg +n000019/0114_01.jpg +n000019/0130_02.jpg +n000019/0149_02.jpg +n000019/0170_01.jpg +n000019/0182_01.jpg +n000019/0219_01.jpg +n000019/0221_02.jpg +n000019/0234_02.jpg +n000019/0249_01.jpg +n000019/0259_01.jpg +n000019/0273_01.jpg +n000019/0306_01.jpg +n000019/0313_01.jpg +n000019/0333_01.jpg +n000019/0350_02.jpg +n000020/0006_01.jpg +n000020/0071_01.jpg +n000020/0074_02.jpg +n000020/0099_02.jpg +n000020/0379_01.jpg +n000020/0400_01.jpg +n000021/0120_02.jpg +n000021/0221_01.jpg +n000022/0051_01.jpg +n000022/0071_01.jpg +n000022/0146_02.jpg +n000022/0146_02.jpg +n000022/0236_01.jpg +n000023/0008_01.jpg +n000023/0078_01.jpg +n000023/0093_01.jpg +n000023/0133_01.jpg +n000023/0162_01.jpg +n000023/0198_01.jpg +n000023/0207_03.jpg +n000023/0269_02.jpg +n000023/0265_01.jpg +n000023/0280_01.jpg +n000023/0366_01.jpg +n000023/0389_01.jpg +n000024/0062_01.jpg +n000024/0073_01.jpg +n000024/0354_04.jpg +n000024/0409_01.jpg +n000025/0100_02.jpg +n000025/0274_02.jpg +n000026/0038_01.jpg +n000026/0041_01.jpg +n000026/0059_01.jpg +n000026/0062_01.jpg +n000026/0065_01.jpg +n000026/0082_02.jpg +n000026/0103_01.jpg +n000026/0137_01.jpg +n000026/0060_01.jpg +n000026/0179_03.jpg +n000026/0196_01.jpg +n000026/0248_01.jpg +n000026/0255_01.jpg +n000026/0273_01.jpg +n000026/0280_01.jpg +n000027/0023_02.jpg +n000027/0023_05.jpg +n000027/0115_01.jpg +n000027/0157_02.jpg +n000027/0171_01.jpg +n000027/0182_02.jpg +n000027/0211_02.jpg +n000027/0255_01.jpg +n000027/0274_04.jpg +n000027/0318_04.jpg +n000027/0326_01.jpg +n000027/0401_01.jpg +n000027/0402_01.jpg +n000027/0438_01.jpg +n000027/0442_01.jpg +n000027/0493_01.jpg +n000028/0040_04.jpg +n000028/0056_01.jpg +n000028/0134_01.jpg +n000028/0136_03.jpg +n000028/0138_01.jpg +n000028/0144_02.jpg +n000028/0156_01.jpg +n000028/0162_01.jpg +n000028/0168_01.jpg +n000028/0205_01.jpg +n000028/0220_01.jpg +n000028/0249_01.jpg +n000028/0300_01.jpg +n000028/0324_02.jpg +n000028/0343_01.jpg +n000028/0352_01.jpg +n000028/0384_01.jpg +n000028/0392_01.jpg +n000028/0408_02.jpg +n000028/0412_02.jpg +n000030/0112_01.jpg +n000030/0119_01.jpg +n000030/0156_01.jpg +n000030/0192_01.jpg +n000030/0195_01.jpg +n000030/0203_01.jpg +n000030/0218_02.jpg +n000030/0305_01.jpg +n000031/0025_01.jpg +n000031/0080_02.jpg +n000031/0141_01.jpg +n000031/0196_01.jpg +n000031/0215_01.jpg +n000031/0286_02.jpg +n000032/0085_01.jpg +n000032/0100_01.jpg +n000032/0100_02.jpg +n000032/0233_01.jpg +n000032/0261_01.jpg +n000032/0350_01.jpg +n000032/0374_01.jpg +n000032/0393_02.jpg +n000032/0428_01.jpg +n000032/0443_01.jpg +n000032/0459_01.jpg +n000032/0465_02.jpg +n000033/0031_01.jpg +n000033/0032_02.jpg +n000033/0034_01.jpg +n000033/0034_02.jpg +n000033/0080_01.jpg +n000033/0100_01.jpg +n000033/0100_02.jpg +n000033/0122_01.jpg +n000033/0164_02.jpg +n000033/0166_01.jpg +n000033/0250_02.jpg +n000033/0327_01.jpg +n000033/0337_01.jpg +n000034/0327_01.jpg +n000035/0072_02.jpg +n000035/0099_01.jpg +n000035/0132_03.jpg +n000035/0134_01.jpg +n000035/0150_01.jpg +n000035/0158_01.jpg +n000035/0159_02.jpg +n000035/0167_01.jpg +n000035/0170_01.jpg +n000035/0200_01.jpg n000036/0236_02.jpg n000036/0257_02.jpg n000036/0476_01.jpg @@ -207,7 +1896,9 @@ n000054/0103_02.jpg n000055/0345_01.jpg n000056/0196_01.jpg n000056/0238_03.jpg +n000056/0254_01.jpg n000056/0251_01.jpg +n000056/0158_02.jpg n000057/0058_01.jpg n000057/0049_01.jpg n000057/0323_02.jpg @@ -232,6 +1923,7 @@ n000061/0347_01.jpg n000061/0365_01.jpg n000061/0334_01.jpg n000061/0393_02.jpg +n000061/0209_01.jpg n000062/0073_01.jpg n000062/0075_01.jpg n000062/0089_01.jpg @@ -253,11 +1945,21 @@ n000065/0126_01.jpg n000065/0200_01.jpg n000065/0209_01.jpg n000065/0225_02.jpg +n000065/0082_01.jpg n000066/0040_01.jpg n000066/0109_01.jpg n000066/0267_01.jpg n000066/0276_01.jpg +n000066/0096_01.jpg n000066/0262_01.jpg +n000066/0056_01.jpg +n000066/0037_01.jpg +n000066/0110_03.jpg +n000066/0141_01.jpg +n000066/0085_02.jpg +n000066/0100_01.jpg +n000066/0160_02.jpg +n000066/0366_04.jpg n000067/0526_01.jpg n000067/0521_01.jpg n000067/0457_01.jpg @@ -274,7 +1976,22 @@ n000067/0307_05.jpg n000067/0343_01.jpg n000067/0334_01.jpg n000067/0457_01.jpg +n000067/0056_01.jpg +n000067/0110_03.jpg +n000067/0141_01.jpg +n000067/0295_02.jpg +n000067/0296_01.jpg +n000067/0320_03.jpg +n000067/0323_02.jpg +n000067/0328_01.jpg +n000069/0151_01.jpg +n000069/0191_04.jpg +n000069/0323_02.jpg +n000069/0159_02.jpg n000069/0283_01.jpg +n000069/0186_01.jpg +n000069/0129_02.jpg +n000069/0153_02.jpg n000069/0475_01.jpg n000069/0323_02.jpg n000069/0282_01.jpg @@ -1281,6 +2998,7 @@ n000212/0051_01.jpg n000212/0075_01.jpg n000212/0065_01.jpg n000212/0109_02.jpg +n000212/0216_01.jpg n000213/0068_01.jpg n000213/0158_01.jpg n000213/0300_01.jpg @@ -2190,6 +3908,7 @@ n000364/0208_02.jpg n000364/0239_01.jpg n000364/0368_01.jpg n000364/0674_01.jpg +n000364/0097_03.jpg n000365/0049_02.jpg n000365/0150_02.jpg n000365/0210_02.jpg @@ -51772,4 +53491,15 @@ n009273/0407_01.jpg n009274/0047_01.jpg n009275/0050_02.jpg n009275/0073_01.jpg -n009278/0061_01.jpg \ No newline at end of file +n009278/0061_01.jpg +n000018/0189_01.jpg +n000018/0293_01.jpg +n000018/0280_01.jpg +n000018/0163_01.jpg +n000018/0317_01.jpg +n000018/0216_01.jpg +n000018/0212_01.jpg +n000016/0047_03.jpg +n000020/0367_01.jpg +n000023/0318_01.jpg +n001251/0164_01.jpg