This commit is contained in:
chenxuanhong
2022-04-24 15:44:47 +08:00
parent 99ed65aaa3
commit 29d8914c0a
138 changed files with 24864 additions and 353 deletions
+212 -36
View File
@@ -5,14 +5,17 @@
# Created Date: Wednesday December 22nd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 17th February 2022 10:47:35 pm
# Last Modified: Friday, 22nd April 2022 11:23:17 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
from glob import glob
from ipaddress import ip_address
import os
import re
import sys
import time
import json
@@ -37,6 +40,7 @@ import tkinter.ttk as ttk
import subprocess
from pathlib import Path
from tkinter.filedialog import askopenfilename
@@ -163,6 +167,27 @@ class fileUploaderClass(object):
self.__ssh__.close()
return roots
def sshScpGetRNamesBySuffix(self, remoteDir, suffix):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
wocao = sftp.listdir(remoteDir)
# print(wocao.st_mtime)
roots = {}
for item in wocao:
wocao = sftp.stat(remoteDir+"/"+item)
roots[item] = {
"t":wocao.st_mtime,
"p":remoteDir+"/"+item
}
# temp= remoteDir+ "/"+item
# child_dirs = sftp.listdir(temp)
# child_dirs = ["save\\" +item + "\\" + i for i in child_dirs]
# list_name += child_dirs
sftp.close()
self.__ssh__.close()
return roots
def sshScpGet(self, remoteFile, localFile, showProgress=False):
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
@@ -308,16 +333,16 @@ class Application(tk.Frame):
"config_json_name":"model_config.json"
}
machine_text = {
"ip": "0.0.0.0",
"ip": "localhost",
"user": "username",
"port": 22,
"passwd": "12345678",
"path": "/path/to/remote_host",
"path": ".",
"ckp_path":"save",
"logfilename": "filestate_machine0.json"
"logfilename": "filestate_machine_localhost.json"
}
current_log = {}
current_ckpt = {}
def __init__(self, master=None):
tk.Frame.__init__(self, master,bg='black')
@@ -435,7 +460,7 @@ class Application(tk.Frame):
config_frame.pack(fill="both", padx=5,pady=5)
config_frame.columnconfigure(0, weight=1)
config_frame.columnconfigure(1, weight=1)
# config_frame.columnconfigure(2, weight=1)
config_frame.columnconfigure(2, weight=1)
machine_btn = tk.Button(config_frame, text = "Ignore Conf",
font=font_list, command = self.IgnoreConfig, bg='#660099', fg='#F5F5F5')
@@ -445,12 +470,17 @@ class Application(tk.Frame):
font=font_list, command = self.EnvConfig, bg='#660099', fg='#F5F5F5')
machine_btn2.grid(row=0,column=1,sticky=tk.EW)
machine_btn2 = tk.Button(config_frame, text = "Test Conf",
font=font_list, command = self.TestConfig, bg='#660099', fg='#F5F5F5')
machine_btn2.grid(row=0,column=2,sticky=tk.EW)
#################################################################################################
log_frame = tk.Frame(self.master)
log_frame.pack(fill="both", padx=5,pady=5)
log_frame.columnconfigure(0, weight=1)
log_frame.columnconfigure(1, weight=1)
log_frame.columnconfigure(2, weight=1)
log_frame.columnconfigure(3, weight=1)
self.log_var = tkinter.StringVar()
@@ -460,36 +490,104 @@ class Application(tk.Frame):
self.update_ckpt_task()
self.log_com.bind("<<ComboboxSelected>>",select_log)
self.test_var = tkinter.StringVar()
self.test_com = ttk.Combobox(log_frame, textvariable=self.test_var)
self.test_com.grid(row=0,column=1,sticky=tk.EW)
log_update_button = tk.Button(log_frame, text = "Fresh",
font=font_list, command = self.UpdateLog, bg='#F4A460', fg='#F5F5F5')
log_update_button.grid(row=0,column=1,sticky=tk.EW)
log_update_button = tk.Button(log_frame, text = "Pull Log",
font=font_list, command = self.PullLog, bg='#F4A460', fg='#F5F5F5')
log_update_button.grid(row=0,column=2,sticky=tk.EW)
# log_update_button = tk.Button(log_frame, text = "Pull Log",
# font=font_list, command = self.PullLog, bg='#F4A460', fg='#F5F5F5')
# log_update_button.grid(row=0,column=2,sticky=tk.EW)
log_update_button = tk.Button(log_frame, text = "Fresh CKPT",
font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5')
log_update_button.grid(row=0,column=3,sticky=tk.EW)
#################################################################################################
test_frame = tk.Frame(self.master)
test_frame.pack(fill="both", padx=5,pady=5)
test_frame.columnconfigure(0, weight=1)
test_frame.columnconfigure(1, weight=1)
test_frame.columnconfigure(2, weight=1)
# test_frame.columnconfigure(3, weight=1)
self.test_var = tkinter.StringVar()
self.testscript_var = tkinter.StringVar()
self.test_com = ttk.Combobox(test_frame, textvariable=self.test_var)
self.test_com.grid(row=0,column=0,sticky=tk.EW)
self.testscript_com = ttk.Combobox(test_frame, textvariable=self.testscript_var)
self.testscript_com.grid(row=0,column=0,sticky=tk.EW)
test_update_button = tk.Button(test_frame, text = "Fresh CKPT",
font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5')
test_update_button.grid(row=0,column=1,sticky=tk.EW)
testscript_files = Path("./test_scripts").glob("*.py")
testscript_list = []
for item in testscript_files:
basename = item.name
basename = os.path.splitext(basename)[0]
testscript_list.append(basename)
self.testscript_com["value"] = testscript_list
# test_update_button = tk.Button(test_frame, text = "Fresh CKPT",
# font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5')
# test_update_button.grid(row=0,column=1,sticky=tk.EW)
test_update_button = tk.Button(test_frame, text = "Test",
font=font_list, command = self.Test, bg='#F4A460', fg='#F5F5F5')
test_update_button.grid(row=0,column=1,sticky=tk.EW)
# test_update_button = tk.Button(test_frame, text = "Test Config",
# font=font_list, command = self.TestConfig, bg='#660099', fg='#F5F5F5')
# test_update_button.grid(row=0,column=2,sticky=tk.EW)
test_update_button = tk.Button(test_frame, text = "Sample",
font=font_list, command = self.OpenSample, bg='#0033FF', fg='#F5F5F5')
test_update_button.grid(row=0,column=2,sticky=tk.EW)
# #################################################################################################
select_frame = tk.Frame(self.master)
select_frame.pack(fill="both", padx=5,pady=5)
select_frame.columnconfigure(0, weight=2)
select_frame.columnconfigure(1, weight=5)
select_frame.columnconfigure(2, weight=1)
select_frame.columnconfigure(3, weight=2)
select_frame.columnconfigure(4, weight=5)
select_frame.columnconfigure(5, weight=1)
select_frame.columnconfigure(6, weight=3)
self.preprocess_var = tkinter.StringVar()
self.preprocess_com = ttk.Combobox(select_frame, textvariable=self.preprocess_var)
self.preprocess_com.grid(row=0,column=6,sticky=tk.EW)
self.preprocess_com["value"] = ["True", "False"]
self.preprocess_com.current(1)
self.ID_path = tkinter.StringVar()
self.ID_path.set("...")
tk.Label(select_frame, text="ID:",font=font_list,justify="left")\
.grid(row=0,column=0,sticky=tk.EW)
tk.Entry(select_frame, textvariable= self.ID_path, font=font_list)\
.grid(row=0,column=1,sticky=tk.EW)
tk.Button(select_frame, text = "...", font=font_list,
command = self.Select_ID_path, bg='#F4A460', fg='#F5F5F5')\
.grid(row=0,column=2,sticky=tk.EW)
self.Attr_path = tkinter.StringVar()
self.Attr_path.set("...")
tk.Label(select_frame, text="Attr:",font=font_list,justify="left")\
.grid(row=0,column=3,sticky=tk.EW)
tk.Entry(select_frame, textvariable= self.Attr_path, font=font_list)\
.grid(row=0,column=4,sticky=tk.EW)
tk.Button(select_frame, text = "...", font=font_list,
command = self.Select_Attr_path, bg='#F4A460', fg='#F5F5F5')\
.grid(row=0,column=5,sticky=tk.EW)
# #################################################################################################
# tbtext_frame = tk.Frame(self.master)
@@ -527,23 +625,61 @@ class Application(tk.Frame):
self.master.protocol("WM_DELETE_WINDOW", self.on_closing)
# def __scaning_logs__(self):
def Select_ID_path(self):
thread_update = threading.Thread(target=self.select_ID_task)
thread_update.start()
def select_ID_task(self):
path = askopenfilename()
print("Selected ID: %s"%path)
self.ID_path.set(path)
def Select_Attr_path(self):
thread_update = threading.Thread(target=self.select_Attr_task)
thread_update.start()
def select_Attr_task(self):
path = askopenfilename()
print("Selected Attibutes: %s"%path)
self.Attr_path.set(path)
def UpdateCKPT(self):
thread_update = threading.Thread(target=self.update_ckpt_task)
thread_update.start()
def update_ckpt_task(self):
ip = self.list_com.get()
log = self.log_com.get()
cur_mac = self.machine_dict[ip]
files = Path('.',cur_mac["ckp_path"], log)
files = files.glob('*.pth')
all_files = []
for one_file in files:
all_files.append(one_file.name)
self.test_com["value"] =all_files
if len(all_files):
print("Loading checkpoints..........................")
remotemachine, mac = self.connection()
log = self.log_com.get()
remote_path = os.path.join(mac["path"],mac["ckp_path"], log, "checkpoints").replace("\\", "/")
if remotemachine == []:
files = Path(remote_path).glob("*/")
first_level = {}
for one_file in files:
first_level[one_file.name] = {
"t":"",
"p":""
}
else:
first_level = remotemachine.sshScpGetNames(remote_path)
if len(first_level) == 0:
self.test_com["value"] = [""]
self.test_com.current(0)
print("No checkpoint found!")
return
logs = []
for k,v in first_level.items():
logs.append([k,v["t"]])
# logs = sorted(logs)
logs = sorted(logs, key= lambda logs : logs[1],reverse=True)
self.test_com["value"] =[item[0] for item in logs]
self.test_com.current(0)
self.current_ckpt = first_level
print("Checkpoints list update success!")
def CopyPasswd(self):
def copy():
@@ -557,13 +693,25 @@ class Application(tk.Frame):
def Test(self):
def test_task():
ip = self.list_com.get()
log = self.log_com.get()
ckpt = self.test_com.get()
script_name = self.testscript_com.get()
id_path = self.ID_path.get()
attr_path = self.Attr_path.get()
preprocess = self.preprocess_com.get()
# if preprocess == "Preprocess-Off":
# preprocess = "off"
# else:
# preprocess = "on"
ckpt = re.sub("\D", "", ckpt)
cwd = os.getcwd()
files = str(Path(log, ckpt))
print(files)
# files = str(Path(log, ckpt))
# print(files)
print("start cmd /k \"cd /d %s && conda activate base \
&& python test.py -v %s -s %s -t %s -n %s -i %s -a %s --preprocess %s \""%(cwd, log, ckpt, script_name, ip, id_path, attr_path, preprocess))
subprocess.check_call("start cmd /k \"cd /d %s && conda activate base \
&& python test.py --model %s\""%(cwd, files), shell=True)
&& python test.py -v %s -s %s -t %s -n %s --preprocess %s -i %s -a %s\""%(cwd, log, ckpt, script_name, ip, preprocess , id_path, attr_path), shell=True)
thread_update = threading.Thread(target=test_task)
thread_update.start()
@@ -595,7 +743,7 @@ class Application(tk.Frame):
ssh_username = cur_mac["user"]
ssh_passwd = cur_mac["passwd"]
ssh_port = int(cur_mac["port"])
print(ssh_ip)
print("Processing IP: %s."%ssh_ip)
if ip.lower() == "local" or ip.lower() == "localhost":
print("localhost no need to connect!")
return [], cur_mac
@@ -607,6 +755,7 @@ class Application(tk.Frame):
print(cells)
def update_log_task(self):
print("Processing! Do not touch!")
remotemachine,mac = self.connection()
remote_path = os.path.join(mac["path"],mac["ckp_path"]).replace("\\", "/")
if remotemachine == []:
@@ -617,16 +766,23 @@ class Application(tk.Frame):
"t":"",
"p":""
}
# elif remotemachine == "localhost":
else:
first_level = remotemachine.sshScpGetNames(remote_path)
if len(first_level) == 0:
print("No training log found!")
return
logs = []
for k,v in first_level.items():
logs.append(k)
logs = sorted(logs)
self.log_com["value"] =logs
logs.append([k,v["t"]])
# logs = sorted(logs)
logs = sorted(logs, key= lambda logs : logs[1],reverse=True)
self.log_com["value"] = [item[0] for item in logs]
self.log_com.current(0)
self.current_log = first_level
self.update_ckpt_task()
print("Done!")
def UpdateLog(self):
thread_update = threading.Thread(target=self.update_log_task)
@@ -666,6 +822,21 @@ class Application(tk.Frame):
thread_update = threading.Thread(target=pull_log_task)
thread_update.start()
def TestConfig(self):
def test_config_task():
subprocess.call("start %s"%"test.py", shell=True)
thread_update = threading.Thread(target=test_config_task)
thread_update.start()
def OpenSample(self):
def open_cmd_task():
log = self.log_com.get()
cwd = os.getcwd()
sample = os.path.join(cwd,"test_logs",log,"samples")
subprocess.call("explorer "+sample, shell=False)
thread_update = threading.Thread(target=open_cmd_task)
thread_update.start()
def OpenCMD(self):
def open_cmd_task():
@@ -685,6 +856,7 @@ class Application(tk.Frame):
ip = self.list_com.get()
if ip.lower() == "local" or ip.lower() == "localhost":
print("localhost no need to connect!")
return
cur_mac = self.machine_dict[ip]
ssh_ip = cur_mac["ip"]
ssh_username = cur_mac["user"]
@@ -707,6 +879,10 @@ class Application(tk.Frame):
def GPUUsage(self):
def gpu_usage_task():
remotemachine,_ = self.connection()
if remotemachine == "local" or remotemachine == "localhost":
print("localhost no need to connect!")
return
results = remotemachine.sshExec("nvidia-smi")
print(results)
+281 -166
View File
@@ -1,91 +1,91 @@
{
"GUI.py": 1647657822.9152665,
"test.py": 1647879709.2723496,
"train.py": 1647657822.9562755,
"components\\Generator.py": 1647657822.93127,
"components\\projected_discriminator.py": 1647657822.938272,
"components\\pg_modules\\blocks.py": 1647657822.9362714,
"components\\pg_modules\\diffaug.py": 1647657822.9362714,
"components\\pg_modules\\discriminator.py": 1647657822.937271,
"components\\pg_modules\\networks_fastgan.py": 1647657822.937271,
"components\\pg_modules\\networks_stylegan2.py": 1647657822.937271,
"components\\pg_modules\\projector.py": 1647657822.938272,
"data_tools\\data_loader.py": 1647657822.9392715,
"data_tools\\data_loader_condition.py": 1647657822.9402719,
"data_tools\\data_loader_VGGFace2HQ.py": 1647657822.9392715,
"data_tools\\StyleResize.py": 1647657822.9392715,
"data_tools\\test_dataloader_dir.py": 1647657822.941272,
"losses\\PerceptualLoss.py": 1647657822.9432724,
"losses\\SliceWassersteinDistance.py": 1647657822.9432724,
"models\\arcface_models.py": 1647657822.9442725,
"models\\config.py": 1647657822.9442725,
"models\\__init__.py": 1647657822.9442725,
"test_scripts\\tester_common.py": 1647657822.9472733,
"GUI.py": 1649868392.5891902,
"test.py": 1649641910.555175,
"train.py": 1643397924.974299,
"components\\Generator.py": 1644689001.9005148,
"components\\projected_discriminator.py": 1642348101.4661522,
"components\\pg_modules\\blocks.py": 1640773190.0,
"components\\pg_modules\\diffaug.py": 1640773190.0,
"components\\pg_modules\\discriminator.py": 1642349784.9407308,
"components\\pg_modules\\networks_fastgan.py": 1640773190.0,
"components\\pg_modules\\networks_stylegan2.py": 1640773190.0,
"components\\pg_modules\\projector.py": 1642349764.3896568,
"data_tools\\data_loader.py": 1611123530.660446,
"data_tools\\data_loader_condition.py": 1625411562.8217106,
"data_tools\\data_loader_VGGFace2HQ.py": 1644234949.3769877,
"data_tools\\StyleResize.py": 1624954084.7176485,
"data_tools\\test_dataloader_dir.py": 1634041792.6743984,
"losses\\PerceptualLoss.py": 1615020169.668723,
"losses\\SliceWassersteinDistance.py": 1634022704.6082795,
"models\\arcface_models.py": 1642390690.623,
"models\\config.py": 1632643596.2908099,
"models\\__init__.py": 1642390864.8828168,
"test_scripts\\tester_common.py": 1625369535.199175,
"test_scripts\\tester_FastNST.py": 1634041357.607633,
"train_scripts\\trainer_base.py": 1647657822.9582758,
"train_scripts\\trainer_FM.py": 1647657822.957276,
"train_scripts\\trainer_naiv512.py": 1647657822.9602764,
"utilities\\checkpoint_manager.py": 1647657822.9652774,
"utilities\\figure.py": 1647657822.9652774,
"utilities\\json_config.py": 1647657822.9652774,
"utilities\\learningrate_scheduler.py": 1647657822.9652774,
"utilities\\logo_class.py": 1647657822.9662776,
"utilities\\plot.py": 1647657822.9662776,
"utilities\\reporter.py": 1647657822.9662776,
"utilities\\save_heatmap.py": 1647657822.967278,
"utilities\\sshupload.py": 1647657822.967278,
"utilities\\transfer_checkpoint.py": 1647657822.967278,
"utilities\\utilities.py": 1647657822.9682784,
"utilities\\yaml_config.py": 1647657822.9682784,
"train_yamls\\train_512FM.yaml": 1647657822.961277,
"train_scripts\\trainer_2layer_FM.py": 1647657822.957276,
"train_yamls\\train_2layer_FM.yaml": 1647657822.961277,
"components\\Generator_reduce.py": 1647657822.934271,
"train_scripts\\trainer_base.py": 1642396105.3868554,
"train_scripts\\trainer_FM.py": 1643021959.3577182,
"train_scripts\\trainer_naiv512.py": 1642315674.9740853,
"utilities\\checkpoint_manager.py": 1611123530.6624403,
"utilities\\figure.py": 1611123530.6634378,
"utilities\\json_config.py": 1611123530.6614666,
"utilities\\learningrate_scheduler.py": 1611123530.675422,
"utilities\\logo_class.py": 1633883995.3093486,
"utilities\\plot.py": 1641911100.7995758,
"utilities\\reporter.py": 1646311333.3067005,
"utilities\\save_heatmap.py": 1611123530.679439,
"utilities\\sshupload.py": 1649910787.5441866,
"utilities\\transfer_checkpoint.py": 1642397157.0163105,
"utilities\\utilities.py": 1649907294.9180465,
"utilities\\yaml_config.py": 1611123530.6614666,
"train_yamls\\train_512FM.yaml": 1643021615.8106658,
"train_scripts\\trainer_2layer_FM.py": 1642826548.2530458,
"train_yamls\\train_2layer_FM.yaml": 1642411635.5534878,
"components\\Generator_reduce.py": 1645020911.0651233,
"insightface_func\\face_detect_crop_multi.py": 1643796928.6362474,
"insightface_func\\face_detect_crop_single.py": 1638370471.7967434,
"insightface_func\\__init__.py": 1624197300.011183,
"insightface_func\\utils\\face_align_ffhqandnewarc.py": 1638370471.850638,
"losses\\PatchNCE.py": 1647677173.0239084,
"losses\\PatchNCE.py": 1647769120.567006,
"parsing_model\\model.py": 1626745709.554252,
"parsing_model\\resnet.py": 1626745709.554252,
"test_scripts\\tester_common copy.py": 1647657822.9472733,
"test_scripts\\tester_video.py": 1647657822.9482737,
"train_scripts\\trainer_cycleloss.py": 1647657822.9592762,
"train_scripts\\trainer_GramFM.py": 1647657822.9582758,
"utilities\\ImagenetNorm.py": 1647657822.9642777,
"utilities\\reverse2original.py": 1647657822.9662776,
"train_yamls\\train_cycleloss.yaml": 1647704989.2310138,
"train_yamls\\train_GramFM.yaml": 1647657822.9622767,
"train_yamls\\train_512FM_Modulation.yaml": 1647657822.961277,
"face_crop.py": 1647657822.9422722,
"face_crop_video.py": 1647657822.9422722,
"similarity.py": 1647657822.945273,
"train_multigpu.py": 1647967698.603863,
"components\\arcface_decoder.py": 1647657822.9352713,
"test_scripts\\tester_common copy.py": 1625369535.199175,
"test_scripts\\tester_video.py": 1649749352.843031,
"train_scripts\\trainer_cycleloss.py": 1642580463.495596,
"train_scripts\\trainer_GramFM.py": 1643095575.2628715,
"utilities\\ImagenetNorm.py": 1642732910.5280058,
"utilities\\reverse2original.py": 1648533907.6187606,
"train_yamls\\train_cycleloss.yaml": 1647769120.6110919,
"train_yamls\\train_GramFM.yaml": 1643398791.363959,
"train_yamls\\train_512FM_Modulation.yaml": 1643022022.3165789,
"face_crop.py": 1649079350.7075238,
"face_crop_video.py": 1649988435.4005315,
"similarity.py": 1643269705.1073737,
"train_multigpu.py": 1650004781.6307411,
"components\\arcface_decoder.py": 1643396144.2575414,
"components\\Generator_nobias.py": 1643179001.810856,
"data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1647657822.9402719,
"data_tools\\data_loader_VGGFace2HQ_Rec.py": 1647657822.9392715,
"test_scripts\\tester_arcface_Rec.py": 1647657822.946273,
"test_scripts\\tester_image.py": 1647657822.9472733,
"torch_utils\\custom_ops.py": 1647657822.9482737,
"torch_utils\\misc.py": 1647657822.9492736,
"torch_utils\\persistence.py": 1647657822.9552753,
"torch_utils\\training_stats.py": 1647657822.9562755,
"torch_utils\\utils_spectrum.py": 1647657822.9562755,
"torch_utils\\__init__.py": 1647657822.9482737,
"torch_utils\\ops\\bias_act.py": 1647657822.9502747,
"torch_utils\\ops\\conv2d_gradfix.py": 1647657822.9512744,
"torch_utils\\ops\\conv2d_resample.py": 1647657822.9512744,
"torch_utils\\ops\\filtered_lrelu.py": 1647657822.9532747,
"torch_utils\\ops\\fma.py": 1647657822.9532747,
"torch_utils\\ops\\grid_sample_gradfix.py": 1647657822.9542756,
"torch_utils\\ops\\upfirdn2d.py": 1647657822.9552753,
"torch_utils\\ops\\__init__.py": 1647657822.9492736,
"train_scripts\\trainer_arcface_rec.py": 1647657822.9582758,
"train_scripts\\trainer_multigpu_base.py": 1647657822.9602764,
"train_scripts\\trainer_multi_gpu.py": 1647657822.9592762,
"train_yamls\\train_arcface_rec.yaml": 1647657822.9622767,
"train_yamls\\train_multigpu.yaml": 1647657822.963277,
"data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1649177633.2426238,
"data_tools\\data_loader_VGGFace2HQ_Rec.py": 1643398754.86898,
"test_scripts\\tester_arcface_Rec.py": 1643431261.9333818,
"test_scripts\\tester_image.py": 1648028827.923354,
"torch_utils\\custom_ops.py": 1640773190.0,
"torch_utils\\misc.py": 1640773190.0,
"torch_utils\\persistence.py": 1640773190.0,
"torch_utils\\training_stats.py": 1640773190.0,
"torch_utils\\utils_spectrum.py": 1640773190.0,
"torch_utils\\__init__.py": 1640773190.0,
"torch_utils\\ops\\bias_act.py": 1640773190.0,
"torch_utils\\ops\\conv2d_gradfix.py": 1640773190.0,
"torch_utils\\ops\\conv2d_resample.py": 1640773190.0,
"torch_utils\\ops\\filtered_lrelu.py": 1640773190.0,
"torch_utils\\ops\\fma.py": 1640773190.0,
"torch_utils\\ops\\grid_sample_gradfix.py": 1640773190.0,
"torch_utils\\ops\\upfirdn2d.py": 1640773190.0,
"torch_utils\\ops\\__init__.py": 1640773190.0,
"train_scripts\\trainer_arcface_rec.py": 1643399647.0182135,
"train_scripts\\trainer_multigpu_base.py": 1644131205.772292,
"train_scripts\\trainer_multi_gpu.py": 1648285132.309124,
"train_yamls\\train_arcface_rec.yaml": 1643398807.3434353,
"train_yamls\\train_multigpu.yaml": 1644549590.0652373,
"wandb\\run-20220129_032741-340btp9k\\files\\conda-environment.yaml": 1643398065.409959,
"wandb\\run-20220129_032741-340btp9k\\files\\config.yaml": 1643398069.2392955,
"wandb\\run-20220129_032939-2nmaozxq\\files\\conda-environment.yaml": 1643398182.647548,
@@ -100,93 +100,208 @@
"wandb\\run-20220129_034859-2puk6sph\\files\\config.yaml": 1643399477.881678,
"wandb\\run-20220129_035624-3hmwgcgw\\files\\conda-environment.yaml": 1643399787.8899708,
"wandb\\run-20220129_035624-3hmwgcgw\\files\\config.yaml": 1643426465.6088357,
"dnnlib\\util.py": 1647657822.941272,
"dnnlib\\__init__.py": 1647657822.941272,
"components\\Generator_ori.py": 1647657822.9332705,
"losses\\cos.py": 1647657822.9442725,
"data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1647657822.9402719,
"speed_test.py": 1647657822.945273,
"components\\DeConv_Invo.py": 1647657822.9302697,
"dnnlib\\util.py": 1640773190.0,
"dnnlib\\__init__.py": 1640773190.0,
"components\\Generator_ori.py": 1644689174.414655,
"losses\\cos.py": 1644229583.4023254,
"data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1644860106.943826,
"speed_test.py": 1648982366.0803514,
"components\\DeConv_Invo.py": 1644426607.1588645,
"components\\Generator_reduce_up.py": 1644688655.2096283,
"components\\Generator_upsample.py": 1647657822.9352713,
"components\\misc\\Involution.py": 1647657822.9352713,
"train_yamls\\train_Invoup.yaml": 1647657822.9622767,
"flops.py": 1647657822.9422722,
"detection_test.py": 1647657822.941272,
"components\\DeConv_Depthwise.py": 1647657822.9292698,
"components\\DeConv_Depthwise1.py": 1647657822.9292698,
"components\\Generator_upsample.py": 1644689723.8293872,
"components\\misc\\Involution.py": 1644509321.5267963,
"train_yamls\\train_Invoup.yaml": 1644689981.9794765,
"flops.py": 1649040334.6186154,
"detection_test.py": 1644935512.6830947,
"components\\DeConv_Depthwise.py": 1645064447.4379447,
"components\\DeConv_Depthwise1.py": 1644946969.5054545,
"components\\Generator_modulation_depthwise.py": 1644861291.4467516,
"components\\Generator_modulation_depthwise_config.py": 1647657822.93227,
"components\\Generator_modulation_up.py": 1647657822.9332705,
"components\\Generator_oriae_modulation.py": 1647657822.934271,
"components\\Generator_ori_config.py": 1647657822.934271,
"train_scripts\\trainer_multi_gpu1.py": 1647657822.9602764,
"train_yamls\\train_Depthwise.yaml": 1647657822.961277,
"train_yamls\\train_depthwise_modulation.yaml": 1647657822.963277,
"train_yamls\\train_oriae_modulation.yaml": 1647657822.9642777,
"train_distillation_mgpu.py": 1647657822.9562755,
"components\\DeConv.py": 1647657822.9292698,
"components\\DeConv_Depthwise_ECA.py": 1647657822.9292698,
"components\\ECA.py": 1647657822.9302697,
"components\\ECA_Depthwise_Conv.py": 1647657822.93127,
"components\\Generator_eca_depthwise.py": 1647657822.93227,
"losses\\KA.py": 1647657822.9432724,
"train_scripts\\trainer_distillation_mgpu.py": 1647657822.9592762,
"train_yamls\\train_distillation.yaml": 1647657822.963277,
"annotation.py": 1647657822.9172668,
"components\\DeConv_ECA_Invo.py": 1647657822.9302697,
"components\\DeConv_Invobn.py": 1647657822.9302697,
"components\\Generator_Invobn_config.py": 1647657822.93127,
"components\\Generator_Invobn_config1.py": 1647657822.93127,
"components\\misc\\Involution_BN.py": 1647657822.9362714,
"components\\misc\\Involution_ECA.py": 1647657822.9362714,
"train_yamls\\train_Invobn_config.yaml": 1647657822.9622767,
"components\\Generator_Invobn_config2.py": 1647657822.93227,
"components\\Generator_Invobn_config3.py": 1647657822.93227,
"components\\Generator_ori_modulation_config.py": 1647657822.934271,
"test_scripts\\tester_image_allstep.py": 1647657822.9482737,
"train_yamls\\train_ori_modulation_config.yaml": 1647657822.9642777,
"test_arcface.py": 1647657822.946273,
"arcface_torch\\dataset.py": 1647657822.9222684,
"arcface_torch\\eval_ijbc.py": 1647657822.9242685,
"arcface_torch\\inference.py": 1647657822.9242685,
"arcface_torch\\losses.py": 1647657822.9242685,
"arcface_torch\\lr_scheduler.py": 1647657822.9242685,
"arcface_torch\\onnx_helper.py": 1647657822.9252684,
"arcface_torch\\onnx_ijbc.py": 1647657822.9252684,
"arcface_torch\\partial_fc.py": 1647657822.9252684,
"arcface_torch\\torch2onnx.py": 1647657822.9262686,
"arcface_torch\\train.py": 1647657822.9262686,
"arcface_torch\\backbones\\iresnet.py": 1647657822.918267,
"arcface_torch\\backbones\\iresnet2060.py": 1647657822.9192681,
"arcface_torch\\backbones\\mobilefacenet.py": 1647657822.9192681,
"arcface_torch\\backbones\\__init__.py": 1647657822.918267,
"arcface_torch\\configs\\3millions.py": 1647657822.9192681,
"arcface_torch\\configs\\base.py": 1647657822.9192681,
"arcface_torch\\configs\\glint360k_mobileface_lr02_bs4k.py": 1647657822.9202676,
"arcface_torch\\configs\\glint360k_r100_lr02_bs4k_16gpus.py": 1647657822.9202676,
"arcface_torch\\configs\\ms1mv3_mobileface_lr02.py": 1647657822.9202676,
"arcface_torch\\configs\\ms1mv3_r100_lr02.py": 1647657822.9202676,
"arcface_torch\\configs\\ms1mv3_r50_lr02.py": 1647657822.9202676,
"arcface_torch\\configs\\webface42m_mobilefacenet_pfc02_bs8k_16gpus.py": 1647657822.9212687,
"arcface_torch\\configs\\webface42m_r100_lr01_pfc02_bs4k_16gpus.py": 1647657822.9212687,
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_32gpus.py": 1647657822.9212687,
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_8gpus.py": 1647657822.9212687,
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs8k_16gpus.py": 1647657822.9212687,
"arcface_torch\\configs\\__init__.py": 1647657822.9192681,
"arcface_torch\\eval\\verification.py": 1647657822.923268,
"arcface_torch\\eval\\__init__.py": 1647657822.923268,
"arcface_torch\\utils\\plot.py": 1647657822.927269,
"arcface_torch\\utils\\utils_callbacks.py": 1647657822.927269,
"arcface_torch\\utils\\utils_config.py": 1647657822.927269,
"arcface_torch\\utils\\utils_logging.py": 1647657822.927269,
"arcface_torch\\utils\\__init__.py": 1647657822.9262686,
"components\\LSTU.py": 1647702612.240765,
"test_scripts\\tester_ID_Pose.py": 1647657822.946273,
"train_scripts\\trainer_distillation_mgpu_withrec_importweight.py": 1647657822.9592762,
"train_scripts\\trainer_multi_gpu_CUT.py": 1647676964.475,
"train_scripts\\trainer_multi_gpu_cycle.py": 1647705628.7020626,
"components\\Generator_LSTU_config.py": 1647954099.1135788,
"components\\Generator_Res_config.py": 1648006159.4385264,
"train_yamls\\train_cycleloss_res.yaml": 1648006232.5734456
"components\\Generator_modulation_depthwise_config.py": 1645262162.9779513,
"components\\Generator_modulation_up.py": 1644946498.7005584,
"components\\Generator_oriae_modulation.py": 1644897798.1987727,
"components\\Generator_ori_config.py": 1646329319.6131227,
"train_scripts\\trainer_multi_gpu1.py": 1644859528.8428593,
"train_yamls\\train_Depthwise.yaml": 1644860961.099242,
"train_yamls\\train_depthwise_modulation.yaml": 1645035964.9551077,
"train_yamls\\train_oriae_modulation.yaml": 1644897891.2576747,
"train_distillation_mgpu.py": 1645554603.908166,
"components\\DeConv.py": 1645263338.9001615,
"components\\DeConv_Depthwise_ECA.py": 1645265769.1076133,
"components\\ECA.py": 1614848426.9604986,
"components\\ECA_Depthwise_Conv.py": 1645265754.2023985,
"components\\Generator_eca_depthwise.py": 1645266338.9750814,
"losses\\KA.py": 1646388425.4841197,
"train_scripts\\trainer_distillation_mgpu.py": 1645601961.4139585,
"train_yamls\\train_distillation.yaml": 1645600099.540936,
"annotation.py": 1648654581.017103,
"components\\DeConv_ECA_Invo.py": 1645869347.379311,
"components\\DeConv_Invobn.py": 1645862876.018001,
"components\\Generator_Invobn_config.py": 1645929418.6924264,
"components\\Generator_Invobn_config1.py": 1645862695.8743145,
"components\\misc\\Involution_BN.py": 1645867197.3984175,
"components\\misc\\Involution_ECA.py": 1645869012.4927464,
"train_yamls\\train_Invobn_config.yaml": 1646101598.499709,
"components\\Generator_Invobn_config2.py": 1645962618.7056074,
"components\\Generator_Invobn_config3.py": 1646302561.1984286,
"components\\Generator_ori_modulation_config.py": 1646329636.719998,
"test_scripts\\tester_image_allstep.py": 1646312637.9363256,
"train_yamls\\train_ori_modulation_config.yaml": 1646330406.200162,
"test_arcface.py": 1647448497.69041,
"arcface_torch\\dataset.py": 1647445446.261035,
"arcface_torch\\eval_ijbc.py": 1647445446.2630043,
"arcface_torch\\inference.py": 1647445446.2630043,
"arcface_torch\\losses.py": 1647445446.2630043,
"arcface_torch\\lr_scheduler.py": 1647445446.2630043,
"arcface_torch\\onnx_helper.py": 1647445446.2630043,
"arcface_torch\\onnx_ijbc.py": 1647445446.2640254,
"arcface_torch\\partial_fc.py": 1647445446.2640254,
"arcface_torch\\torch2onnx.py": 1647445446.2640254,
"arcface_torch\\train.py": 1647445446.2649992,
"arcface_torch\\backbones\\iresnet.py": 1647445446.2580183,
"arcface_torch\\backbones\\iresnet2060.py": 1647445446.2580183,
"arcface_torch\\backbones\\mobilefacenet.py": 1647445446.2580183,
"arcface_torch\\backbones\\__init__.py": 1647445446.25702,
"arcface_torch\\configs\\3millions.py": 1647445446.2580183,
"arcface_torch\\configs\\base.py": 1647445446.259039,
"arcface_torch\\configs\\glint360k_mobileface_lr02_bs4k.py": 1647445446.259039,
"arcface_torch\\configs\\glint360k_r100_lr02_bs4k_16gpus.py": 1647445446.259039,
"arcface_torch\\configs\\ms1mv3_mobileface_lr02.py": 1647445446.259039,
"arcface_torch\\configs\\ms1mv3_r100_lr02.py": 1647445446.259039,
"arcface_torch\\configs\\ms1mv3_r50_lr02.py": 1647445446.260039,
"arcface_torch\\configs\\webface42m_mobilefacenet_pfc02_bs8k_16gpus.py": 1647445446.260039,
"arcface_torch\\configs\\webface42m_r100_lr01_pfc02_bs4k_16gpus.py": 1647445446.260039,
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_32gpus.py": 1647445446.260039,
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_8gpus.py": 1647445446.260039,
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs8k_16gpus.py": 1647445446.260039,
"arcface_torch\\configs\\__init__.py": 1647445446.2580183,
"arcface_torch\\eval\\verification.py": 1647445446.2620306,
"arcface_torch\\eval\\__init__.py": 1647445446.2620306,
"arcface_torch\\utils\\plot.py": 1647445446.2649992,
"arcface_torch\\utils\\utils_callbacks.py": 1647445446.2649992,
"arcface_torch\\utils\\utils_config.py": 1647445446.2649992,
"arcface_torch\\utils\\utils_logging.py": 1647445446.2659965,
"arcface_torch\\utils\\__init__.py": 1647445446.2649992,
"components\\LSTU.py": 1648482475.4378786,
"test_scripts\\tester_ID_Pose.py": 1646558809.85301,
"train_scripts\\trainer_distillation_mgpu_withrec_importweight.py": 1646391740.1106014,
"train_scripts\\trainer_multi_gpu_CUT.py": 1647769120.5968685,
"train_scripts\\trainer_multi_gpu_cycle.py": 1648313934.2140906,
"components\\Generator_LSTU_config.py": 1648028831.4087331,
"components\\Generator_Res_config.py": 1648053382.053794,
"train_yamls\\train_cycleloss_res.yaml": 1648103614.4515965,
"clear_dataset.py": 1648920044.7807672,
"id_cos.py": 1648569510.5593822,
"translation_list2json.py": 1648106406.478007,
"components\\Generator_ResSkip_config.py": 1648526573.2056787,
"components\\Generator_Res_config1.py": 1648092270.7609806,
"components\\Generator_Res_config2.py": 1648103885.6257715,
"test_scripts\\tester_image_list.py": 1648145245.0948818,
"test_scripts\\tester_image_nofusion.py": 1648096849.0402405,
"train_yamls\\train_cycleloss_resskip.yaml": 1648313962.968481,
"check_list.txt": 1648657338.5051336,
"test_imgs_list.txt": 1649574560.1498592,
"vggface2hq_failed.txt": 1648926017.999394,
"arcface_torch\\requirement.txt": 1647445446.2640254,
"wandb\\run-20220129_032741-340btp9k\\files\\requirements.txt": 1643398065.409959,
"wandb\\run-20220129_032939-2nmaozxq\\files\\requirements.txt": 1643398182.647548,
"wandb\\run-20220129_033051-21z19tyg\\files\\requirements.txt": 1643398254.926299,
"wandb\\run-20220129_033202-16la4gpu\\files\\requirements.txt": 1643398325.8784783,
"wandb\\run-20220129_034327-1bmseytq\\files\\requirements.txt": 1643399010.865907,
"wandb\\run-20220129_034859-2puk6sph\\files\\requirements.txt": 1643399343.3508356,
"wandb\\run-20220129_035624-3hmwgcgw\\files\\requirements.txt": 1643399787.8869605,
"components\\Generator_featout_config.py": 1648877243.4813964,
"components\\Generator_ResSkip_config1.py": 1648530486.7970698,
"components\\LSTU_Config.py": 1648528200.7229428,
"components\\Nonstau_Discriminator.py": 1648476236.8430562,
"components\\Nonstau_Discriminator_FM.py": 1649833901.6153808,
"metrics\\equivariance.py": 1640773190.0,
"metrics\\frechet_inception_distance.py": 1640773190.0,
"metrics\\inception_score.py": 1640773190.0,
"metrics\\kernel_inception_distance.py": 1640773190.0,
"metrics\\metric_main.py": 1640773190.0,
"metrics\\metric_utils.py": 1640773190.0,
"metrics\\perceptual_path_length.py": 1640773190.0,
"metrics\\precision_recall.py": 1640773190.0,
"test_scripts\\tester_image_list_w_mask.py": 1649074097.8263073,
"test_scripts\\tester_image_w_mask.py": 1648529828.1194224,
"train_scripts\\trainer_mgpu_fm.py": 1648878513.5066502,
"train_scripts\\trainer_multi_gpu_cycle_nonstatue_dis.py": 1648559801.6695006,
"train_yamls\\train_cycleloss_fm_nonstatu.yaml": 1648572102.661056,
"train_yamls\\train_cycleloss_resskip_nonstatu.yaml": 1648527833.9132054,
"components\\Generator_Res_config3.py": 1648628068.1252878,
"data_tools\\data_loader_FFHQ_multigpu.py": 1648640139.408,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\align_and_crop_dir.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\generate_mask.py": 1648652827.224725,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\test_enhance_dir_align.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\test_enhance_dir_unalign.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\test_enhance_single_unalign.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\train.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\base_dataset.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\celebahqmask_dataset.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\ffhq_dataset.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\image_folder.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\single_dataset.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\__init__.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\base_model.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\blocks.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\enhance_model.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\loss.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\networks.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\parse_model.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\psfrnet.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\__init__.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\base_options.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\test_options.py": 1648647365.0084722,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\train_options.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\__init__.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\utils\\logger.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\utils\\timer.py": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\utils\\utils.py": 1648653194.9661705,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\requirements.txt": 1640774152.0,
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\check_points\\experiment_name\\test_opt.txt": 1648653211.0647676,
"components\\Generator_maskhead_config.py": 1648919192.4128494,
"data_tools\\data_loader_VGGFace2HQ_multigpu_w_mask.py": 1648950658.8917239,
"train_scripts\\trainer_mgpu_fm_w_mask.py": 1649842622.1032736,
"train_scripts\\trainer_mgpu_maskloss.py": 1650003673.00616,
"train_yamls\\train_maskhead_fm_nonstatu.yaml": 1648918857.2325976,
"train_yamls\\train_maskhead_maskloss.yaml": 1648919545.869737,
"dataset.check.py": 1648925868.9100032,
"components\\Generator_involution_maskhead_config.py": 1648919192.4128494,
"components\\Generator_VGGStyle_maskhead_config.py": 1649177890.0438528,
"train_yamls\\train_maskhead_fm_vggstyle.yaml": 1649177995.515654,
"train_yamls\\train_maskloss.yaml": 1648920140.455,
"components\\Generator_maskhead_config1.py": 1649248730.7064345,
"train_yamls\\train_maskhead_hififace.yaml": 1649345218.4856737,
"filter.py": 1649836163.761168,
"test_json.py": 1649232568.5223434,
"components\\Generator_2maskhead_config copy.py": 1649816572.443475,
"components\\Generator_2maskhead_config.py": 1649816572.443475,
"components\\Generator_maskhead_config2.py": 1649833973.4127545,
"components\\Generator_starganv2.py": 1649845840.4322417,
"face_enhancer\\gfpgan\\train.py": 1644942770.0,
"face_enhancer\\gfpgan\\utils.py": 1644942770.0,
"face_enhancer\\gfpgan\\version.py": 1648618484.3356814,
"face_enhancer\\gfpgan\\__init__.py": 1644942770.0,
"face_enhancer\\gfpgan\\archs\\arcface_arch.py": 1644942770.0,
"face_enhancer\\gfpgan\\archs\\gfpganv1_arch.py": 1644942770.0,
"face_enhancer\\gfpgan\\archs\\gfpganv1_clean_arch.py": 1644942770.0,
"face_enhancer\\gfpgan\\archs\\gfpgan_bilinear_arch.py": 1644942770.0,
"face_enhancer\\gfpgan\\archs\\stylegan2_bilinear_arch.py": 1644942770.0,
"face_enhancer\\gfpgan\\archs\\stylegan2_clean_arch.py": 1644942770.0,
"face_enhancer\\gfpgan\\archs\\__init__.py": 1644942770.0,
"face_enhancer\\gfpgan\\data\\ffhq_degradation_dataset.py": 1644942770.0,
"face_enhancer\\gfpgan\\data\\__init__.py": 1644942770.0,
"face_enhancer\\gfpgan\\models\\gfpgan_model.py": 1644942770.0,
"face_enhancer\\gfpgan\\models\\__init__.py": 1644942770.0,
"face_enhancer\\scripts\\convert_gfpganv_to_clean.py": 1644942770.0,
"face_enhancer\\scripts\\parse_landmark.py": 1644942770.0,
"test_scripts\\tester_image_list_w_2mask.py": 1649768648.90081,
"test_scripts\\tester_image_w_2mask.py": 1649729361.2799067,
"test_scripts\\tester_image_w_mask_gfpgan.py": 1649872098.8747168,
"test_scripts\\tester_video_gfpgan.py": 1649907748.9359775,
"train_scripts\\trainer_mgpu_2maskloss.py": 1649699504.6697042,
"train_yamls\\train_1maskhead.yaml": 1650003571.40854,
"train_yamls\\train_2maskhead.yaml": 1649953481.822017,
"train_yamls\\train_maskhead_hififace1.yaml": 1649471700.4954374,
"components\\Generator_2mask.py": 1649953828.8860412
}
+2 -1
View File
@@ -2,7 +2,8 @@
"white_list": {
"extension": [
"py",
"yaml"
"yaml",
"txt"
],
"file": [],
"path": []
+9
View File
@@ -8,6 +8,15 @@
"ckp_path": "train_logs",
"logfilename": "filestate_machine0.json"
},
{
"ip": "119.29.91.52",
"user": "ubuntu",
"port": 22,
"passwd": "zpKlOW0sMlyt!xhE",
"path": "/home/ubuntu/CXH/simswap_plus",
"ckp_path": "train_logs",
"logfilename": "filestate_machine3.json"
},
{
"ip": "2001:da8:8000:6880:f284:d61c:3c76:f9cb",
"user": "ps",
+1
View File
@@ -1,6 +1,7 @@
# Simswap++
## Dependencies
- moviepy
- python >= 3.7
- yaml (pip install pyyaml)
- paramiko (For ssh file transportation)
+2 -2
View File
@@ -5,7 +5,7 @@
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 27th February 2022 11:03:58 am
# Last Modified: Wednesday, 30th March 2022 11:36:20 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -33,7 +33,7 @@ def str2bool(v):
def getParameters():
parser = argparse.ArgumentParser()
# general
parser.add_argument('--image_dir', type=str, default="G:\\VGGFace2-HQ\\VGGface2_None_norm_512_true_bygfpgan")
parser.add_argument('--image_dir', type=str, default="G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan")
parser.add_argument('--savetxt', type=str, default="./check_list.txt")
parser.add_argument('--winWidth', type=int, default=512)
parser.add_argument('--winHeight', type=int, default=512)
+2 -2
View File
@@ -1,6 +1,6 @@
{
"breakpoint": [
31,
110
54,
101
]
}
+1278
View File
File diff suppressed because it is too large Load Diff
+7 -4
View File
@@ -5,7 +5,7 @@
# Created Date: Thursday March 24th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 24th March 2022 3:32:40 pm
# Last Modified: Sunday, 3rd April 2022 1:20:44 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -18,11 +18,14 @@ if __name__ == "__main__":
savePath = "./vggface2hq_failed.txt"
env_config = readConfig('env/env.json')
env_config = env_config["path"]
dataset_root = env_config["dataset_paths"]["vggface2_hq"]
dataset_root = env_config["dataset_paths"]["vggface2_hq"]["images"]
# dataset_root = "G:/VGGFace2-HQ/newversion"
print(dataset_root)
with open(savePath,'r') as logf:
for line in logf:
img_path = os.path.join(dataset_root,line[:-1]).replace("\\","/")
img_path = os.path.join(dataset_root,line.replace("\n","")).replace("\\","/")
try:
os.rename(img_path,img_path+".deleted")
except: pass
except Exception as e:
print(e)
+304
View File
@@ -0,0 +1,304 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th April 2022 7:03:46 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"].lower()
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
norm_mask= nn.InstanceNorm2d
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
norm_mask = nn.BatchNorm2d
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 128
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 64
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 32
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
# self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 1
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 32
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 64
# self.maskhead = nn.Sequential(
# nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
# norm_mask, # 64
# activation,
# nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
# nn.Sigmoid())
self.maskhead_lr = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
norm_mask(in_channel, affine=True), # 32
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
norm_mask(in_channel//4, affine=True), # 64
activation
)
self.maskhead_hr = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//4, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False),
norm_mask(in_channel//16, affine=True), # 128
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid() # 256
)
self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//4, 1, kernel_size=1, stride=1),
nn.Sigmoid())
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
mask_feat= self.maskhead_lr(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up4(res,id)
res = self.up3(res,id)
mask_lr= self.maskhead_out(mask_feat)
# res = (1-mask) * self.sigma(skip) + mask * res
res = (1-mask_lr) * skip + mask_lr * res
res = self.up2(res) # + skip
res = self.up1(res)
res = self.to_rgb(res)
mask_hr=self.maskhead_hr(mask_feat)
res = (1-mask_hr) * img + mask_hr * res
return res, mask_lr, mask_hr
+312
View File
@@ -0,0 +1,312 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 15th April 2022 12:30:27 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"].lower()
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
norm_mask= nn.InstanceNorm2d
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
norm_mask = nn.BatchNorm2d
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
# self.maskhead = nn.Sequential(
# nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
# norm_mask, # 64
# activation,
# nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
# nn.Sigmoid())
self.maskhead_lr = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
norm_mask(in_channel, affine=True), # 32
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
norm_mask(in_channel//4, affine=True), # 64
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1),
norm_mask(in_channel//8, affine=True), # 128
activation,
)
self.maskhead_hr = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False),
norm_mask(in_channel//16, affine=True), # 256
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid() # 512
)
self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1),
nn.Sigmoid())
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
mask_feat= self.maskhead_lr(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
mask_lr= self.maskhead_out(mask_feat)
# res = (1-mask) * self.sigma(skip) + mask * res
res = (1-mask_lr) * skip + mask_lr * res
res = self.up2(res) # + skip
res = self.up1(res)
res = self.to_rgb(res)
mask_hr=self.maskhead_hr(mask_feat)
res = (1-mask_hr) * img + mask_hr * res
return res, mask_lr, mask_hr
+312
View File
@@ -0,0 +1,312 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th April 2022 12:45:55 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"].lower()
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
norm_mask= nn.InstanceNorm2d
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
norm_mask = nn.BatchNorm2d
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
# self.maskhead = nn.Sequential(
# nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
# norm_mask, # 64
# activation,
# nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
# nn.Sigmoid())
self.maskhead_lr = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
norm_mask(in_channel, affine=True), # 32
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
norm_mask(in_channel//4, affine=True), # 64
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1),
norm_mask(in_channel//8, affine=True), # 128
activation,
)
self.maskhead_hr = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False),
norm_mask(in_channel//16, affine=True), # 256
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid() # 512
)
self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1),
nn.Sigmoid())
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
# self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
mask_feat= self.maskhead_lr(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
mask_lr= self.maskhead_out(mask_feat)
# res = (1-mask) * self.sigma(skip) + mask * res
res = (1-mask_lr) * skip + mask_lr * res
res = self.up2(res,id) # + skip
res = self.up1(res,id)
res = self.to_rgb(res)
mask_hr=self.maskhead_hr(mask_feat)
res = (1-mask_hr) * img + mask_hr * res
return res, mask_lr, mask_hr
+453
View File
@@ -0,0 +1,453 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 18th April 2022 10:20:12 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import math
import torch
from torch import nn
import torch.nn.functional as F
from components.ModulatedDWConv import ModulatedDWConv2d
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Sequential(
nn.Conv2d(dim_in, dim_in, 3, 1, 1, groups=dim_in),
nn.Conv2d(dim_in, dim_in, 1, 1)
)
self.conv2 = nn.Sequential(
nn.Conv2d(dim_in, dim_in, 3, 1, 1, groups=dim_in),
nn.Conv2d(dim_in, dim_out, 1, 1)
)
# self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
# self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
# class ResUpBlk(nn.Module):
# def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
# super().__init__()
# self.actv = actv
# self.normalize = normalize
# self.learned_sc = dim_in != dim_out
# self.equal_var = math.sqrt(2)
# self._build_weights(dim_in, dim_out)
# def _build_weights(self, dim_in, dim_out):
# self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
# self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
# if self.normalize.lower() == "in":
# self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
# self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
# elif self.normalize.lower() == "bn":
# self.norm1 = nn.BatchNorm2d(dim_in)
# self.norm2 = nn.BatchNorm2d(dim_out)
# if self.learned_sc:
# self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
# def _shortcut(self, x):
# x = F.interpolate(x, scale_factor=2, mode='nearest')
# if self.learned_sc:
# x = self.conv1x1(x)
# return x
# def _residual(self, x):
# x = self.norm1(x)
# x = self.actv(x)
# x = F.interpolate(x, scale_factor=2, mode='nearest')
# x = self.conv1(x)
# x = self.norm2(x)
# x = self.actv(x)
# x = self.conv2(x)
# return x
# def forward(self, x):
# out = self._residual(x)
# out = (out + self._shortcut(x)) / self.equal_var
# return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Sequential(
nn.Conv2d(dim_in, dim_in, 3, 1, 1,groups=dim_in),
nn.Conv2d(dim_in, dim_out, 1, 1)
)
self.conv2 = nn.Sequential(
nn.Conv2d(dim_out, dim_out, 3, 1, 1,groups=dim_out),
nn.Conv2d(dim_out, dim_out, 1)
)
# self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
# self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class ModulatedResBlk(nn.Module):
def __init__(self,
dim_in,
dim_out,
style_dim=512,
actv=nn.LeakyReLU(0.2),
upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = ModulatedDWConv2d(dim_in, dim_out, style_dim, 3)
self.conv2 = ModulatedDWConv2d(dim_out, dim_out, style_dim, 3)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x,s)
x = self.actv(x)
x = self.conv2(x,s)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"].lower()
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
norm_mask= nn.InstanceNorm2d
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
norm_mask = nn.BatchNorm2d
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
# self.maskhead = nn.Sequential(
# nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
# norm_mask, # 64
# activation,
# nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
# nn.Sigmoid())
self.maskhead_lr = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
norm_mask(in_channel, affine=True), # 32
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
norm_mask(in_channel//4, affine=True), # 64
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1),
norm_mask(in_channel//8, affine=True), # 128
activation,
)
self.maskhead_hr = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False),
norm_mask(in_channel//16, affine=True), # 256
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid() # 512
)
self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1),
nn.Sigmoid())
# self.maskhead_lr = nn.Sequential(
# nn.UpsamplingNearest2d(scale_factor = 2),
# nn.Conv2d(in_channel*8, in_channel*8, 3, 1, 1,groups=in_channel*8),
# nn.Conv2d(in_channel*8, in_channel, 1, bias=False),
# # nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
# norm_mask(in_channel, affine=True), # 32
# activation,
# nn.UpsamplingNearest2d(scale_factor = 2),
# nn.Conv2d(in_channel, in_channel, 3, 1, 1,groups=in_channel),
# nn.Conv2d(in_channel, in_channel//4, 1, bias=False),
# # nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
# norm_mask(in_channel//4, affine=True), # 64
# activation,
# nn.UpsamplingNearest2d(scale_factor = 2),
# # nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1),
# nn.Conv2d(in_channel//4, in_channel//4, 3, 1, 1,groups=in_channel//4),
# nn.Conv2d(in_channel//4, in_channel//8, 1, bias=False),
# norm_mask(in_channel//8, affine=True), # 128
# activation,
# )
# self.maskhead_hr = nn.Sequential(
# nn.UpsamplingNearest2d(scale_factor = 2),
# # nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False),
# nn.Conv2d(in_channel//8, in_channel//8, 3, 1, 1,groups=in_channel//8),
# nn.Conv2d(in_channel//8, in_channel//16, 1, bias=False),
# norm_mask(in_channel//16, affine=True), # 256
# activation,
# nn.UpsamplingNearest2d(scale_factor = 2),
# nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1),
# nn.Sigmoid() # 512
# )
# self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1),
# nn.Sigmoid())
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel*2, in_channel, normalize=norm)
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel, in_channel, normalize=norm)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
mask_feat= self.maskhead_lr(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
mask_lr= self.maskhead_out(mask_feat)
# res = (1-mask) * self.sigma(skip) + mask * res
res = (1-mask_lr) * skip + mask_lr * res
res = self.up2(res,id) # + skip
res = self.up1(res,id)
res = self.to_rgb(res)
mask_hr=self.maskhead_hr(mask_feat)
res = (1-mask_hr) * img + mask_hr * res
return res, mask_lr, mask_hr
@@ -0,0 +1,293 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 13th April 2022 10:22:52 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="bn", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"]
norm = norm.lower()
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
# self.sigma = ResBlk(in_channel*2,in_channel*2)
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
self.maskhead_lr = nn.Sequential(
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel), # 64
activation,
nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
self.maskhead_hr = nn.Sequential(
nn.Conv2d(in_channel, in_channel//8, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel//8), # 64
activation,
nn.Conv2d(in_channel//8, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
mask= self.maskhead_lr(res)
# res = (1-mask) * self.sigma(skip) + mask * res
res = (1-mask) * skip + mask * res
res = self.up2(res) # + skip
res = self.up1(res)
mask_hr = self.maskhead_hr(res)
res = self.to_rgb(res)
res = (1-mask_hr)*img + mask_hr*res
return res, mask, mask_hr
+293
View File
@@ -0,0 +1,293 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 13th April 2022 10:22:52 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="bn", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"]
norm = norm.lower()
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
# self.sigma = ResBlk(in_channel*2,in_channel*2)
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
self.maskhead_lr = nn.Sequential(
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel), # 64
activation,
nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
self.maskhead_hr = nn.Sequential(
nn.Conv2d(in_channel, in_channel//8, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel//8), # 64
activation,
nn.Conv2d(in_channel//8, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
mask= self.maskhead_lr(res)
# res = (1-mask) * self.sigma(skip) + mask * res
res = (1-mask) * skip + mask * res
res = self.up2(res) # + skip
res = self.up1(res)
mask_hr = self.maskhead_hr(res)
res = self.to_rgb(res)
res = (1-mask_hr)*img + mask_hr*res
return res, mask, mask_hr
+276
View File
@@ -0,0 +1,276 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 29th March 2022 12:02:53 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
import torch.nn.functional as F
import math
from components.LSTU import LSTU
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"]
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
self.lstu = LSTU(in_channel*2,norm)
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel*2, in_channel, normalize="in") # 256
# self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
# nn.BatchNorm2d(in_channel),
# activation,
# nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
# nn.Sigmoid()
# )
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
# res = self.down6(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
# res = self.up6(res,id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
res,mask = self.lstu(skip, res)
res = self.up2(res,id) # + skip
res = self.up1(res,id)
res = self.to_rgb(res)
return res,mask
+285
View File
@@ -0,0 +1,285 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 29th March 2022 1:08:05 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"]
lstu_script = kwargs["lstu_script"]
lstu_class = kwargs["lstu_class"]
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
script_name = "components." + lstu_script
package = __import__(script_name, fromlist=True)
lstu_class = getattr(package, lstu_class)
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
self.lstu = lstu_class(in_channel*2,norm)
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel*2, in_channel, normalize="in") # 256
# self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
# nn.BatchNorm2d(in_channel),
# activation,
# nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
# nn.Sigmoid()
# )
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
# res = self.down6(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
# res = self.up6(res,id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
res,mask = self.lstu(skip, res)
res = self.up2(res,id) # + skip
res = self.up1(res,id)
res = self.to_rgb(res)
return res
+285
View File
@@ -0,0 +1,285 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 30th March 2022 4:14:27 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"]
lstu_script = kwargs["lstu_script"]
lstu_class = kwargs["lstu_class"]
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
script_name = "components." + lstu_script
package = __import__(script_name, fromlist=True)
lstu_class = getattr(package, lstu_class)
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
self.lstu = lstu_class(in_channel*2,norm)
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel*2, in_channel, normalize="in") # 256
# self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
# nn.BatchNorm2d(in_channel),
# activation,
# nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
# nn.Sigmoid()
# )
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
# res = self.down6(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
# res = self.up6(res,id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
# res,mask = self.lstu(skip, res)
res = self.up2(res,id) # + skip
res = self.up1(res,id)
res = self.to_rgb(res)
return res
@@ -0,0 +1,204 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 6th April 2022 12:55:51 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainUpBlock(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2)):
super().__init__()
self.actv = actv
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.norm = AdaIN(style_dim, dim_out)
def forward(self, x, s):
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv(x)
x = self.norm(x, s)
x = self.actv(x)
return x
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"]
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = nn.Sequential(nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=2, padding=1, bias=False), # 256
nn.BatchNorm2d(in_channel), activation)
self.down2 = nn.Sequential(nn.Conv2d(in_channel, in_channel*2, kernel_size=3, stride=2, padding=1, bias=False), # 128
nn.BatchNorm2d(in_channel*2), activation)
self.down3 = nn.Sequential(nn.Conv2d(in_channel*2, in_channel*4, kernel_size=3, stride=2, padding=1, bias=False), # 64
nn.BatchNorm2d(in_channel*4), activation)
self.down4 = nn.Sequential(nn.Conv2d(in_channel*4, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False), # 32
nn.BatchNorm2d(in_channel*8), activation)
self.down5 = nn.Sequential(nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False), # 32
nn.BatchNorm2d(in_channel*8), activation)
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
self.maskhead = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
nn.BatchNorm2d(in_channel), # 32
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel//2), # 64
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1, bias=False),
nn.Sigmoid()
)
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainUpBlock(in_channel*8, in_channel*8, style_dim=id_dim) # 32
self.up4 = AdainUpBlock(in_channel*8, in_channel*4, style_dim=id_dim) # 64
self.up3 = AdainUpBlock(in_channel*4, in_channel*2, style_dim=id_dim) # 128
self.up2 = AdainUpBlock(in_channel*2, in_channel, style_dim=id_dim)
self.up1 = AdainUpBlock(in_channel, in_channel, style_dim=id_dim)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
self.to_rgb = nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(in_channel, 3, kernel_size=3, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
mask= self.maskhead(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
res = (1-mask) * skip + mask * res
res = self.up2(res,id) # + skip
res = self.up1(res,id)
res = self.to_rgb(res)
return res, mask
+298
View File
@@ -0,0 +1,298 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Saturday, 2nd April 2022 1:27:23 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"]
lstu_script = kwargs["lstu_script"]
lstu_class = kwargs["lstu_class"]
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
script_name = "components." + lstu_script
package = __import__(script_name, fromlist=True)
lstu_class = getattr(package, lstu_class)
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
self.maskhead = nn.Sequential(
nn.ConvTranspose2d(in_channel*8, in_channel, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(in_channel), # 32
activation,
nn.ConvTranspose2d(in_channel, in_channel//2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(in_channel), # 64
activation,
nn.ConvTranspose2d(in_channel//2, 1, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(in_channel), # 128
nn.Sigmoid()
)
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
self.lstu = lstu_class(in_channel*2,norm)
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel*2, in_channel, normalize="in") # 256
# self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
# nn.BatchNorm2d(in_channel),
# activation,
# nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
# nn.Sigmoid()
# )
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id, feat_out=False):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
if feat_out:
return res
# res = self.down6(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
# res = self.up6(res,id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
res,mask = self.lstu(skip, res)
res = self.up2(res,id) # + skip
res = self.up1(res,id)
res = self.to_rgb(res)
return res
@@ -0,0 +1,280 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 3rd April 2022 1:06:31 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"]
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
self.maskhead = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
nn.BatchNorm2d(in_channel), # 32
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel//2), # 64
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1, bias=False),
nn.Sigmoid()
)
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
mask= self.maskhead(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
res = (1-mask) * skip + mask * res
res = self.up2(res,id) # + skip
res = self.up1(res,id)
res = self.to_rgb(res)
return res, mask
+280
View File
@@ -0,0 +1,280 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 3rd April 2022 1:06:31 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"]
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
self.maskhead = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
nn.BatchNorm2d(in_channel), # 32
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel//2), # 64
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1, bias=False),
nn.Sigmoid()
)
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
mask= self.maskhead(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
res = (1-mask) * skip + mask * res
res = self.up2(res,id) # + skip
res = self.up1(res,id)
res = self.to_rgb(res)
return res, mask
+279
View File
@@ -0,0 +1,279 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 6th April 2022 8:38:50 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="bn", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"]
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.sigma = ResBlk(in_channel*2,in_channel*2)
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
self.maskhead = nn.Sequential(
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel), # 64
activation,
nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
mask= self.maskhead(res)
res = (1-mask) * self.sigma(skip) + mask * res
res = self.up2(res,id) # + skip
res = self.up1(res,id)
res = self.to_rgb(res)
return res, mask
+283
View File
@@ -0,0 +1,283 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 13th April 2022 3:12:53 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="bn", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"].lower()
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
# self.sigma = ResBlk(in_channel*2,in_channel*2)
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
self.maskhead = nn.Sequential(
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel), # 64
activation,
nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize="bn")
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
self.up1 = ResUpBlk(in_channel, in_channel, normalize="bn")
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
mask= self.maskhead(res)
# res = (1-mask) * self.sigma(skip) + mask * res
res = (1-mask) * skip + mask * res
res = self.up2(res) # + skip
res = self.up1(res)
res = self.to_rgb(res)
return res, mask
+297
View File
@@ -0,0 +1,297 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator_Invobn_config1.py
# Created Date: Saturday February 26th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 13th April 2022 6:30:26 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import torch
from torch import nn
import torch.nn.functional as F
import math
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class ResUpBlk(nn.Module):
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
super().__init__()
self.actv = actv
self.normalize = normalize
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x):
x = self.norm1(x)
x = self.actv(x)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
out = self._residual(x)
out = (out + self._shortcut(x)) / self.equal_var
return out
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=512,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / self.equal_var
return out
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
id_dim = kwargs["id_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
in_channel = kwargs["in_channel"]
up_mode = kwargs["up_mode"]
norm = kwargs["norm"].lower()
aggregator = kwargs["aggregator"]
res_mode = kwargs["res_mode"]
padding_size= int((k_size -1)/2)
padding_type= 'reflect'
if norm.lower() == "in":
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
norm_mask= nn.InstanceNorm2d
elif norm.lower() == "bn":
norm_out = nn.BatchNorm2d(in_channel)
norm_mask = nn.BatchNorm2d
activation = nn.LeakyReLU(0.2)
# activation = nn.ReLU()
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(64), activation)
### downsample
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
### resnet blocks
BN = []
for i in range(res_num):
BN += [
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
self.BottleNeck = nn.Sequential(*BN)
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
# self.maskhead = nn.Sequential(
# nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
# norm_mask, # 64
# activation,
# nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
# nn.Sigmoid())
self.maskhead = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
norm_mask(in_channel, affine=True), # 32
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
norm_mask(in_channel//4, affine=True), # 64
activation,
nn.UpsamplingNearest2d(scale_factor = 2),
nn.Conv2d(in_channel//4, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
self.to_rgb = nn.Sequential(
norm_out,
activation,
nn.Conv2d(in_channel, 3, 3, 1, 1))
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, img, id):
res = self.from_rgb(img)
res = self.down1(res)
skip = self.down2(res)
res = self.down3(skip)
res = self.down4(res)
res = self.down5(res)
mask= self.maskhead(res)
for i in range(len(self.BottleNeck)):
res = self.BottleNeck[i](res, id)
res = self.up5(res,id)
res = self.up4(res,id)
res = self.up3(res,id)
# res = (1-mask) * self.sigma(skip) + mask * res
res = (1-mask) * skip + mask * res
res = self.up2(res) # + skip
res = self.up1(res)
res = self.to_rgb(res)
return res, mask
+96 -24
View File
@@ -5,43 +5,115 @@
# Created Date: Sunday January 16th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 13th February 2022 2:03:21 am
# Last Modified: Monday, 28th March 2022 11:47:55 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import math
import torch
from torch import nn
import torch.nn.functional as F
# class LSTU(nn.Module):
# def __init__(
# self,
# in_channel,
# out_channel,
# latent_channel,
# scale = 4
# ):
# super().__init__()
# sig = nn.Sigmoid()
# self.relu = nn.ReLU(True)
# self.up_sample = nn.Sequential(nn.Conv2d(latent_channel, out_channel/4, kernel_size=3, stride=1, padding=1, bias=False),
# nn.BatchNorm2d(out_channel/4),
# self.relu,
# nn.Conv2d(latent_channel/4, out_channel, kernel_size=3, stride=1, padding=1),
# )
# self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(out_channel), sig)
# self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(out_channel), sig)
# self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True))
# def forward(self, encoder_in, bottleneck_in):
# h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel`
# h_bar_l = self.conv11(h_hat_l_1)
# f_l = self.forget_gate(h_hat_l_1)
# r_l = self.reset_gate (h_hat_l_1)
# h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in
# x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1
# return x_hat_l
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class LSTU(nn.Module):
def __init__(
self,
in_channel,
out_channel,
latent_channel,
scale = 4
norm
):
super().__init__()
sig = nn.Sigmoid()
self.relu = nn.ReLU(True)
self.sig = nn.Sigmoid()
self.up_sample = nn.Sequential(nn.ConvTranspose2d(latent_channel, out_channel, kernel_size=4, stride=scale, padding=0, bias=False),
nn.BatchNorm2d(out_channel), sig)
self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channel), sig)
self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channel), sig)
self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True))
self.mask_head = ResBlk(in_channel, 1, normalize=norm)
# self.forget_gate = ResBlk(in_channel,in_channel, normalize=norm)
def forward(self, encoder_in, bottleneck_in):
h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel`
h_bar_l = self.conv11(h_hat_l_1)
f_l = self.forget_gate(h_hat_l_1)
r_l = self.reset_gate (h_hat_l_1)
h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in
x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1
return x_hat_l
def forward(self, encoder_in, decoder_in):
mask = self.sig(self.mask_head(decoder_in)) # upsample and make `channel` identical to `out_channel`
# enc_feat= self.forget_gate(encoder_in)
out = (1-mask)*encoder_in + mask * decoder_in
return out, mask
+124
View File
@@ -0,0 +1,124 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Generator.py
# Created Date: Sunday January 16th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 29th March 2022 12:20:26 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import math
import torch
from torch import nn
import torch.nn.functional as F
# class LSTU(nn.Module):
# def __init__(
# self,
# in_channel,
# out_channel,
# latent_channel,
# scale = 4
# ):
# super().__init__()
# sig = nn.Sigmoid()
# self.relu = nn.ReLU(True)
# self.up_sample = nn.Sequential(nn.Conv2d(latent_channel, out_channel/4, kernel_size=3, stride=1, padding=1, bias=False),
# nn.BatchNorm2d(out_channel/4),
# self.relu,
# nn.Conv2d(latent_channel/4, out_channel, kernel_size=3, stride=1, padding=1),
# )
# self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(out_channel), sig)
# self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(out_channel), sig)
# self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True))
# def forward(self, encoder_in, bottleneck_in):
# h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel`
# h_bar_l = self.conv11(h_hat_l_1)
# f_l = self.forget_gate(h_hat_l_1)
# r_l = self.reset_gate (h_hat_l_1)
# h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in
# x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1
# return x_hat_l
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.equal_var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x /self.equal_var # unit variance
class LSTU(nn.Module):
def __init__(
self,
in_channel,
norm
):
super().__init__()
# self.mask_head = ResBlk(in_channel, 1, normalize=norm)
self.mask_head = nn.Sequential(nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(in_channel//2),
nn.LeakyReLU(0.2),
nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
# self.forget_gate = ResBlk(in_channel,in_channel, normalize=norm)
def forward(self, encoder_in, decoder_in):
mask = self.mask_head(decoder_in) # upsample and make `channel` identical to `out_channel`
# enc_feat= self.forget_gate(encoder_in)
out = (1-mask)*encoder_in + mask * decoder_in
return out, mask
+66
View File
@@ -0,0 +1,66 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: ModulatedDWConv.py
# Created Date: Monday April 18th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 18th April 2022 10:33:48 am
# Modified By: Chen Xuanhong
# Modified from: https://github.com/bes-dev/MobileStyleGAN.pytorch
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class ModulatedDWConv2d(nn.Module):
def __init__(
self,
channels_in,
channels_out,
style_dim,
kernel_size,
demodulate=True
):
super().__init__()
# create conv
self.weight_dw = nn.Parameter(
torch.randn(channels_in, 1, kernel_size, kernel_size)
)
self.weight_permute = nn.Parameter(
torch.randn(channels_out, channels_in, 1, 1)
)
# create modulation network
self.modulation = nn.Linear(style_dim, channels_in, bias=True)
self.modulation.bias.data.fill_(1.0)
# create demodulation parameters
self.demodulate = demodulate
if self.demodulate:
self.register_buffer("style_inv", torch.randn(1, 1, channels_in, 1, 1))
# some service staff
self.scale = 1.0 / math.sqrt(channels_in * kernel_size ** 2)
self.padding = kernel_size // 2
def forward(self, x, style):
modulation = self.get_modulation(style)
x = modulation * x
x = F.conv2d(x, self.weight_dw, padding=self.padding, groups=x.size(1))
x = F.conv2d(x, self.weight_permute)
if self.demodulate:
demodulation = self.get_demodulation(style)
x = demodulation * x
return x
def get_modulation(self, style):
style = self.modulation(style).view(style.size(0), -1, 1, 1)
modulation = self.scale * style
return modulation
def get_demodulation(self, style):
w = (self.weight_dw.transpose(0, 1) * self.weight_permute).unsqueeze(0)
norm = torch.rsqrt((self.scale * self.style_inv * w).pow(2).sum([2, 3, 4]) + 1e-8)
demodulation = norm
return demodulation.view(*demodulation.size(), 1, 1)
+96
View File
@@ -0,0 +1,96 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Nonstau_Discriminator.py
# Created Date: Monday March 28th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 28th March 2022 10:03:56 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
elif self.normalize.lower() == "none":
self.normalize = False
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x / self.var # unit variance
class Discriminator(torch.nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
img_size = kwargs["img_size"]
num_domains = 1
max_conv_dim = kwargs["max_conv_dim"]
norm = kwargs["norm"]
dim_in = 2**14 // img_size
blocks = []
blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
repeat_num = int(np.log2(img_size)) - 2
for _ in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, normalize=norm, downsample=True)]
dim_in = dim_out
blocks += [nn.LeakyReLU(0.2)]
blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
blocks += [nn.LeakyReLU(0.2)]
blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)]
self.main = nn.Sequential(*blocks)
def forward(self, x):
out = self.main(x)
out = out.view(out.size(0), -1) # (batch, num_domains)
return out
+107
View File
@@ -0,0 +1,107 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Nonstau_Discriminator.py
# Created Date: Monday March 28th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 13th April 2022 3:11:40 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize="in", downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self.var = math.sqrt(2)
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize.lower() == "in":
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
elif self.normalize.lower() == "bn":
self.norm1 = nn.BatchNorm2d(dim_in)
self.norm2 = nn.BatchNorm2d(dim_in)
elif self.normalize.lower() == "none":
self.normalize = False
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x / self.var # unit variance
class Discriminator(torch.nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
img_size = kwargs["img_size"]
num_domains = 1
max_conv_dim = kwargs["max_conv_dim"]
norm = kwargs["norm"].lower()
dim_in = 2**14 // img_size
blocks = []
blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
repeat_num = int(np.log2(img_size)) - 2
for _ in range(repeat_num-2):
dim_out = min(dim_in*2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, normalize=norm, downsample=True)]
dim_in = dim_out
blocks1 = []
for _ in range(2): # 16
dim_out = min(dim_in*2, max_conv_dim)
blocks1 += [ResBlk(dim_in, dim_out, normalize=norm, downsample=True)]
dim_in = dim_out
blocks1 += [nn.LeakyReLU(0.2)]
blocks1 += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
blocks1 += [nn.LeakyReLU(0.2)]
blocks1 += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)]
self.main = nn.Sequential(*blocks)
self.tail = nn.Sequential(*blocks1)
def get_feature(self,x):
mid = self.main(x)
return mid
def forward(self, x):
mid = self.main(x)
out = self.tail(mid)
out = out.view(out.size(0), -1) # (batch, num_domains)
return out,mid
+185
View File
@@ -0,0 +1,185 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: data_loader_VGGFace2HQ copy.py
# Created Date: Sunday February 6th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 15th February 2022 1:50:19 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import glob
import torch
import random
import numpy as np
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
# from StyleResize import StyleResize
class InfiniteSampler(torch.utils.data.Sampler):
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
assert len(dataset) > 0
assert num_replicas > 0
assert 0 <= rank < num_replicas
assert 0 <= window_size <= 1
super().__init__(dataset)
self.dataset = dataset
self.rank = rank
self.num_replicas = num_replicas
self.shuffle = shuffle
self.seed = seed
self.window_size = window_size
def __iter__(self):
order = np.arange(len(self.dataset))
rnd = None
window = 0
if self.shuffle:
rnd = np.random.RandomState(self.seed)
rnd.shuffle(order)
window = int(np.rint(order.size * self.window_size))
idx = 0
while True:
i = idx % order.size
if idx % self.num_replicas == self.rank:
yield order[i]
if window >= 2:
j = (i - rnd.randint(window)) % order.size
order[i], order[j] = order[j], order[i]
idx += 1
class data_prefetcher():
def __init__(self, loader, cur_gpu):
torch.cuda.set_device(cur_gpu) # must add this line to avoid excessive use of GPU 0 by the prefetcher
self.loader = loader
self.dataiter = iter(loader)
self.stream = torch.cuda.Stream(device=cur_gpu)
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1)
self.cur_gpu = cur_gpu
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
# self.num_images = loader.__len__()
self.preload()
def preload(self):
# try:
self.src_image1, self.src_image2 = next(self.dataiter)
# except StopIteration:
# self.dataiter = iter(self.loader)
# self.src_image1, self.src_image2 = next(self.dataiter)
with torch.cuda.stream(self.stream):
self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True)
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True)
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
# self.next_input = self.next_input.float()
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream)
src_image1 = self.src_image1
src_image2 = self.src_image2
self.preload()
return src_image1, src_image2
# def __len__(self):
# """Return the number of images."""
# return self.num_images
class VGGFace2HQDataset(data.Dataset):
"""Dataset class for the Artworks dataset and content dataset."""
def __init__(self,
image_dir,
img_transform,
subffix='jpg',
random_seed=1234):
"""Initialize and preprocess the VGGFace2 HQ dataset."""
self.image_dir = image_dir
self.img_transform = img_transform
self.subffix = subffix
self.dataset = []
self.random_seed = random_seed
self.preprocess()
self.num_images = len(self.dataset)
def preprocess(self):
"""Preprocess the VGGFace2 HQ dataset."""
print("processing VGGFace2 HQ dataset images...")
temp_path = os.path.join(self.image_dir,'*/')
pathes = glob.glob(temp_path)
self.dataset = []
for dir_item in pathes:
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
print("processing %s"%dir_item,end='\r')
temp_list = []
for item in join_path:
temp_list.append(item)
self.dataset.append(temp_list)
random.seed(self.random_seed)
random.shuffle(self.dataset)
print('Finished preprocessing the VGGFace2 HQ dataset, total dirs number: %d...'%len(self.dataset))
def __getitem__(self, index):
"""Return two src domain images and two dst domain images."""
dir_tmp1 = self.dataset[index]
dir_tmp1_len = len(dir_tmp1)
filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
image1 = self.img_transform(Image.open(filename1))
image2 = self.img_transform(Image.open(filename2))
return image1, image2
def __len__(self):
"""Return the number of images."""
return self.num_images
def GetLoader( dataset_roots,
rank,
num_gpus,
batch_size=16,
**kwargs
):
"""Build and return a data loader."""
data_root = dataset_roots
random_seed = kwargs["random_seed"]
num_workers = kwargs["dataloader_workers"]
c_transforms = []
c_transforms.append(T.ToTensor())
c_transforms = T.Compose(c_transforms)
content_dataset = VGGFace2HQDataset(
data_root,
c_transforms,
"jpg",
random_seed)
device = torch.device('cuda', rank)
sampler = InfiniteSampler(dataset=content_dataset, rank=rank, num_replicas=num_gpus, seed=random_seed)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True, sampler=sampler)
# content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
# drop_last=False,shuffle=True,num_workers=num_workers,pin_memory=True)
prefetcher = data_prefetcher(content_data_loader,device)
return prefetcher
def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
@@ -5,7 +5,7 @@
# Created Date: Sunday February 6th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 15th February 2022 1:50:19 am
# Last Modified: Wednesday, 6th April 2022 12:53:53 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -108,7 +108,7 @@ class VGGFace2HQDataset(data.Dataset):
subffix='jpg',
random_seed=1234):
"""Initialize and preprocess the VGGFace2 HQ dataset."""
self.image_dir = image_dir
self.image_dir = image_dir["images"]
self.img_transform = img_transform
self.subffix = subffix
self.dataset = []
@@ -0,0 +1,202 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: data_loader_VGGFace2HQ copy.py
# Created Date: Sunday February 6th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 3rd April 2022 9:48:23 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import glob
import torch
import random
import numpy as np
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
import cv2
# from StyleResize import StyleResize
class InfiniteSampler(torch.utils.data.Sampler):
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
assert len(dataset) > 0
assert num_replicas > 0
assert 0 <= rank < num_replicas
assert 0 <= window_size <= 1
super().__init__(dataset)
self.dataset = dataset
self.rank = rank
self.num_replicas = num_replicas
self.shuffle = shuffle
self.seed = seed
self.window_size = window_size
def __iter__(self):
order = np.arange(len(self.dataset))
rnd = None
window = 0
if self.shuffle:
rnd = np.random.RandomState(self.seed)
rnd.shuffle(order)
window = int(np.rint(order.size * self.window_size))
idx = 0
while True:
i = idx % order.size
if idx % self.num_replicas == self.rank:
yield order[i]
if window >= 2:
j = (i - rnd.randint(window)) % order.size
order[i], order[j] = order[j], order[i]
idx += 1
class data_prefetcher():
def __init__(self, loader, cur_gpu):
torch.cuda.set_device(cur_gpu) # must add this line to avoid excessive use of GPU 0 by the prefetcher
self.loader = loader
self.dataiter = iter(loader)
self.stream = torch.cuda.Stream(device=cur_gpu)
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1)
self.cur_gpu = cur_gpu
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
# self.num_images = loader.__len__()
self.preload()
def preload(self):
# try:
self.src_image1, self.src_image2, self.mask = next(self.dataiter)
# except StopIteration:
# self.dataiter = iter(self.loader)
# self.src_image1, self.src_image2 = next(self.dataiter)
with torch.cuda.stream(self.stream):
self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True)
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True)
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
self.mask = self.mask.cuda(device= self.cur_gpu, non_blocking=True)
self.mask = self.mask/255.0
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
# self.next_input = self.next_input.float()
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream)
src_image1 = self.src_image1
src_image2 = self.src_image2
mask = self.mask
self.preload()
return src_image1, src_image2, mask
# def __len__(self):
# """Return the number of images."""
# return self.num_images
class VGGFace2HQDataset(data.Dataset):
"""Dataset class for the Artworks dataset and content dataset."""
def __init__(self,
image_dir,
mask_dir,
img_transform,
subffix='jpg',
random_seed=1234):
"""Initialize and preprocess the VGGFace2 HQ dataset."""
self.image_dir = image_dir
self.mask_dir = mask_dir
self.img_transform = img_transform
self.subffix = subffix
self.dataset = []
self.random_seed = random_seed
self.preprocess()
self.num_images = len(self.dataset)
def preprocess(self):
"""Preprocess the VGGFace2 HQ dataset."""
print("processing VGGFace2 HQ dataset images...")
temp_path = os.path.join(self.image_dir,'*/')
pathes = glob.glob(temp_path)
self.dataset = []
for dir_item in pathes:
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
print("processing %s"%dir_item,end='\r')
dir_path = os.path.dirname(join_path[1])
dir_name = os.path.join(self.mask_dir, os.path.basename(dir_path))
# print(dir_name)
temp_list = []
for item in join_path:
img_name = os.path.basename(item)
img_name, _ = os.path.splitext(img_name)
temp_list.append({
"i":item,
"m":os.path.join(dir_name, img_name + ".png")
})
self.dataset.append(temp_list)
random.seed(self.random_seed)
random.shuffle(self.dataset)
print('Finished preprocessing the VGGFace2 HQ dataset, total dirs number: %d...'%len(self.dataset))
def __getitem__(self, index):
"""Return two src domain images and two dst domain images."""
dir_tmp1 = self.dataset[index]
dir_tmp1_len = len(dir_tmp1)
filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
image1 = self.img_transform(Image.open(filename1["i"]))
image2 = self.img_transform(Image.open(filename2["i"]))
mask = torch.from_numpy(cv2.imread(filename1["m"],0)).unsqueeze(0)
return image1, image2, mask
def __len__(self):
"""Return the number of images."""
return self.num_images
def GetLoader( dataset_roots,
rank,
num_gpus,
batch_size=16,
**kwargs
):
"""Build and return a data loader."""
data_root = dataset_roots["images"]
mask_root = dataset_roots["masks"]
random_seed = kwargs["random_seed"]
num_workers = kwargs["dataloader_workers"]
c_transforms = []
c_transforms.append(T.ToTensor())
c_transforms = T.Compose(c_transforms)
content_dataset = VGGFace2HQDataset(
data_root,
mask_root,
c_transforms,
"jpg",
random_seed)
device = torch.device('cuda', rank)
sampler = InfiniteSampler(dataset=content_dataset, rank=rank, num_replicas=num_gpus, seed=random_seed)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True, sampler=sampler)
# content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
# drop_last=False,shuffle=True,num_workers=num_workers,pin_memory=True)
prefetcher = data_prefetcher(content_data_loader,device)
return prefetcher
def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
+40
View File
@@ -0,0 +1,40 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: dataset.check.py
# Created Date: Sunday April 3rd 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 3rd April 2022 2:57:48 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import glob
from utilities.json_config import readConfig, writeConfig
# dataset = "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan"
# mask_dir= "G:/VGGFace2-HQ/VGGface2_HQ_original_aligned_mask"
savePath = "./vggface2hq_failed.txt"
env_config = readConfig('env/env.json')
env_config = env_config["path"]
dataset = env_config["dataset_paths"]["vggface2_hq"]["images"]
mask_dir = env_config["dataset_paths"]["vggface2_hq"]["masks"]
temp_path = os.path.join(dataset,'*/')
pathes = glob.glob(temp_path)
for dir_item in pathes:
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
print("processing %s"%dir_item,end='\r')
dir_path = os.path.dirname(join_path[1])
dir_name = os.path.join(mask_dir, os.path.basename(dir_path))
# print(dir_name)
temp_list = []
for item in join_path:
img_name = os.path.basename(item)
img_name, _ = os.path.splitext(img_name)
mask_name = os.path.join(dir_name, img_name + ".png")
if not os.path.exists(mask_name):
print(mask_name)
+113
View File
@@ -0,0 +1,113 @@
张雨绮 --> 1
刘亦菲 --> 2
周杰伦 --> 3
关晓彤 --> 4
林志玲 --> 13
103许岚解压密码coshunter.top --> 18
19 --> [YOUMI灏よ湝鑽焆 2020.09.03 VOL.521 濡插繁_Toxic [61P508MB]
金晨 --> 20
容祖儿 --> 21
张歆艺 --> 22
杨颖 --> 23
倪妮 --> 24
G:/4K/90 佟丽娅 --> H:/face_data/VGGFace2_HQ/27
273 赵粤 --> 28
G:/4K/291-不支持在线打开,请下载后解压玥儿玥er --> H:/face_data/VGGFace2_HQ/29
G:/4K/114 张紫宁 --> H:/face_data/VGGFace2_HQ/30
G:/4K/120 舒淇 --> H:/face_data/VGGFace2_HQ/31
G:/4K/124 张子枫 --> H:/face_data/VGGFace2_HQ/32
G:/4K/173 王心凌 --> H:/face_data/VGGFace2_HQ/33
G:/4K/191 郑爽 --> H:/face_data/VGGFace2_HQ/34
G:/4K/286尤妮丝b/防何蟹文件夹(勿在线解压) 图包 02/防何蟹文件夹(勿在线解压) 图包 02 --> H:/face_data/VGGFace2_HQ/35
G:/4K/辛芷蕾 --> H:/face_data/VGGFace2_HQ/36
G:/4K/197 邓家佳 --> H:/face_data/VGGFace2_HQ/37
G:/4K/195 袁姗姗 --> H:/face_data/VGGFace2_HQ/38
G:/4K/249 希林娜依.高 --> H:/face_data/VGGFace2_HQ/39
G:/4K/素材合/母其弥雅原片素材 --> H:/face_data/VGGFace2_HQ/40
G:/4K/素材合/张含韵系列 --> H:/face_data/VGGFace2_HQ/41
G:/4K/231胡静 --> H:/face_data/VGGFace2_HQ/42
G:/4K/259 周洁琼 --> H:/face_data/VGGFace2_HQ/43
G:/4K/190 李晟 --> H:/face_data/VGGFace2_HQ/44
G:/4K/193 张檬 --> H:/face_data/VGGFace2_HQ/45
G:/4K/174 潘之琳 --> H:/face_data/VGGFace2_HQ/46
G:/4K/196 陈都灵 --> H:/face_data/VGGFace2_HQ/47
G:/4K/236 李亚男 --> H:/face_data/VGGFace2_HQ/48
G:/4K/221 江疏影 --> H:/face_data/VGGFace2_HQ/49
G:/4K/253 刘羽琦 --> H:/face_data/VGGFace2_HQ/50
G:/4K/282 李艺彤 --> H:/face_data/VGGFace2_HQ/51
G:/4K/沈梦辰系列 --> H:/face_data/VGGFace2_HQ/52
G:/4K/293 张怡 --> H:/face_data/VGGFace2_HQ/53
G:/4K/私多个文件/11.7 --> H:/face_data/VGGFace2_HQ/54
G:/4K/私多个文件/2016.3.31木子私房/原片 --> H:/face_data/VGGFace2_HQ/55
G:/4K/私多个文件/李可人 原片 --> H:/face_data/VGGFace2_HQ/56
G:/4K/私多个文件/刘雨欣 --> H:/face_data/VGGFace2_HQ/57
G:/4K/私多个文件/新建文件夹 --> H:/face_data/VGGFace2_HQ/58
G:/4K/私多个文件/新雯 --> H:/face_data/VGGFace2_HQ/58
G:/4K/私多个文件/原片 --> H:/face_data/VGGFace2_HQ/59
G:/4K/私多个文件/原片22 --> H:/face_data/VGGFace2_HQ/60
G:/4K/私多个文件/原图 --> H:/face_data/VGGFace2_HQ/61
G:/4K/徐璐系列 --> H:/face_data/VGGFace2_HQ/62
G:/4K/徐冬冬 --> H:/face_data/VGGFace2_HQ/63
G:/4K/徐静蕾 --> H:/face_data/VGGFace2_HQ/64
G:/4K/张靓颖 --> H:/face_data/VGGFace2_HQ/65
G:/4K/私多个文件/原片22 --> H:/face_data/hg
G:/4K/私多个文件/原片22 --> H:/face_data/hg1
G:/4K/张靓颖 --> H:/face_data/VGGFace2_HQ/65
G:/4K/凉高定妆 --> H:/face_data/VGGFace2_HQ/66
G:/4K/林志玲-初整 --> H:/face_data/VGGFace2_HQ/67
G:/4K/000未整理/张小斐 --> H:/face_data/VGGFace2_HQ/68
G:/4K/000未整理/马思纯系列1 --> H:/face_data/VGGFace2_HQ/69
G:/4K/000未整理/王丽坤 --> H:/face_data/VGGFace2_HQ/70
G:/4K/000未整理/刘涛全部 --> H:/face_data/VGGFace2_HQ/71
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/张子枫 --> H:/face_data/VGGFace2_HQ/72
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/李小璐 --> H:/face_data/VGGFace2_HQ/73
G:/4K/295-不支持在线打开,请下载后解压镜酱 --> H:/face_data/VGGFace2_HQ/74
G:/4K/297-不支持在线打开,请下载后解压一只云烧 --> H:/face_data/VGGFace2_HQ/74
G:/4K/295-不支持在线打开,请下载后解压镜酱 --> H:/face_data/VGGFace2_HQ/75
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/海报艺人可用素材/李世鹏可用 --> H:/face_data/VGGFace2_HQ/76
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/海报艺人可用素材/李菲儿可用 --> H:/face_data/VGGFace2_HQ/77
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/程砚秋 --> H:/face_data/VGGFace2_HQ/78
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/白舟夏绿 --> H:/face_data/VGGFace2_HQ/79
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/果果白舟夏绿 --> H:/face_data/VGGFace2_HQ/80
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/陶西安谧 --> H:/face_data/VGGFace2_HQ/81
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/祖儿紫枫 --> H:/face_data/VGGFace2_HQ/82
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/祖儿紫枫2 --> H:/face_data/VGGFace2_HQ/83
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照0905-0912 --> H:/face_data/VGGFace2_HQ/84
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照0913-0922 --> H:/face_data/VGGFace2_HQ/85
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照0923-1004 --> H:/face_data/VGGFace2_HQ/86
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1005-1010 --> H:/face_data/VGGFace2_HQ/87
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1011-1017 --> H:/face_data/VGGFace2_HQ/88
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1018-1101 --> H:/face_data/VGGFace2_HQ/89
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1102-1114 --> H:/face_data/VGGFace2_HQ/90
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1115-1122 --> H:/face_data/VGGFace2_HQ/91
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1123-1204 --> H:/face_data/VGGFace2_HQ/92
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1205-1212 --> H:/face_data/VGGFace2_HQ/93
G:/4K/000未整理/黄奕1 --> H:/face_data/VGGFace2_HQ/94
G:/4K/000未整理/阚清子 --> H:/face_data/VGGFace2_HQ/95
G:/4K/000未整理/何穗热裤李小璐等 --> H:/face_data/VGGFace2_HQ/96
G:/4K/000未整理/景甜热裤泳池 --> H:/face_data/VGGFace2_HQ/97
H:/face_data/VGGFace2_HQ/77 --> H:/face_data/VGGFace2_HQ/98
G:/4K/16 吴映洁空姐制服 --> H:/face_data/VGGFace2_HQ/98
G:/4K/2 --> H:/face_data/VGGFace2_HQ/99
G:/4K/108 张韶涵 --> H:/face_data/VGGFace2_HQ\108 张韶涵
G:/4K/110 张小斐 --> H:/face_data/VGGFace2_HQ\110 张小斐
G:/4K/111 张小斐 --> H:/face_data/VGGFace2_HQ\111 张小斐
G:/4K/119 宋茜 --> H:/face_data/VGGFace2_HQ\119 宋茜
G:/4K/120 舒淇 --> H:/face_data/VGGFace2_HQ\120 舒淇
G:/4K/125 舒淇合集 --> H:/face_data/VGGFace2_HQ\125 舒淇合集
G:/4K/各地素人拍摄/新建文件夹 (6) --> H:/face_data/VGGFace2_HQ\新建文件夹 (6)
G:/4K/各地素人拍摄/原片(6) --> H:/face_data/VGGFace2_HQ\原片(6)
G:/4K/各地素人拍摄/20161113舞剧系拍摄jpg --> H:/face_data/VGGFace2_HQ\20161113舞剧系拍摄jpg
G:/4K/各地素人拍摄/155 --> H:/face_data/VGGFace2_HQ\155
G:/4K/各地素人拍摄/17 --> H:/face_data/VGGFace2_HQ\17
G:/4K/249刘飞er解压密码coshunter.com或coshunter.top --> H:/face_data/VGGFace2_HQ\249刘飞er解压密码coshunter.com或coshunter
H:/face_data/cw --> H:/face_data/VGGFace2_HQ\cw
H:/face_data/cw --> H:/face_data/VGGFace2_HQ\cw
G:/4K/柳岩棚拍 --> H:/face_data/VGGFace2_HQ\柳岩棚拍
G:/4K/216 卢靖姗 --> H:/face_data/VGGFace2_HQ\216 卢靖姗
G:/4K/217 卢靖姗 --> H:/face_data/VGGFace2_HQ\217 卢靖姗
G:/4K/白百合 --> H:/face_data/VGGFace2_HQ\白百合
G:/4K/白百何 --> H:/face_data/VGGFace2_HQ\白百何
G:/4K/张雨绮 --> H:/face_data/VGGFace2_HQ\张雨绮
G:/4K/179 冯薪朵 --> H:/face_data/VGGFace2_HQ\179 冯薪朵
+44 -13
View File
@@ -5,7 +5,7 @@
# Created Date: Tuesday February 1st 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 2nd February 2022 4:13:28 pm
# Last Modified: Sunday, 24th April 2022 2:01:47 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -26,6 +26,8 @@ import tkinter.ttk as ttk
import subprocess
from pathlib import Path
import numpy as np
from insightface_func.face_detect_crop_multi import Face_detect_crop
class TextRedirector(object):
@@ -113,6 +115,7 @@ class Application(tk.Frame):
label_frame.columnconfigure(0, weight=1)
label_frame.columnconfigure(1, weight=1)
label_frame.columnconfigure(2, weight=1)
label_frame.columnconfigure(3, weight=1)
tk.Label(label_frame, text="Crop Size:",font=font_list,justify="left")\
.grid(row=0,column=0,sticky=tk.EW)
@@ -123,6 +126,9 @@ class Application(tk.Frame):
tk.Label(label_frame, text="Target Format:",font=font_list,justify="left")\
.grid(row=0,column=2,sticky=tk.EW)
tk.Label(label_frame, text="Blurry Thredhold:",font=font_list,justify="left")\
.grid(row=0,column=3,sticky=tk.EW)
#################################################################################################
test_frame = tk.Frame(self.master)
@@ -151,8 +157,10 @@ class Application(tk.Frame):
self.format_com["value"] = ["png","jpg"]
self.format_com.current(0)
self.thredhold = tkinter.StringVar()
tk.Entry(test_frame, textvariable= self.thredhold, font=font_list)\
.grid(row=0,column=3,sticky=tk.EW)
self.thredhold.set("70")
#################################################################################################
scale_frame = tk.Frame(self.master)
scale_frame.pack(fill="both", padx=5,pady=5)
@@ -200,8 +208,10 @@ class Application(tk.Frame):
def select_task(self):
path = askdirectory()
print("Selected source directory: %s"%path)
self.img_path.set(path)
if os.path.isdir(path):
print("Selected source directory: %s"%path)
self.img_path.set(path)
def Select_Target(self):
thread_update = threading.Thread(target=self.select_target_task)
@@ -209,8 +219,9 @@ class Application(tk.Frame):
def select_target_task(self):
path = askdirectory()
print("Selected target directory: %s"%path)
self.save_path.set(path)
if os.path.isdir(path):
print("Selected target directory: %s"%path)
self.save_path.set(path)
def Crop(self):
thread_update = threading.Thread(target=self.crop_task)
@@ -218,39 +229,59 @@ class Application(tk.Frame):
def crop_task(self):
mode = self.align_com.get()
if mode == "VGGFace":
mode = "None"
crop_size = int(self.test_com.get())
path = self.img_path.get()
tg_path = self.save_path.get()
blur_t = self.thredhold.get()
basepath = os.path.splitext(os.path.basename(path))[0]
tg_path = os.path.join("H:/face_data/VGGFace2_HQ",basepath)
print("target path: ",tg_path)
if not os.path.exists(tg_path):
os.makedirs(tg_path)
tg_format = self.format_com.get()
min_scale = float(self.min_scale.get())
blur_t = 100.0
font = cv2.FONT_HERSHEY_SIMPLEX
blur_t = float(blur_t)
print("Blurry thredhold %f"%blur_t)
self.detect.prepare(ctx_id = 0, det_thresh=0.6,\
det_size=(640,640),mode = mode,crop_size=crop_size,ratio=min_scale)
log_file = "./dataset_readme.txt"
with open(log_file,'a+') as logf: # ,encoding='UTF-8'
logf.writelines("%s --> %s\n"%(path,tg_path))
if path and tg_path:
imgs_list = []
if os.path.isdir(path):
print("Input a dir....")
imgs = glob.glob(os.path.join(path,"*"))
for item in imgs:
# imgs = glob.glob(os.path.join(path,"**"))
for item in glob.iglob(os.path.join(path,"**"),recursive=True):
imgs_list.append(item)
# print(imgs_list)
index = 0
for img in imgs_list:
print(img)
attr_img_ori= cv2.imread(img)
try:
attr_img_ori = cv2.imdecode(np.fromfile(img, dtype=np.uint8),-1)
except:
print("Illegal file!")
continue
# attr_img_ori= cv2.imread(img)
try:
attr_img_align_crop, _ = self.detect.get(attr_img_ori)
sub_index = 0
if len(attr_img_align_crop) < 1:
print("Small face")
for face_i in attr_img_align_crop:
imageVar = cv2.Laplacian(face_i, cv2.CV_64F).var()
f_path =os.path.join(tg_path, str(index).zfill(6)+"_%d.%s"%(sub_index,tg_format))
# print("save path: ",f_path)
if imageVar < blur_t:
print("Over blurry image!")
continue
# face_i = cv2.putText(face_i, '%.1f'%imageVar,(50, 50), font, 0.8, (15, 9, 255), 2)
cv2.imwrite(f_path,face_i)
# cv2.imwrite(f_path,face_i)
cv2.imencode('.png',face_i)[1].tofile(f_path)
sub_index += 1
index += 1
except:
+279
View File
@@ -0,0 +1,279 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: face_crop.py
# Created Date: Tuesday February 1st 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 22nd April 2022 8:43:40 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import sys
import glob
import json
import tkinter
from tkinter.filedialog import askdirectory
import threading
import tkinter as tk
import tkinter.ttk as ttk
import subprocess
from pathlib import Path
from insightface_func.face_detect_crop_multi import Face_detect_crop
class TextRedirector(object):
def __init__(self, widget, tag="stdout"):
self.widget = widget
self.tag = tag
def write(self, str):
self.widget.configure(state="normal")
self.widget.insert("end", str, (self.tag,))
self.widget.configure(state="disabled")
self.widget.see(tk.END)
def flush(self):
pass
#############################################################
# Main Class
#############################################################
class Application(tk.Frame):
def __init__(self, master=None):
tk.Frame.__init__(self, master,bg='black')
# self.font_size = 16
self.font_list = ("Times New Roman",14)
self.padx = 5
self.pady = 5
self.window_init()
def __label_text__(self, usr, root):
return "User Name: %s\nWorkspace: %s"%(usr, root)
def window_init(self):
cwd = os.getcwd()
self.master.title('Face Crop - %s'%cwd)
# self.master.iconbitmap('./utilities/_logo.ico')
self.master.geometry("{}x{}".format(640, 600))
font_list = self.font_list
#################################################################################################
list_frame = tk.Frame(self.master)
list_frame.pack(fill="both", padx=5,pady=5)
list_frame.columnconfigure(0, weight=1)
list_frame.columnconfigure(1, weight=1)
list_frame.columnconfigure(2, weight=1)
self.img_path = tkinter.StringVar()
tk.Label(list_frame, text="Image/Video Path:",font=font_list,justify="left")\
.grid(row=0,column=0,sticky=tk.EW)
tk.Entry(list_frame, textvariable= self.img_path, font=font_list)\
.grid(row=0,column=1,sticky=tk.EW)
tk.Button(list_frame, text = "Select Path", font=font_list,
command = self.Select, bg='#F4A460', fg='#F5F5F5')\
.grid(row=0,column=2,sticky=tk.EW)
#################################################################################################
list_frame1 = tk.Frame(self.master)
list_frame1.pack(fill="both", padx=5,pady=5)
list_frame1.columnconfigure(0, weight=1)
list_frame1.columnconfigure(1, weight=1)
list_frame1.columnconfigure(2, weight=1)
self.save_path = tkinter.StringVar()
tk.Label(list_frame1, text="Target Path:",font=font_list,justify="left")\
.grid(row=0,column=0,sticky=tk.EW)
tk.Entry(list_frame1, textvariable= self.save_path, font=font_list)\
.grid(row=0,column=1,sticky=tk.EW)
tk.Button(list_frame1, text = "Select Path", font=font_list,
command = self.Select_Target, bg='#F4A460', fg='#F5F5F5')\
.grid(row=0,column=2,sticky=tk.EW)
#################################################################################################
label_frame = tk.Frame(self.master)
label_frame.pack(fill="both", padx=5,pady=5)
label_frame.columnconfigure(0, weight=1)
label_frame.columnconfigure(1, weight=1)
label_frame.columnconfigure(2, weight=1)
tk.Label(label_frame, text="Crop Size:",font=font_list,justify="left")\
.grid(row=0,column=0,sticky=tk.EW)
tk.Label(label_frame, text="Align Mode:",font=font_list,justify="left")\
.grid(row=0,column=1,sticky=tk.EW)
tk.Label(label_frame, text="Target Format:",font=font_list,justify="left")\
.grid(row=0,column=2,sticky=tk.EW)
#################################################################################################
test_frame = tk.Frame(self.master)
test_frame.pack(fill="both", padx=5,pady=5)
test_frame.columnconfigure(0, weight=1)
test_frame.columnconfigure(1, weight=1)
test_frame.columnconfigure(2, weight=1)
self.test_var = tkinter.StringVar()
self.test_com = ttk.Combobox(test_frame, textvariable=self.test_var)
self.test_com.grid(row=0,column=0,sticky=tk.EW)
self.test_com["value"] = [256,512,768,1024]
self.test_com.current(1)
self.align_var = tkinter.StringVar()
self.align_com = ttk.Combobox(test_frame, textvariable=self.align_var)
self.align_com.grid(row=0,column=1,sticky=tk.EW)
self.align_com["value"] = ["VGGFace","ffhq"]
self.align_com.current(0)
self.format_var = tkinter.StringVar()
self.format_com = ttk.Combobox(test_frame, textvariable=self.format_var)
self.format_com.grid(row=0,column=2,sticky=tk.EW)
self.format_com["value"] = ["png","jpg"]
self.format_com.current(0)
#################################################################################################
scale_frame = tk.Frame(self.master)
scale_frame.pack(fill="both", padx=5,pady=5)
scale_frame.columnconfigure(0, weight=2)
label_frame.columnconfigure(1, weight=1)
# label_frame.columnconfigure(2, weight=1)
tk.Label(scale_frame, text="Min Size:",font=font_list,justify="left")\
.grid(row=0,column=0,sticky=tk.EW)
self.min_scale = tkinter.StringVar()
tk.Scale(scale_frame, from_=0.5, to=2.0, length=500, orient=tk.HORIZONTAL, variable= self.min_scale,\
font=font_list, resolution=0.1).grid(row=0,column=1,sticky=tk.EW)
#################################################################################################
test_frame1 = tk.Frame(self.master)
test_frame1.pack(fill="both", padx=5,pady=5)
test_frame1.columnconfigure(0, weight=1)
# test_frame1.columnconfigure(1, weight=1)
test_update_button = tk.Button(test_frame1, text = "Crop",
font=font_list, command = self.Crop, bg='#F4A460', fg='#F5F5F5')
test_update_button.grid(row=0,column=0,sticky=tk.EW)
#################################################################################################
text = tk.Text(self.master, wrap="word")
text.pack(fill="both",expand="yes", padx=5,pady=5)
sys.stdout = TextRedirector(text, "stdout")
self.init_algorithm()
self.master.protocol("WM_DELETE_WINDOW", self.on_closing)
def init_algorithm(self):
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
# def __scaning_logs__(self):
def Select(self):
thread_update = threading.Thread(target=self.select_task)
thread_update.start()
def select_task(self):
path = askdirectory()
print("Selected source directory: %s"%path)
self.img_path.set(path)
def Select_Target(self):
thread_update = threading.Thread(target=self.select_target_task)
thread_update.start()
def select_target_task(self):
path = askdirectory()
print("Selected target directory: %s"%path)
self.save_path.set(path)
def Crop(self):
thread_update = threading.Thread(target=self.crop_task)
thread_update.start()
def crop_task(self):
mode = self.align_com.get()
crop_size = int(self.test_com.get())
path = self.img_path.get()
tg_path = self.save_path.get()
if not os.path.exists(tg_path):
os.makedirs(tg_path)
tg_format = self.format_com.get()
min_scale = float(self.min_scale.get())
blur_t = 10.0
font = cv2.FONT_HERSHEY_SIMPLEX
self.detect.prepare(ctx_id = 0, det_thresh=0.6,\
det_size=(640,640),mode = mode,crop_size=crop_size,ratio=min_scale)
if path and tg_path:
imgs_list = []
if os.path.isdir(path):
print("Input a dir....")
imgs = glob.glob(os.path.join(path,"*"))
for item in imgs:
imgs_list.append(item)
# print(imgs_list)
index = 0
for img in imgs_list:
print(img)
attr_img_ori= cv2.imread(img)
try:
attr_img_align_crop, _ = self.detect.get(attr_img_ori)
sub_index = 0
if len(attr_img_align_crop) < 1:
print("Small face")
for face_i in attr_img_align_crop:
imageVar = cv2.Laplacian(face_i, cv2.CV_64F).var()
f_path =os.path.join(tg_path, str(index).zfill(6)+"_%d.%s"%(sub_index,tg_format))
if imageVar < blur_t:
print("Over blurry image!")
continue
# face_i = cv2.putText(face_i, '%.1f'%imageVar,(50, 50), font, 0.8, (15, 9, 255), 2)
cv2.imwrite(f_path,face_i)
sub_index += 1
index += 1
except:
print("Detect no face!")
continue
else:
print("Input an image....")
imgs_list.append(path)
print("Process finished!")
else:
print("Pathes are invalid!")
def on_closing(self):
# self.__save_config__()
self.master.destroy()
if __name__ == "__main__":
app = Application()
app.mainloop()
+3 -3
View File
@@ -5,7 +5,7 @@
# Created Date: Tuesday February 1st 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 2nd February 2022 11:17:04 pm
# Last Modified: Friday, 15th April 2022 10:07:15 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -27,7 +27,7 @@ def getParameters():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--save_path', type=str, default="./output/",
help="The root path for saving cropped images")
parser.add_argument('-v', '--video', type=str, default="G:\\4K\\05.mp4",
parser.add_argument('-v', '--video', type=str, default="G:\\4K\\Faith.Makes.Great.2021\\40.mp4",
help="The path for input video")
parser.add_argument('-c', '--crop_size', type=int, default=512,
help="expected image resolution")
@@ -39,7 +39,7 @@ def getParameters():
choices=['jpg', 'png'],help="target file format")
parser.add_argument('-i', '--interval', type=int, default=20,
help="number of frames interval")
parser.add_argument('-b', '--blur', type=float, default=10.0,
parser.add_argument('-b', '--blur', type=float, default=20.0,
help="blur degree")
return parser.parse_args()
@@ -0,0 +1,7 @@
# Pre-trained Models and Other Data
Download pre-trained models and other data. Put them in this folder.
1. [Pretrained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth)
1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)
+7
View File
@@ -0,0 +1,7 @@
# flake8: noqa
from .archs import *
from .data import *
from .models import *
from .utils import *
# from .version import *
+10
View File
@@ -0,0 +1,10 @@
import importlib
from basicsr.utils import scandir
from os import path as osp
# automatically scan and import arch modules for registry
# scan all the files that end with '_arch.py' under the archs folder
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]
+245
View File
@@ -0,0 +1,245 @@
import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
def conv3x3(inplanes, outplanes, stride=1):
"""A simple wrapper for 3x3 convolution with padding.
Args:
inplanes (int): Channel number of inputs.
outplanes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
"""
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
"""Basic residual block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class IRBlock(nn.Module):
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
super(IRBlock, self).__init__()
self.bn0 = nn.BatchNorm2d(inplanes)
self.conv1 = conv3x3(inplanes, inplanes)
self.bn1 = nn.BatchNorm2d(inplanes)
self.prelu = nn.PReLU()
self.conv2 = conv3x3(inplanes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.use_se = use_se
if self.use_se:
self.se = SEBlock(planes)
def forward(self, x):
residual = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.use_se:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.prelu(out)
return out
class Bottleneck(nn.Module):
"""Bottleneck block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 4 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class SEBlock(nn.Module):
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
Args:
channel (int): Channel number of inputs.
reduction (int): Channel reduction ration. Default: 16.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
nn.Sigmoid())
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
@ARCH_REGISTRY.register()
class ResNetArcFace(nn.Module):
"""ArcFace with ResNet architectures.
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
Args:
block (str): Block used in the ArcFace architecture.
layers (tuple(int)): Block numbers in each layer.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
def __init__(self, block, layers, use_se=True):
if block == 'IRBlock':
block = IRBlock
self.inplanes = 64
self.use_se = use_se
super(ResNetArcFace, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.prelu = nn.PReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.bn4 = nn.BatchNorm2d(512)
self.dropout = nn.Dropout()
self.fc5 = nn.Linear(512 * 8 * 8, 512)
self.bn5 = nn.BatchNorm1d(512)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
self.inplanes = planes
for _ in range(1, num_blocks):
layers.append(block(self.inplanes, planes, use_se=self.use_se))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn4(x)
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.fc5(x)
x = self.bn5(x)
return x
@@ -0,0 +1,312 @@
import math
import random
import torch
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from .gfpganv1_arch import ResUpBlock
from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
StyleGAN2GeneratorBilinear)
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
lr_mlp=0.01,
narrow=1,
sft_half=False):
super(StyleGAN2GeneratorBilinearSFT, self).__init__(
out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
lr_mlp=lr_mlp,
narrow=narrow)
self.sft_half = sft_half
def forward(self,
styles,
conditions,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2GeneratorBilinearSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
@ARCH_REGISTRY.register()
class GFPGANBilinear(nn.Module):
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
deployment. It can be easily converted to the clean version: GFPGANv1Clean.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
out_size,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
lr_mlp=0.01,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False):
super(GFPGANBilinear, self).__init__()
self.input_is_latent = input_is_latent
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
'16': int(512 * unet_narrow),
'32': int(512 * unet_narrow),
'64': int(256 * channel_multiplier * unet_narrow),
'128': int(128 * channel_multiplier * unet_narrow),
'256': int(64 * channel_multiplier * unet_narrow),
'512': int(32 * channel_multiplier * unet_narrow),
'1024': int(16 * channel_multiplier * unet_narrow)
}
self.log_size = int(math.log(out_size, 2))
first_out_size = 2**(int(math.log(out_size, 2)))
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
# downsample
in_channels = channels[f'{first_out_size}']
self.conv_body_down = nn.ModuleList()
for i in range(self.log_size, 2, -1):
out_channels = channels[f'{2**(i - 1)}']
self.conv_body_down.append(ResBlock(in_channels, out_channels))
in_channels = out_channels
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
# upsample
in_channels = channels['4']
self.conv_body_up = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
in_channels = out_channels
# to RGB
self.toRGB = nn.ModuleList()
for i in range(3, self.log_size + 1):
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
if different_w:
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
else:
linear_out_channel = num_style_feat
self.final_linear = EqualLinear(
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
out_size=out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
lr_mlp=lr_mlp,
narrow=narrow,
sft_half=sft_half)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
if sft_half:
sft_out_channels = out_channels
else:
sft_out_channels = out_channels * 2
self.condition_scale.append(
nn.Sequential(
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
ScaledLeakyReLU(0.2),
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
self.condition_shift.append(
nn.Sequential(
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
ScaledLeakyReLU(0.2),
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
"""Forward function for GFPGANBilinear.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
# encoder
feat = self.conv_body_first(x)
for i in range(self.log_size - 2):
feat = self.conv_body_down[i](feat)
unet_skips.insert(0, feat)
feat = self.final_conv(feat)
# style code
style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
# decode
for i in range(self.log_size - 2):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
conditions.append(shift.clone())
# generate rgb images
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
# decoder
image, _ = self.stylegan_decoder([style_code],
conditions,
return_latents=return_latents,
input_is_latent=self.input_is_latent,
randomize_noise=randomize_noise)
return image, out_rgbs
+439
View File
@@ -0,0 +1,439 @@
import math
import random
import torch
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
StyleGAN2Generator)
from basicsr.ops.fused_act import FusedLeakyReLU
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
resample_kernel=(1, 3, 3, 1),
lr_mlp=0.01,
narrow=1,
sft_half=False):
super(StyleGAN2GeneratorSFT, self).__init__(
out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
resample_kernel=resample_kernel,
lr_mlp=lr_mlp,
narrow=narrow)
self.sft_half = sft_half
def forward(self,
styles,
conditions,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2GeneratorSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ConvUpLayer(nn.Module):
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
stride (int): Stride of the convolution. Default: 1
padding (int): Zero-padding added to both sides of the input. Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
activate (bool): Whether use activateion. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
bias=True,
bias_init_val=0,
activate=True):
super(ConvUpLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
# self.scale is used to scale the convolution weights, which is related to the common initializations.
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
if bias and not activate:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
# activation
if activate:
if bias:
self.activation = FusedLeakyReLU(out_channels)
else:
self.activation = ScaledLeakyReLU(0.2)
else:
self.activation = None
def forward(self, x):
# bilinear upsample
out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
# conv
out = F.conv2d(
out,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
# activation
if self.activation is not None:
out = self.activation(out)
return out
class ResUpBlock(nn.Module):
"""Residual block with upsampling.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
"""
def __init__(self, in_channels, out_channels):
super(ResUpBlock, self).__init__()
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
skip = self.skip(x)
out = (out + skip) / math.sqrt(2)
return out
@ARCH_REGISTRY.register()
class GFPGANv1(nn.Module):
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
out_size,
num_style_feat=512,
channel_multiplier=1,
resample_kernel=(1, 3, 3, 1),
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
lr_mlp=0.01,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False):
super(GFPGANv1, self).__init__()
self.input_is_latent = input_is_latent
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
'16': int(512 * unet_narrow),
'32': int(512 * unet_narrow),
'64': int(256 * channel_multiplier * unet_narrow),
'128': int(128 * channel_multiplier * unet_narrow),
'256': int(64 * channel_multiplier * unet_narrow),
'512': int(32 * channel_multiplier * unet_narrow),
'1024': int(16 * channel_multiplier * unet_narrow)
}
self.log_size = int(math.log(out_size, 2))
first_out_size = 2**(int(math.log(out_size, 2)))
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
# downsample
in_channels = channels[f'{first_out_size}']
self.conv_body_down = nn.ModuleList()
for i in range(self.log_size, 2, -1):
out_channels = channels[f'{2**(i - 1)}']
self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
in_channels = out_channels
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
# upsample
in_channels = channels['4']
self.conv_body_up = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
in_channels = out_channels
# to RGB
self.toRGB = nn.ModuleList()
for i in range(3, self.log_size + 1):
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
if different_w:
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
else:
linear_out_channel = num_style_feat
self.final_linear = EqualLinear(
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorSFT(
out_size=out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
resample_kernel=resample_kernel,
lr_mlp=lr_mlp,
narrow=narrow,
sft_half=sft_half)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
if sft_half:
sft_out_channels = out_channels
else:
sft_out_channels = out_channels * 2
self.condition_scale.append(
nn.Sequential(
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
ScaledLeakyReLU(0.2),
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
self.condition_shift.append(
nn.Sequential(
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
ScaledLeakyReLU(0.2),
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
"""Forward function for GFPGANv1.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
# encoder
feat = self.conv_body_first(x)
for i in range(self.log_size - 2):
feat = self.conv_body_down[i](feat)
unet_skips.insert(0, feat)
feat = self.final_conv(feat)
# style code
style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
# decode
for i in range(self.log_size - 2):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
conditions.append(shift.clone())
# generate rgb images
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
# decoder
image, _ = self.stylegan_decoder([style_code],
conditions,
return_latents=return_latents,
input_is_latent=self.input_is_latent,
randomize_noise=randomize_noise)
return image, out_rgbs
@ARCH_REGISTRY.register()
class FacialComponentDiscriminator(nn.Module):
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
"""
def __init__(self):
super(FacialComponentDiscriminator, self).__init__()
# It now uses a VGG-style architectrue with fixed model size
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
def forward(self, x, return_feats=False):
"""Forward function for FacialComponentDiscriminator.
Args:
x (Tensor): Input images.
return_feats (bool): Whether to return intermediate features. Default: False.
"""
feat = self.conv1(x)
feat = self.conv3(self.conv2(feat))
rlt_feats = []
if return_feats:
rlt_feats.append(feat.clone())
feat = self.conv5(self.conv4(feat))
if return_feats:
rlt_feats.append(feat.clone())
out = self.final_conv(feat)
if return_feats:
return out, rlt_feats
else:
return out, None
@@ -0,0 +1,324 @@
import math
import random
import torch
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
super(StyleGAN2GeneratorCSFT, self).__init__(
out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
narrow=narrow)
self.sft_half = sft_half
def forward(self,
styles,
conditions,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2GeneratorCSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ResBlock(nn.Module):
"""Residual block with bilinear upsampling/downsampling.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
"""
def __init__(self, in_channels, out_channels, mode='down'):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
if mode == 'down':
self.scale_factor = 0.5
elif mode == 'up':
self.scale_factor = 2
def forward(self, x):
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
# upsample/downsample
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
# skip
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
skip = self.skip(x)
out = out + skip
return out
@ARCH_REGISTRY.register()
class GFPGANv1Clean(nn.Module):
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
out_size,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False):
super(GFPGANv1Clean, self).__init__()
self.input_is_latent = input_is_latent
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
'16': int(512 * unet_narrow),
'32': int(512 * unet_narrow),
'64': int(256 * channel_multiplier * unet_narrow),
'128': int(128 * channel_multiplier * unet_narrow),
'256': int(64 * channel_multiplier * unet_narrow),
'512': int(32 * channel_multiplier * unet_narrow),
'1024': int(16 * channel_multiplier * unet_narrow)
}
self.log_size = int(math.log(out_size, 2))
first_out_size = 2**(int(math.log(out_size, 2)))
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
# downsample
in_channels = channels[f'{first_out_size}']
self.conv_body_down = nn.ModuleList()
for i in range(self.log_size, 2, -1):
out_channels = channels[f'{2**(i - 1)}']
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
in_channels = out_channels
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
# upsample
in_channels = channels['4']
self.conv_body_up = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
in_channels = out_channels
# to RGB
self.toRGB = nn.ModuleList()
for i in range(3, self.log_size + 1):
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
if different_w:
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
else:
linear_out_channel = num_style_feat
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
out_size=out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
narrow=narrow,
sft_half=sft_half)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
if sft_half:
sft_out_channels = out_channels
else:
sft_out_channels = out_channels * 2
self.condition_scale.append(
nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
self.condition_shift.append(
nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
"""Forward function for GFPGANv1Clean.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
# encoder
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
for i in range(self.log_size - 2):
feat = self.conv_body_down[i](feat)
unet_skips.insert(0, feat)
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
# style code
style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
# decode
for i in range(self.log_size - 2):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
conditions.append(shift.clone())
# generate rgb images
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
# decoder
image, _ = self.stylegan_decoder([style_code],
conditions,
return_latents=return_latents,
input_is_latent=self.input_is_latent,
randomize_noise=randomize_noise)
return image, out_rgbs
@@ -0,0 +1,613 @@
import math
import random
import torch
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class NormStyleCode(nn.Module):
def forward(self, x):
"""Normalize the style codes.
Args:
x (Tensor): Style codes with shape (b, c).
Returns:
Tensor: Normalized tensor.
"""
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
class EqualLinear(nn.Module):
"""Equalized Linear as StyleGAN2.
Args:
in_channels (int): Size of each sample.
out_channels (int): Size of each output sample.
bias (bool): If set to ``False``, the layer will not learn an additive
bias. Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
lr_mul (float): Learning rate multiplier. Default: 1.
activation (None | str): The activation after ``linear`` operation.
Supported: 'fused_lrelu', None. Default: None.
"""
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
super(EqualLinear, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.lr_mul = lr_mul
self.activation = activation
if self.activation not in ['fused_lrelu', None]:
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
"Supported ones are: ['fused_lrelu', None].")
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
def forward(self, x):
if self.bias is None:
bias = None
else:
bias = self.bias * self.lr_mul
if self.activation == 'fused_lrelu':
out = F.linear(x, self.weight * self.scale)
out = fused_leaky_relu(out, bias)
else:
out = F.linear(x, self.weight * self.scale, bias=bias)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, bias={self.bias is not None})')
class ModulatedConv2d(nn.Module):
"""Modulated Conv2d used in StyleGAN2.
There is no bias in ModulatedConv2d.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether to demodulate in the conv layer.
Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
eps (float): A value added to the denominator for numerical stability.
Default: 1e-8.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
eps=1e-8,
interpolation_mode='bilinear'):
super(ModulatedConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.demodulate = demodulate
self.sample_mode = sample_mode
self.eps = eps
self.interpolation_mode = interpolation_mode
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
# modulation inside each modulated conv
self.modulation = EqualLinear(
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
self.padding = kernel_size // 2
def forward(self, x, style):
"""Forward function.
Args:
x (Tensor): Tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
Returns:
Tensor: Modulated tensor after convolution.
"""
b, c, h, w = x.shape # c = c_in
# weight modulation
style = self.modulation(style).view(b, 1, c, 1, 1)
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
if self.sample_mode == 'upsample':
x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
elif self.sample_mode == 'downsample':
x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
# weight: (b*c_out, c_in, k, k), groups=b
out = F.conv2d(x, weight, padding=self.padding, groups=b)
out = out.view(b, self.out_channels, *out.shape[2:4])
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size}, '
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
class StyleConv(nn.Module):
"""Style conv.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode='bilinear'):
super(StyleConv, self).__init__()
self.modulated_conv = ModulatedConv2d(
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=demodulate,
sample_mode=sample_mode,
interpolation_mode=interpolation_mode)
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
self.activate = FusedLeakyReLU(out_channels)
def forward(self, x, style, noise=None):
# modulate
out = self.modulated_conv(x, style)
# noise injection
if noise is None:
b, _, h, w = out.shape
noise = out.new_empty(b, 1, h, w).normal_()
out = out + self.weight * noise
# activation (with bias)
out = self.activate(out)
return out
class ToRGB(nn.Module):
"""To RGB from features.
Args:
in_channels (int): Channel number of input.
num_style_feat (int): Channel number of style features.
upsample (bool): Whether to upsample. Default: True.
"""
def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
super(ToRGB, self).__init__()
self.upsample = upsample
self.interpolation_mode = interpolation_mode
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
self.modulated_conv = ModulatedConv2d(
in_channels,
3,
kernel_size=1,
num_style_feat=num_style_feat,
demodulate=False,
sample_mode=None,
interpolation_mode=interpolation_mode)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, x, style, skip=None):
"""Forward function.
Args:
x (Tensor): Feature tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
skip (Tensor): Base/skip tensor. Default: None.
Returns:
Tensor: RGB images.
"""
out = self.modulated_conv(x, style)
out = out + self.bias
if skip is not None:
if self.upsample:
skip = F.interpolate(
skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
out = out + skip
return out
class ConstantInput(nn.Module):
"""Constant input.
Args:
num_channel (int): Channel number of constant input.
size (int): Spatial size of constant input.
"""
def __init__(self, num_channel, size):
super(ConstantInput, self).__init__()
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
def forward(self, batch):
out = self.weight.repeat(batch, 1, 1, 1)
return out
@ARCH_REGISTRY.register()
class StyleGAN2GeneratorBilinear(nn.Module):
"""StyleGAN2 Generator.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of
StyleGAN2. Default: 2.
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
def __init__(self,
out_size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=2,
lr_mlp=0.01,
narrow=1,
interpolation_mode='bilinear'):
super(StyleGAN2GeneratorBilinear, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
style_mlp_layers = [NormStyleCode()]
for i in range(num_mlp):
style_mlp_layers.append(
EqualLinear(
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
activation='fused_lrelu'))
self.style_mlp = nn.Sequential(*style_mlp_layers)
channels = {
'4': int(512 * narrow),
'8': int(512 * narrow),
'16': int(512 * narrow),
'32': int(512 * narrow),
'64': int(256 * channel_multiplier * narrow),
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
'1024': int(16 * channel_multiplier * narrow)
}
self.channels = channels
self.constant_input = ConstantInput(channels['4'], size=4)
self.style_conv1 = StyleConv(
channels['4'],
channels['4'],
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode=interpolation_mode)
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
self.log_size = int(math.log(out_size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.num_latent = self.log_size * 2 - 2
self.style_convs = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channels = channels['4']
# noise
for layer_idx in range(self.num_layers):
resolution = 2**((layer_idx + 5) // 2)
shape = [1, 1, resolution, resolution]
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
# style convs and to_rgbs
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.style_convs.append(
StyleConv(
in_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode='upsample',
interpolation_mode=interpolation_mode))
self.style_convs.append(
StyleConv(
out_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
interpolation_mode=interpolation_mode))
self.to_rgbs.append(
ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
in_channels = out_channels
def make_noise(self):
"""Make noise for noise injection."""
device = self.constant_input.weight.device
noises = [torch.randn(1, 1, 4, 4, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_latent(self, x):
return self.style_mlp(x)
def mean_latent(self, num_latent):
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent
def forward(self,
styles,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2Generator.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style.
Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is
False. Default: True.
truncation (float): TODO. Default: 1.
truncation_latent (Tensor | None): TODO. Default: None.
inject_index (int | None): The injection index for mixing noise.
Default: None.
return_latents (bool): Whether to return style latents.
Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latent with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ScaledLeakyReLU(nn.Module):
"""Scaled LeakyReLU.
Args:
negative_slope (float): Negative slope. Default: 0.2.
"""
def __init__(self, negative_slope=0.2):
super(ScaledLeakyReLU, self).__init__()
self.negative_slope = negative_slope
def forward(self, x):
out = F.leaky_relu(x, negative_slope=self.negative_slope)
return out * math.sqrt(2)
class EqualConv2d(nn.Module):
"""Equalized Linear as StyleGAN2.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
stride (int): Stride of the convolution. Default: 1
padding (int): Zero-padding added to both sides of the input.
Default: 0.
bias (bool): If ``True``, adds a learnable bias to the output.
Default: ``True``.
bias_init_val (float): Bias initialized value. Default: 0.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
super(EqualConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
else:
self.register_parameter('bias', None)
def forward(self, x):
out = F.conv2d(
x,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
f'out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size},'
f' stride={self.stride}, padding={self.padding}, '
f'bias={self.bias is not None})')
class ConvLayer(nn.Sequential):
"""Conv Layer used in StyleGAN2 Discriminator.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Kernel size.
downsample (bool): Whether downsample by a factor of 2.
Default: False.
bias (bool): Whether with bias. Default: True.
activate (bool): Whether use activateion. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
downsample=False,
bias=True,
activate=True,
interpolation_mode='bilinear'):
layers = []
self.interpolation_mode = interpolation_mode
# downsample
if downsample:
if self.interpolation_mode == 'nearest':
self.align_corners = None
else:
self.align_corners = False
layers.append(
torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
stride = 1
self.padding = kernel_size // 2
# conv
layers.append(
EqualConv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
and not activate))
# activation
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channels))
else:
layers.append(ScaledLeakyReLU(0.2))
super(ConvLayer, self).__init__(*layers)
class ResBlock(nn.Module):
"""Residual block used in StyleGAN2 Discriminator.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
"""
def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
super(ResBlock, self).__init__()
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
self.conv2 = ConvLayer(
in_channels,
out_channels,
3,
downsample=True,
interpolation_mode=interpolation_mode,
bias=True,
activate=True)
self.skip = ConvLayer(
in_channels,
out_channels,
1,
downsample=True,
interpolation_mode=interpolation_mode,
bias=False,
activate=False)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
skip = self.skip(x)
out = (out + skip) / math.sqrt(2)
return out
@@ -0,0 +1,368 @@
import math
import random
import torch
from basicsr.archs.arch_util import default_init_weights
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class NormStyleCode(nn.Module):
def forward(self, x):
"""Normalize the style codes.
Args:
x (Tensor): Style codes with shape (b, c).
Returns:
Tensor: Normalized tensor.
"""
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
class ModulatedConv2d(nn.Module):
"""Modulated Conv2d used in StyleGAN2.
There is no bias in ModulatedConv2d.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
eps=1e-8):
super(ModulatedConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.demodulate = demodulate
self.sample_mode = sample_mode
self.eps = eps
# modulation inside each modulated conv
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
# initialization
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
math.sqrt(in_channels * kernel_size**2))
self.padding = kernel_size // 2
def forward(self, x, style):
"""Forward function.
Args:
x (Tensor): Tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
Returns:
Tensor: Modulated tensor after convolution.
"""
b, c, h, w = x.shape # c = c_in
# weight modulation
style = self.modulation(style).view(b, 1, c, 1, 1)
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
weight = self.weight * style # (b, c_out, c_in, k, k)
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
# upsample or downsample if necessary
if self.sample_mode == 'upsample':
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
elif self.sample_mode == 'downsample':
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
# weight: (b*c_out, c_in, k, k), groups=b
out = F.conv2d(x, weight, padding=self.padding, groups=b)
out = out.view(b, self.out_channels, *out.shape[2:4])
return out
def __repr__(self):
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
class StyleConv(nn.Module):
"""Style conv used in StyleGAN2.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
"""
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
super(StyleConv, self).__init__()
self.modulated_conv = ModulatedConv2d(
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x, style, noise=None):
# modulate
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
# noise injection
if noise is None:
b, _, h, w = out.shape
noise = out.new_empty(b, 1, h, w).normal_()
out = out + self.weight * noise
# add bias
out = out + self.bias
# activation
out = self.activate(out)
return out
class ToRGB(nn.Module):
"""To RGB (image space) from features.
Args:
in_channels (int): Channel number of input.
num_style_feat (int): Channel number of style features.
upsample (bool): Whether to upsample. Default: True.
"""
def __init__(self, in_channels, num_style_feat, upsample=True):
super(ToRGB, self).__init__()
self.upsample = upsample
self.modulated_conv = ModulatedConv2d(
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, x, style, skip=None):
"""Forward function.
Args:
x (Tensor): Feature tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
skip (Tensor): Base/skip tensor. Default: None.
Returns:
Tensor: RGB images.
"""
out = self.modulated_conv(x, style)
out = out + self.bias
if skip is not None:
if self.upsample:
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
out = out + skip
return out
class ConstantInput(nn.Module):
"""Constant input.
Args:
num_channel (int): Channel number of constant input.
size (int): Spatial size of constant input.
"""
def __init__(self, num_channel, size):
super(ConstantInput, self).__init__()
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
def forward(self, batch):
out = self.weight.repeat(batch, 1, 1, 1)
return out
@ARCH_REGISTRY.register()
class StyleGAN2GeneratorClean(nn.Module):
"""Clean version of StyleGAN2 Generator.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
super(StyleGAN2GeneratorClean, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
style_mlp_layers = [NormStyleCode()]
for i in range(num_mlp):
style_mlp_layers.extend(
[nn.Linear(num_style_feat, num_style_feat, bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True)])
self.style_mlp = nn.Sequential(*style_mlp_layers)
# initialization
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
# channel list
channels = {
'4': int(512 * narrow),
'8': int(512 * narrow),
'16': int(512 * narrow),
'32': int(512 * narrow),
'64': int(256 * channel_multiplier * narrow),
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
'1024': int(16 * channel_multiplier * narrow)
}
self.channels = channels
self.constant_input = ConstantInput(channels['4'], size=4)
self.style_conv1 = StyleConv(
channels['4'],
channels['4'],
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None)
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
self.log_size = int(math.log(out_size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.num_latent = self.log_size * 2 - 2
self.style_convs = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channels = channels['4']
# noise
for layer_idx in range(self.num_layers):
resolution = 2**((layer_idx + 5) // 2)
shape = [1, 1, resolution, resolution]
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
# style convs and to_rgbs
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.style_convs.append(
StyleConv(
in_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode='upsample'))
self.style_convs.append(
StyleConv(
out_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None))
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
in_channels = out_channels
def make_noise(self):
"""Make noise for noise injection."""
device = self.constant_input.weight.device
noises = [torch.randn(1, 1, 4, 4, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_latent(self, x):
return self.style_mlp(x)
def mean_latent(self, num_latent):
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent
def forward(self,
styles,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2GeneratorClean.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
+10
View File
@@ -0,0 +1,10 @@
import importlib
from basicsr.utils import scandir
from os import path as osp
# automatically scan and import dataset modules for registry
# scan all the files that end with '_dataset.py' under the data folder
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]
@@ -0,0 +1,230 @@
import cv2
import math
import numpy as np
import os.path as osp
import torch
import torch.utils.data as data
from basicsr.data import degradations as degradations
from basicsr.data.data_util import paths_from_folder
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize)
@DATASET_REGISTRY.register()
class FFHQDegradationDataset(data.Dataset):
"""FFHQ dataset for GFPGAN.
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
io_backend (dict): IO backend type and other kwarg.
mean (list | tuple): Image mean.
std (list | tuple): Image std.
use_hflip (bool): Whether to horizontally flip.
Please see more options in the codes.
"""
def __init__(self, opt):
super(FFHQDegradationDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
self.mean = opt['mean']
self.std = opt['std']
self.out_size = opt['out_size']
self.crop_components = opt.get('crop_components', False) # facial components
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
if self.crop_components:
# load component list from a pre-process pth files
self.components_list = torch.load(opt.get('component_path'))
# file client (lmdb io backend)
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = self.gt_folder
if not self.gt_folder.endswith('.lmdb'):
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
self.paths = [line.split('.')[0] for line in fin]
else:
# disk backend: scan file list from a folder
self.paths = paths_from_folder(self.gt_folder)
# degradation configurations
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob']
self.blur_sigma = opt['blur_sigma']
self.downsample_range = opt['downsample_range']
self.noise_range = opt['noise_range']
self.jpeg_range = opt['jpeg_range']
# color jitter
self.color_jitter_prob = opt.get('color_jitter_prob')
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
# to gray
self.gray_prob = opt.get('gray_prob')
logger = get_root_logger()
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
if self.color_jitter_prob is not None:
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
if self.gray_prob is not None:
logger.info(f'Use random gray. Prob: {self.gray_prob}')
self.color_jitter_shift /= 255.
@staticmethod
def color_jitter(img, shift):
"""jitter color: randomly jitter the RGB values, in numpy formats"""
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
img = img + jitter_val
img = np.clip(img, 0, 1)
return img
@staticmethod
def color_jitter_pt(img, brightness, contrast, saturation, hue):
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and brightness is not None:
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = adjust_brightness(img, brightness_factor)
if fn_id == 1 and contrast is not None:
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = adjust_contrast(img, contrast_factor)
if fn_id == 2 and saturation is not None:
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = adjust_saturation(img, saturation_factor)
if fn_id == 3 and hue is not None:
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = adjust_hue(img, hue_factor)
return img
def get_component_coordinates(self, index, status):
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
components_bbox = self.components_list[f'{index:08d}']
if status[0]: # hflip
# exchange right and left eye
tmp = components_bbox['left_eye']
components_bbox['left_eye'] = components_bbox['right_eye']
components_bbox['right_eye'] = tmp
# modify the width coordinate
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
# get coordinates
locations = []
for part in ['left_eye', 'right_eye', 'mouth']:
mean = components_bbox[part][0:2]
half_len = components_bbox[part][2]
if 'eye' in part:
half_len *= self.eye_enlarge_ratio
loc = np.hstack((mean - half_len + 1, mean + half_len))
loc = torch.from_numpy(loc).float()
locations.append(loc)
return locations
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# load gt image
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.paths[index]
img_bytes = self.file_client.get(gt_path)
img_gt = imfrombytes(img_bytes, float32=True)
# random horizontal flip
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
h, w, _ = img_gt.shape
# get facial component coordinates
if self.crop_components:
locations = self.get_component_coordinates(index, status)
loc_left_eye, loc_right_eye, loc_mouth = locations
# ------------------------ generate lq image ------------------------ #
# blur
kernel = degradations.random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
self.blur_kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
noise_range=None)
img_lq = cv2.filter2D(img_gt, -1, kernel)
# downsample
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
# noise
if self.noise_range is not None:
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
# jpeg compression
if self.jpeg_range is not None:
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
# resize to original size
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
# random color jitter (only for lq)
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
# random to gray (only for lq)
if self.gray_prob and np.random.uniform() < self.gray_prob:
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
if self.opt.get('gt_gray'): # whether convert GT to gray images
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# random color jitter (pytorch version) (only for lq)
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
brightness = self.opt.get('brightness', (0.5, 1.5))
contrast = self.opt.get('contrast', (0.5, 1.5))
saturation = self.opt.get('saturation', (0, 1.5))
hue = self.opt.get('hue', (-0.1, 0.1))
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
# round and clip
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
# normalize
normalize(img_gt, self.mean, self.std, inplace=True)
normalize(img_lq, self.mean, self.std, inplace=True)
if self.crop_components:
return_dict = {
'lq': img_lq,
'gt': img_gt,
'gt_path': gt_path,
'loc_left_eye': loc_left_eye,
'loc_right_eye': loc_right_eye,
'loc_mouth': loc_mouth
}
return return_dict
else:
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)
+10
View File
@@ -0,0 +1,10 @@
import importlib
from basicsr.utils import scandir
from os import path as osp
# automatically scan and import model modules for registry
# scan all the files that end with '_model.py' under the model folder
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
+579
View File
@@ -0,0 +1,579 @@
import math
import os.path as osp
import torch
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.losses.losses import r1_penalty
from basicsr.metrics import calculate_metric
from basicsr.models.base_model import BaseModel
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
from collections import OrderedDict
from torch.nn import functional as F
from torchvision.ops import roi_align
from tqdm import tqdm
@MODEL_REGISTRY.register()
class GFPGANModel(BaseModel):
"""The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
def __init__(self, opt):
super(GFPGANModel, self).__init__(opt)
self.idx = 0 # it is used for saving data for check
# define network
self.net_g = build_network(opt['network_g'])
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_g', 'params')
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
if self.is_train:
self.init_training_settings()
def init_training_settings(self):
train_opt = self.opt['train']
# ----------- define net_d ----------- #
self.net_d = build_network(self.opt['network_d'])
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
# net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g.train()
self.net_d.train()
self.net_g_ema.eval()
# ----------- facial component networks ----------- #
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
self.use_facial_disc = True
else:
self.use_facial_disc = False
if self.use_facial_disc:
# left eye
self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
self.print_network(self.net_d_left_eye)
load_path = self.opt['path'].get('pretrain_network_d_left_eye')
if load_path is not None:
self.load_network(self.net_d_left_eye, load_path, True, 'params')
# right eye
self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
self.print_network(self.net_d_right_eye)
load_path = self.opt['path'].get('pretrain_network_d_right_eye')
if load_path is not None:
self.load_network(self.net_d_right_eye, load_path, True, 'params')
# mouth
self.net_d_mouth = build_network(self.opt['network_d_mouth'])
self.net_d_mouth = self.model_to_device(self.net_d_mouth)
self.print_network(self.net_d_mouth)
load_path = self.opt['path'].get('pretrain_network_d_mouth')
if load_path is not None:
self.load_network(self.net_d_mouth, load_path, True, 'params')
self.net_d_left_eye.train()
self.net_d_right_eye.train()
self.net_d_mouth.train()
# ----------- define facial component gan loss ----------- #
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
# ----------- define losses ----------- #
# pixel loss
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
# perceptual loss
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
# L1 loss is used in pyramid loss, component style loss and identity loss
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
# gan loss (wgan)
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
# ----------- define identity loss ----------- #
if 'network_identity' in self.opt:
self.use_identity = True
else:
self.use_identity = False
if self.use_identity:
# define identity network
self.network_identity = build_network(self.opt['network_identity'])
self.network_identity = self.model_to_device(self.network_identity)
self.print_network(self.network_identity)
load_path = self.opt['path'].get('pretrain_network_identity')
if load_path is not None:
self.load_network(self.network_identity, load_path, True, None)
self.network_identity.eval()
for param in self.network_identity.parameters():
param.requires_grad = False
# regularization weights
self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
self.net_d_reg_every = train_opt['net_d_reg_every']
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def setup_optimizers(self):
train_opt = self.opt['train']
# ----------- optimizer g ----------- #
net_g_reg_ratio = 1
normal_params = []
for _, param in self.net_g.named_parameters():
normal_params.append(param)
optim_params_g = [{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_g']['lr']
}]
optim_type = train_opt['optim_g'].pop('type')
lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
self.optimizers.append(self.optimizer_g)
# ----------- optimizer d ----------- #
net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
normal_params = []
for _, param in self.net_d.named_parameters():
normal_params.append(param)
optim_params_d = [{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_d']['lr']
}]
optim_type = train_opt['optim_d'].pop('type')
lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
self.optimizers.append(self.optimizer_d)
# ----------- optimizers for facial component networks ----------- #
if self.use_facial_disc:
# setup optimizers for facial component discriminators
optim_type = train_opt['optim_component'].pop('type')
lr = train_opt['optim_component']['lr']
# left eye
self.optimizer_d_left_eye = self.get_optimizer(
optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
self.optimizers.append(self.optimizer_d_left_eye)
# right eye
self.optimizer_d_right_eye = self.get_optimizer(
optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
self.optimizers.append(self.optimizer_d_right_eye)
# mouth
self.optimizer_d_mouth = self.get_optimizer(
optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
self.optimizers.append(self.optimizer_d_mouth)
def feed_data(self, data):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
if 'loc_left_eye' in data:
# get facial component locations, shape (batch, 4)
self.loc_left_eyes = data['loc_left_eye']
self.loc_right_eyes = data['loc_right_eye']
self.loc_mouths = data['loc_mouth']
# uncomment to check data
# import torchvision
# if self.opt['rank'] == 0:
# import os
# os.makedirs('tmp/gt', exist_ok=True)
# os.makedirs('tmp/lq', exist_ok=True)
# print(self.idx)
# torchvision.utils.save_image(
# self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
# torchvision.utils.save_image(
# self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
# self.idx = self.idx + 1
def construct_img_pyramid(self):
"""Construct image pyramid for intermediate restoration loss"""
pyramid_gt = [self.gt]
down_img = self.gt
for _ in range(0, self.log_size - 3):
down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
pyramid_gt.insert(0, down_img)
return pyramid_gt
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
face_ratio = int(self.opt['network_g']['out_size'] / 512)
eye_out_size *= face_ratio
mouth_out_size *= face_ratio
rois_eyes = []
rois_mouths = []
for b in range(self.loc_left_eyes.size(0)): # loop for batch size
# left eye and right eye
img_inds = self.loc_left_eyes.new_full((2, 1), b)
bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
rois_eyes.append(rois)
# mouse
img_inds = self.loc_left_eyes.new_full((1, 1), b)
rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
rois_mouths.append(rois)
rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
# real images
all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
self.left_eyes_gt = all_eyes[0::2, :, :, :]
self.right_eyes_gt = all_eyes[1::2, :, :, :]
self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
# output
all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
self.left_eyes = all_eyes[0::2, :, :, :]
self.right_eyes = all_eyes[1::2, :, :, :]
self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
n, c, h, w = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram
def gray_resize_for_identity(self, out, size=128):
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
out_gray = out_gray.unsqueeze(1)
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
return out_gray
def optimize_parameters(self, current_iter):
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
# do not update facial component net_d
if self.use_facial_disc:
for p in self.net_d_left_eye.parameters():
p.requires_grad = False
for p in self.net_d_right_eye.parameters():
p.requires_grad = False
for p in self.net_d_mouth.parameters():
p.requires_grad = False
# image pyramid loss weight
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error
if pyramid_loss_weight > 0:
self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
pyramid_gt = self.construct_img_pyramid()
else:
self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
# get roi-align regions
if self.use_facial_disc:
self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
l_g_total = 0
loss_dict = OrderedDict()
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
# image pyramid loss
if pyramid_loss_weight > 0:
for i in range(0, self.log_size - 2):
l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
l_g_total += l_pyramid
loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
if l_g_style is not None:
l_g_total += l_g_style
loss_dict['l_g_style'] = l_g_style
# gan loss
fake_g_pred = self.net_d(self.output)
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan'] = l_g_gan
# facial component loss
if self.use_facial_disc:
# left eye
fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan_left_eye'] = l_g_gan
# right eye
fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan_right_eye'] = l_g_gan
# mouth
fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
l_g_total += l_g_gan
loss_dict['l_g_gan_mouth'] = l_g_gan
if self.opt['train'].get('comp_style_weight', 0) > 0:
# get gt feat
_, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
_, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
_, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
def _comp_style(feat, feat_gt, criterion):
return criterion(self._gram_mat(feat[0]), self._gram_mat(
feat_gt[0].detach())) * 0.5 + criterion(
self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
# facial component style loss
comp_style_loss = 0
comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
l_g_total += comp_style_loss
loss_dict['l_g_comp_style_loss'] = comp_style_loss
# identity loss
if self.use_identity:
identity_weight = self.opt['train']['identity_weight']
# get gray images and resize
out_gray = self.gray_resize_for_identity(self.output)
gt_gray = self.gray_resize_for_identity(self.gt)
identity_gt = self.network_identity(gt_gray).detach()
identity_out = self.network_identity(out_gray)
l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
l_g_total += l_identity
loss_dict['l_identity'] = l_identity
l_g_total.backward()
self.optimizer_g.step()
# EMA
self.model_ema(decay=0.5**(32 / (10 * 1000)))
# ----------- optimize net_d ----------- #
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
if self.use_facial_disc:
for p in self.net_d_left_eye.parameters():
p.requires_grad = True
for p in self.net_d_right_eye.parameters():
p.requires_grad = True
for p in self.net_d_mouth.parameters():
p.requires_grad = True
self.optimizer_d_left_eye.zero_grad()
self.optimizer_d_right_eye.zero_grad()
self.optimizer_d_mouth.zero_grad()
fake_d_pred = self.net_d(self.output.detach())
real_d_pred = self.net_d(self.gt)
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d'] = l_d
# In WGAN, real_score should be positive and fake_score should be negative
loss_dict['real_score'] = real_d_pred.detach().mean()
loss_dict['fake_score'] = fake_d_pred.detach().mean()
l_d.backward()
# regularization loss
if current_iter % self.net_d_reg_every == 0:
self.gt.requires_grad = True
real_pred = self.net_d(self.gt)
l_d_r1 = r1_penalty(real_pred, self.gt)
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
loss_dict['l_d_r1'] = l_d_r1.detach().mean()
l_d_r1.backward()
self.optimizer_d.step()
# optimize facial component discriminators
if self.use_facial_disc:
# left eye
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
l_d_left_eye = self.cri_component(
real_d_pred, True, is_disc=True) + self.cri_gan(
fake_d_pred, False, is_disc=True)
loss_dict['l_d_left_eye'] = l_d_left_eye
l_d_left_eye.backward()
# right eye
fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
l_d_right_eye = self.cri_component(
real_d_pred, True, is_disc=True) + self.cri_gan(
fake_d_pred, False, is_disc=True)
loss_dict['l_d_right_eye'] = l_d_right_eye
l_d_right_eye.backward()
# mouth
fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
l_d_mouth = self.cri_component(
real_d_pred, True, is_disc=True) + self.cri_gan(
fake_d_pred, False, is_disc=True)
loss_dict['l_d_mouth'] = l_d_mouth
l_d_mouth.backward()
self.optimizer_d_left_eye.step()
self.optimizer_d_right_eye.step()
self.optimizer_d_mouth.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
def test(self):
with torch.no_grad():
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
self.output, _ = self.net_g_ema(self.lq)
else:
logger = get_root_logger()
logger.warning('Do not have self.net_g_ema, use self.net_g.')
self.net_g.eval()
self.output, _ = self.net_g(self.lq)
self.net_g.train()
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
if self.opt['rank'] == 0:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
use_pbar = self.opt['val'].get('pbar', False)
if with_metrics:
if not hasattr(self, 'metric_results'): # only execute in the first run
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
self.metric_results = {metric: 0 for metric in self.metric_results}
metric_data = dict()
if use_pbar:
pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
self.test()
sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
metric_data['img'] = sr_img
if hasattr(self, 'gt'):
gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
metric_data['img2'] = gt_img
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}.png')
else:
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
imwrite(sr_img, save_img_path)
if with_metrics:
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
self.metric_results[name] += calculate_metric(metric_data, opt_)
if use_pbar:
pbar.update(1)
pbar.set_description(f'Test {img_name}')
if use_pbar:
pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
# update the best metric result
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}'
if hasattr(self, 'best_metric_results'):
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
log_str += '\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
def save(self, epoch, current_iter):
# save net_g and net_d
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
self.save_network(self.net_d, 'net_d', current_iter)
# save component discriminators
if self.use_facial_disc:
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
# save training state
self.save_training_state(epoch, current_iter)
+11
View File
@@ -0,0 +1,11 @@
# flake8: noqa
import os.path as osp
from basicsr.train import train_pipeline
import gfpgan.archs
import gfpgan.data
import gfpgan.models
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)
+143
View File
@@ -0,0 +1,143 @@
import cv2
import os
import torch
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class GFPGANer():
"""Helper for restoration with GFPGAN.
It will detect and crop faces, and then resize the faces to 512x512.
GFPGAN is used to restored the resized faces.
The background is upsampled with the bg_upsampler.
Finally, the faces will be pasted back to the upsample background image.
Args:
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
upscale (float): The upscale of the final output. Default: 2.
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
"""
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
self.upscale = upscale
self.bg_upsampler = bg_upsampler
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# initialize the GFP-GAN
if arch == 'clean':
self.gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'bilinear':
self.gfpgan = GFPGANBilinear(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'original':
self.gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=True,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
# initialize face helper
self.face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=self.device)
if model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)
@torch.no_grad()
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
self.face_helper.clean_all()
if has_aligned: # the inputs are already aligned
img = cv2.resize(img, (512, 512))
self.face_helper.cropped_faces = [img]
else:
self.face_helper.read_image(img)
# get face landmarks for each face
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
# align and warp each face
self.face_helper.align_warp_face()
# face restoration
for cropped_face in self.face_helper.cropped_faces:
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
try:
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:
print(f'\tFailed inference for GFPGAN: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)
if not has_aligned and paste_back:
# upsample the background
if self.bg_upsampler is not None:
# Now only support RealESRGAN for upsampling background
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
else:
bg_img = None
self.face_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
else:
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
+5
View File
@@ -0,0 +1,5 @@
# GENERATED VERSION FILE
# TIME: Wed Mar 30 13:34:44 2022
__version__ = '1.3.2'
__gitsha__ = 'unknown'
version_info = (1, 3, 2)
+3
View File
@@ -0,0 +1,3 @@
# Weights
Put the downloaded weights to this folder.
@@ -0,0 +1,164 @@
import argparse
import math
import torch
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
def modify_checkpoint(checkpoint_bilinear, checkpoint_clean):
for ori_k, ori_v in checkpoint_bilinear.items():
if 'stylegan_decoder' in ori_k:
if 'style_mlp' in ori_k: # style_mlp_layers
lr_mul = 0.01
prefix, name, idx, var = ori_k.split('.')
idx = (int(idx) * 2) - 1
crt_k = f'{prefix}.{name}.{idx}.{var}'
if var == 'weight':
_, c_in = ori_v.size()
scale = (1 / math.sqrt(c_in)) * lr_mul
crt_v = ori_v * scale * 2**0.5
else:
crt_v = ori_v * lr_mul * 2**0.5
checkpoint_clean[crt_k] = crt_v
elif 'modulation' in ori_k: # modulation in StyleConv
lr_mul = 1
crt_k = ori_k
var = ori_k.split('.')[-1]
if var == 'weight':
_, c_in = ori_v.size()
scale = (1 / math.sqrt(c_in)) * lr_mul
crt_v = ori_v * scale
else:
crt_v = ori_v * lr_mul
checkpoint_clean[crt_k] = crt_v
elif 'style_conv' in ori_k:
# StyleConv in style_conv1 and style_convs
if 'activate' in ori_k: # FusedLeakyReLU
# eg. style_conv1.activate.bias
# eg. style_convs.13.activate.bias
split_rlt = ori_k.split('.')
if len(split_rlt) == 4:
prefix, name, _, var = split_rlt
crt_k = f'{prefix}.{name}.{var}'
elif len(split_rlt) == 5:
prefix, name, idx, _, var = split_rlt
crt_k = f'{prefix}.{name}.{idx}.{var}'
crt_v = ori_v * 2**0.5 # 2**0.5 used in FusedLeakyReLU
c = crt_v.size(0)
checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1)
elif 'modulated_conv' in ori_k:
# eg. style_conv1.modulated_conv.weight
# eg. style_convs.13.modulated_conv.weight
_, c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v * scale
elif 'weight' in ori_k:
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v * 2**0.5
elif 'to_rgb' in ori_k: # StyleConv in to_rgb1 and to_rgbs
if 'modulated_conv' in ori_k:
# eg. to_rgb1.modulated_conv.weight
# eg. to_rgbs.5.modulated_conv.weight
_, c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v * scale
else:
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v
else:
crt_k = ori_k
checkpoint_clean[crt_k] = ori_v
# end of 'stylegan_decoder'
elif 'conv_body_first' in ori_k or 'final_conv' in ori_k:
# key name
name, _, var = ori_k.split('.')
crt_k = f'{name}.{var}'
# weight and bias
if var == 'weight':
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
else:
checkpoint_clean[crt_k] = ori_v * 2**0.5
elif 'conv_body' in ori_k:
if 'conv_body_up' in ori_k:
ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight')
ori_k = ori_k.replace('skip.weight', 'skip.1.weight')
name1, idx1, name2, _, var = ori_k.split('.')
crt_k = f'{name1}.{idx1}.{name2}.{var}'
if name2 == 'skip':
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale / 2**0.5
else:
if var == 'weight':
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale
else:
checkpoint_clean[crt_k] = ori_v
if 'conv1' in ori_k:
checkpoint_clean[crt_k] *= 2**0.5
elif 'toRGB' in ori_k:
crt_k = ori_k
if 'weight' in ori_k:
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale
else:
checkpoint_clean[crt_k] = ori_v
elif 'final_linear' in ori_k:
crt_k = ori_k
if 'weight' in ori_k:
_, c_in = ori_v.size()
scale = 1 / math.sqrt(c_in)
checkpoint_clean[crt_k] = ori_v * scale
else:
checkpoint_clean[crt_k] = ori_v
elif 'condition' in ori_k:
crt_k = ori_k
if '0.weight' in ori_k:
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
elif '0.bias' in ori_k:
checkpoint_clean[crt_k] = ori_v * 2**0.5
elif '2.weight' in ori_k:
c_out, c_in, k1, k2 = ori_v.size()
scale = 1 / math.sqrt(c_in * k1 * k2)
checkpoint_clean[crt_k] = ori_v * scale
elif '2.bias' in ori_k:
checkpoint_clean[crt_k] = ori_v
return checkpoint_clean
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ori_path', type=str, help='Path to the original model')
parser.add_argument('--narrow', type=float, default=1)
parser.add_argument('--channel_multiplier', type=float, default=2)
parser.add_argument('--save_path', type=str)
args = parser.parse_args()
ori_ckpt = torch.load(args.ori_path)['params_ema']
net = GFPGANv1Clean(
512,
num_style_feat=512,
channel_multiplier=args.channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
# for stylegan decoder
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=args.narrow,
sft_half=True)
crt_ckpt = net.state_dict()
crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt)
print(f'Save to {args.save_path}.')
torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False)
+85
View File
@@ -0,0 +1,85 @@
import cv2
import json
import numpy as np
import os
import torch
from basicsr.utils import FileClient, imfrombytes
from collections import OrderedDict
# ---------------------------- This script is used to parse facial landmarks ------------------------------------- #
# Configurations
save_img = False
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
enlarge_ratio = 1.4 # only for eyes
json_path = 'ffhq-dataset-v2.json'
face_path = 'datasets/ffhq/ffhq_512.lmdb'
save_path = './FFHQ_eye_mouth_landmarks_512.pth'
print('Load JSON metadata...')
# use the official json file in FFHQ dataset
with open(json_path, 'rb') as f:
json_data = json.load(f, object_pairs_hook=OrderedDict)
print('Open LMDB file...')
# read ffhq images
file_client = FileClient('lmdb', db_paths=face_path)
with open(os.path.join(face_path, 'meta_info.txt')) as fin:
paths = [line.split('.')[0] for line in fin]
save_dict = {}
for item_idx, item in enumerate(json_data.values()):
print(f'\r{item_idx} / {len(json_data)}, {item["image"]["file_path"]} ', end='', flush=True)
# parse landmarks
lm = np.array(item['image']['face_landmarks'])
lm = lm * scale
item_dict = {}
# get image
if save_img:
img_bytes = file_client.get(paths[item_idx])
img = imfrombytes(img_bytes, float32=True)
# get landmarks for each component
map_left_eye = list(range(36, 42))
map_right_eye = list(range(42, 48))
map_mouth = list(range(48, 68))
# eye_left
mean_left_eye = np.mean(lm[map_left_eye], 0) # (x, y)
half_len_left_eye = np.max((np.max(np.max(lm[map_left_eye], 0) - np.min(lm[map_left_eye], 0)) / 2, 16))
item_dict['left_eye'] = [mean_left_eye[0], mean_left_eye[1], half_len_left_eye]
# mean_left_eye[0] = 512 - mean_left_eye[0] # for testing flip
half_len_left_eye *= enlarge_ratio
loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int)
if save_img:
eye_left_img = img[loc_left_eye[1]:loc_left_eye[3], loc_left_eye[0]:loc_left_eye[2], :]
cv2.imwrite(f'tmp/{item_idx:08d}_eye_left.png', eye_left_img * 255)
# eye_right
mean_right_eye = np.mean(lm[map_right_eye], 0)
half_len_right_eye = np.max((np.max(np.max(lm[map_right_eye], 0) - np.min(lm[map_right_eye], 0)) / 2, 16))
item_dict['right_eye'] = [mean_right_eye[0], mean_right_eye[1], half_len_right_eye]
# mean_right_eye[0] = 512 - mean_right_eye[0] # # for testing flip
half_len_right_eye *= enlarge_ratio
loc_right_eye = np.hstack(
(mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int)
if save_img:
eye_right_img = img[loc_right_eye[1]:loc_right_eye[3], loc_right_eye[0]:loc_right_eye[2], :]
cv2.imwrite(f'tmp/{item_idx:08d}_eye_right.png', eye_right_img * 255)
# mouth
mean_mouth = np.mean(lm[map_mouth], 0)
half_len_mouth = np.max((np.max(np.max(lm[map_mouth], 0) - np.min(lm[map_mouth], 0)) / 2, 16))
item_dict['mouth'] = [mean_mouth[0], mean_mouth[1], half_len_mouth]
# mean_mouth[0] = 512 - mean_mouth[0] # for testing flip
loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int)
if save_img:
mouth_img = img[loc_mouth[1]:loc_mouth[3], loc_mouth[0]:loc_mouth[2], :]
cv2.imwrite(f'tmp/{item_idx:08d}_mouth.png', mouth_img * 255)
save_dict[f'{item_idx:08d}'] = item_dict
print('Save...')
torch.save(save_dict, save_path)
@@ -0,0 +1,110 @@
check_points/
pretrain_models*
test_dir_enhance_results/
test_dir_align_results/
test_unalign_results/
tmp*
local*
*.pth
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
@@ -0,0 +1,445 @@
PSFR-GAN (c) by Chaofeng Chen
PSFR-GAN is licensed under a
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
You should have received a copy of the license along with this
work. If not, see <http://creativecommons.org/licenses/by-nc-sa/4.0/>.
Attribution-NonCommercial-ShareAlike 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More_considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
Public License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International Public License
("Public License"). To the extent this Public License may be
interpreted as a contract, You are granted the Licensed Rights in
consideration of Your acceptance of these terms and conditions, and the
Licensor grants You such rights in consideration of benefits the
Licensor receives from making the Licensed Material available under
these terms and conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. BY-NC-SA Compatible License means a license listed at
creativecommons.org/compatiblelicenses, approved by Creative
Commons as essentially the equivalent of this Public License.
d. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
e. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
f. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
g. License Elements means the license attributes listed in the name
of a Creative Commons Public License. The License Elements of this
Public License are Attribution, NonCommercial, and ShareAlike.
h. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
i. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
j. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
k. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
l. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
m. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
n. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. Additional offer from the Licensor -- Adapted Material.
Every recipient of Adapted Material from You
automatically receives an offer from the Licensor to
exercise the Licensed Rights in the Adapted Material
under the conditions of the Adapter's License You apply.
c. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
b. ShareAlike.
In addition to the conditions in Section 3(a), if You Share
Adapted Material You produce, the following conditions also apply.
1. The Adapter's License You apply must be a Creative Commons
license with the same License Elements, this version or
later, or a BY-NC-SA Compatible License.
2. You must include the text of, or the URI or hyperlink to, the
Adapter's License You apply. You may satisfy this condition
in any reasonable manner based on the medium, means, and
context in which You Share Adapted Material.
3. You may not offer or impose any additional or different terms
or conditions on, or apply any Effective Technological
Measures to, Adapted Material that restrict exercise of the
rights granted under the Adapter's License You apply.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material,
including for purposes of Section 3(b); and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.
@@ -0,0 +1,123 @@
# PSFR-GAN in PyTorch
[Progressive Semantic-Aware Style Transformation for Blind Face Restoration](https://arxiv.org/abs/2009.08709)
[Chaofeng Chen](https://chaofengc.github.io), [Xiaoming Li](https://csxmli2016.github.io/), [Lingbo Yang](https://lotayou.github.io), [Xianhui Lin](https://dblp.org/pid/147/7708.html), [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang/), [Kwan-Yee K. Wong](https://i.cs.hku.hk/~kykwong/)
![](test_dir/test_hzgg.jpg)
![](test_hzgg_results/hq_final.jpg)
### Changelog
- **2021.04.26**: Add pytorch vgg19 model to GoogleDrive and remove `--distributed` option which causes training error.
- **2021.03.22**: Update new model at 15 epoch (52.5k iterations).
- **2021.03.19**: Add train codes for PSFRGAN and FPN.
## Prerequisites and Installation
- Ubuntu 18.04
- CUDA 10.1
- Clone this repository
```
git clone https://github.com/chaofengc/PSFR-GAN.git
cd PSFR-GAN
```
- Python 3.7, install required packages by `pip3 install -r requirements.txt`
## Quick Test
### Download Pretrain Models and Dataset
Download the pretrained models from the following link and put them to `./pretrain_models`
- [GoogleDrive](https://drive.google.com/drive/folders/1Ubejhxd2xd4fxGc_M_LWl3Ux6CgQd9rP?usp=sharing)
- [BaiduNetDisk](https://pan.baidu.com/s/1cru3uUASEfGX6p6L0_7gWQ), extract code: `gj2r`
### Test single image
Run the following script to enhance face(s) in single input
```
python test_enhance_single_unalign.py --test_img_path ./test_dir/test_hzgg.jpg --results_dir test_hzgg_results --gpus 1
```
This script do the following things:
- Crop and align all the faces from input image, stored at `results_dir/LQ_faces`
- Parse these faces and then enhance them, results stored at `results_dir/ParseMaps` and `results_dir/HQ`
- Paste then enhanced faces back to the original image `results_dir/hq_final.jpg`
- You can use `--gpus` to specify how many GPUs to use, `<=0` means running on CPU. The program will use GPU with the most available memory. Set `CUDA_VISIBLE_DEVICE` to specify the GPU if you do not want automatic GPU selection.
### Test image folder
To test multiple images, we first crop out all the faces and align them use the following script.
```
python align_and_crop_dir.py --src_dir test_dir --results_dir test_dir_align_results
```
For images (*e.g.* `multiface_test.jpg`) contain multiple faces, the aligned faces will be stored as `multiface_test_{face_index}.jpg`
And then parse the aligned faces and enhance them with
```
python test_enhance_dir_align.py --src_dir test_dir_align_results --results_dir test_dir_enhance_results
```
Results will be saved to three folders respectively: `results_dir/lq`, `results_dir/parse`, `results_dir/hq`.
### Additional test script
For your convenience, we also provide script to test multiple unaligned images and paste the enhance results back. **Note the paste back operation could be quite slow for large size images containing many faces (dlib takes time to detect faces in large image).**
```
python test_enhance_dir_unalign.py --src_dir test_dir --results_dir test_unalign_results
```
This script basically do the same thing as `test_enhance_single_unalign.py` for each image in `src_dir`
## Train the Model
### Data Preparation
- Download [FFHQ](https://github.com/NVlabs/ffhq-dataset) and put the images to `../datasets/FFHQ/imgs1024`
- Download parsing masks (`512x512`) [HERE](https://drive.google.com/file/d/1eQwO8hKcaluyCnxuZAp0eJVOdgMi30uA/view?usp=sharing) generated by the pretrained FPN and put them to `../datasets/FFHQ/masks512`.
*Note: you may change `../datasets/FFHQ` to your own path. But images and masks must be stored under `your_own_path/imgs1024` and `your_own_path/masks512` respectively.*
### Train Script for PSFRGAN
Here is an example train script for PSFRGAN:
```
python train.py --gpus 2 --model enhance --name PSFRGAN_v001 \
--g_lr 0.0001 --d_lr 0.0004 --beta1 0.5 \
--gan_mode 'hinge' --lambda_pix 10 --lambda_fm 10 --lambda_ss 1000 \
--Dinput_nc 22 --D_num 3 --n_layers_D 4 \
--batch_size 2 --dataset ffhq --dataroot ../datasets/FFHQ \
--visual_freq 100 --print_freq 10 #--continue_train
```
- Please change the `--name` option for different experiments. Tensorboard records with the same name will be moved to `check_points/log_archive`, and the weight directory will only store weight history of latest experiment with the same name.
- `--gpus` specify number of GPUs used to train. The script will use GPUs with more available memory first. To specify the GPU index, use `export CUDA_VISIBLE_DEVICES=your_gpu_ids` before the script.
- Uncomment `--continue_train` to resume training. *Current codes do not resume the optimizer state.*
- It needs at least **8GB** memory to train with **batch_size=1**.
### Scripts for FPN
You may also train your own FPN and generate masks for the HQ images by yourself with the following steps:
- Download [CelebAHQ-Mask](https://github.com/switchablenorms/CelebAMask-HQ) dataset. Generate `CelebAMask-HQ-mask` and `CelebAMask-HQ-mask-color` with the provided scripts in `CelebAMask-HQ/face_parsing/Data_preprocessing/`.
- Train FPN with the following commmand
```
python train.py --gpus 1 --model parse --name FPN_v001 \
--lr 0.0002 --batch_size 8 \
--dataset celebahqmask --dataroot ../datasets/CelebAMask-HQ \
--visual_freq 100 --print_freq 10 #--continue_train
```
- Generate parsing masks with your own FPN using the following command:
```
python generate_masks.py --save_masks_dir ../datasets/FFHQ/masks512 --batch_size 8 --parse_net_weight path/to/your/own/FPN
```
## Citation
```
@inproceedings{ChenPSFRGAN,
author = {Chen, Chaofeng and Li, Xiaoming and Lingbo, Yang and Lin, Xianhui and Zhang, Lei and Wong, KKY},
title = {Progressive Semantic-Aware Style Transformation for Blind Face Restoration},
Journal = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2021}
}
```
## License
<a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc-sa/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/">Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License</a>.
## Acknowledgement
This work is inspired by [SPADE](https://github.com/NVlabs/SPADE), and closed related to [DFDNet](https://github.com/csxmli2016/DFDNet) and [HiFaceGAN](https://github.com/Lotayou/Face-Renovation). Our codes largely benefit from [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
@@ -0,0 +1,86 @@
import dlib
import os
import cv2
import numpy as np
from tqdm import tqdm
from skimage import transform as trans
from skimage import io
import argparse
def get_points(img, detector, shape_predictor, size_threshold=999):
dets = detector(img, 1)
if len(dets) == 0:
return None
all_points = []
for det in dets:
if isinstance(detector, dlib.cnn_face_detection_model_v1):
rec = det.rect # for cnn detector
else:
rec = det
if rec.width() > size_threshold or rec.height() > size_threshold:
break
shape = shape_predictor(img, rec)
single_points = []
for i in range(5):
single_points.append([shape.part(i).x, shape.part(i).y])
all_points.append(np.array(single_points))
if len(all_points) <= 0:
return None
else:
return all_points
def align_and_save(img, save_path, src_points, template_path, template_scale=1):
out_size = (512, 512)
reference = np.load(template_path) / template_scale
ext = os.path.splitext(save_path)
for idx, spoint in enumerate(src_points):
tform = trans.SimilarityTransform()
tform.estimate(spoint, reference)
M = tform.params[0:2,:]
crop_img = cv2.warpAffine(img, M, out_size)
if len(src_points) > 1:
save_path = ext[0] + '_{}'.format(idx) + ext[1]
dlib.save_image(crop_img.astype(np.uint8), save_path)
print('Saving image', save_path)
def align_and_save_dir(src_dir, save_dir, template_path='./pretrain_models/FFHQ_template.npy', template_scale=2, use_cnn_detector=True):
out_size = (512, 512)
if use_cnn_detector:
detector = dlib.cnn_face_detection_model_v1('./pretrain_models/mmod_human_face_detector.dat')
else:
detector = dlib.get_frontal_face_detector()
sp = dlib.shape_predictor('./pretrain_models/shape_predictor_5_face_landmarks.dat')
for name in os.listdir(src_dir):
img_path = os.path.join(src_dir, name)
img = dlib.load_rgb_image(img_path)
points = get_points(img, detector, sp)
if points is not None:
save_path = os.path.join(save_dir, name)
align_and_save(img, save_path, points, template_path, template_scale)
else:
print('No face detected in', img_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--src_dir', type=str, help='source directory containing images to crop and align.')
parser.add_argument('--results_dir', type=str, help='results directory to save the aligned faces.')
parser.add_argument('--not_use_cnn_detector', action='store_true', help='do not use cnn face detector in dlib.')
args = parser.parse_args()
src_dir = args.src_dir
assert os.path.isdir(src_dir), 'Source path should be a directory containing images'
save_dir = args.results_dir
if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
align_and_save_dir(src_dir, save_dir, use_cnn_detector=not args.not_use_cnn_detector)
@@ -0,0 +1,94 @@
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt)
dataset = data_loader.load_data()
return dataset
class CustomDatasetDataLoader():
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
def __init__(self, opt):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.
"""
self.opt = opt
dataset_class = find_dataset_using_name(opt.dataset_name)
self.dataset = dataset_class(opt)
print("dataset [%s] was created" % type(self.dataset).__name__)
drop_last = True if opt.isTrain else False
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads), drop_last=drop_last)
def load_data(self):
return self
def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data
@@ -0,0 +1,162 @@
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod
import imgaug as ia
import imgaug.augmenters as iaa
class BaseDataset(data.Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
To create a subclass, you need to implement the following four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the class; save the options in the class
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.opt = opt
self.root = opt.dataroot
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess == 'resize_and_crop':
new_h = new_w = opt.load_size
elif opt.preprocess == 'scale_width_and_crop':
new_w = opt.load_size
new_h = opt.load_size * h // w
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
transform_list = []
if grayscale:
# transform_list.append(transforms.Grayscale(1))
from util import util
transform_list.append(util.RGBtoY)
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
if 'crop' in opt.preprocess:
if params is None:
transform_list.append(transforms.RandomCrop(opt.crop_size))
else:
if 'crop_size' in params:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], params['crop_size'])))
else:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
if opt.preprocess == 'none':
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
if not opt.no_flip:
if params is None:
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
if convert:
transform_list += [transforms.ToTensor()]
if grayscale:
transform_list += [transforms.Normalize((0.5,), (0.5,))]
else:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
__print_size_warning(ow, oh, w, h)
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
def __print_size_warning(ow, oh, w, h):
"""Print warning information about image size(only print once)"""
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True
@@ -0,0 +1,60 @@
import os
import random
import numpy as np
from PIL import Image
import imgaug as ia
import imgaug.augmenters as iaa
from data.image_folder import make_dataset
import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from data.base_dataset import BaseDataset
from utils.utils import onehot_parse_map
from data.ffhq_dataset import complex_imgaug, random_gray
class CelebAHQMaskDataset(BaseDataset):
def __init__(self, opt):
BaseDataset.__init__(self, opt)
self.img_size = opt.Pimg_size
self.lr_size = opt.Gin_size
self.hr_size = opt.Gout_size
self.shuffle = True if opt.isTrain else False
self.img_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'CelebA-HQ-img')))
self.mask_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'CelebAMask-HQ-mask')))
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def __len__(self,):
return len(self.img_dataset)
def __getitem__(self, idx):
sample = {}
img_path = self.img_dataset[idx]
mask_path = self.mask_dataset[idx]
hr_img = Image.open(img_path).convert('RGB')
mask_img = Image.open(mask_path)
hr_img = hr_img.resize((self.hr_size, self.hr_size))
hr_img = random_gray(hr_img, p=0.3)
scale_size = np.random.randint(32, 256)
lr_img = complex_imgaug(hr_img, self.img_size, scale_size)
mask_img = mask_img.resize((self.hr_size, self.hr_size))
mask_label = torch.tensor(np.array(mask_img)).long()
hr_tensor = self.to_tensor(hr_img)
lr_tensor = self.to_tensor(lr_img)
return {'HR': hr_tensor, 'LR': lr_tensor, 'HR_paths': img_path, 'Mask': mask_label}
@@ -0,0 +1,88 @@
import os
import random
import numpy as np
from PIL import Image
import imgaug as ia
import imgaug.augmenters as iaa
from data.image_folder import make_dataset
import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from data.base_dataset import BaseDataset
from utils.utils import onehot_parse_map
class FFHQDataset(BaseDataset):
def __init__(self, opt):
BaseDataset.__init__(self, opt)
self.img_size = opt.Pimg_size
self.lr_size = opt.Gin_size
self.hr_size = opt.Gout_size
self.shuffle = True if opt.isTrain else False
self.img_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'imgs1024')))
self.mask_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'masks512')))
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.random_crop = transforms.RandomCrop(self.hr_size)
def __len__(self,):
return len(self.img_dataset)
def __getitem__(self, idx):
sample = {}
img_path = self.img_dataset[idx]
mask_path = self.mask_dataset[idx]
hr_img = Image.open(img_path).convert('RGB')
mask_img = Image.open(mask_path).convert('RGB')
hr_img = hr_img.resize((self.hr_size, self.hr_size))
hr_img = random_gray(hr_img, p=0.3)
scale_size = np.random.randint(32, 256)
lr_img = complex_imgaug(hr_img, self.img_size, scale_size)
mask_img = mask_img.resize((self.hr_size, self.hr_size))
mask_label = onehot_parse_map(mask_img)
mask_label = torch.tensor(mask_label).float()
hr_tensor = self.to_tensor(hr_img)
lr_tensor = self.to_tensor(lr_img)
return {'HR': hr_tensor, 'LR': lr_tensor, 'HR_paths': img_path, 'Mask': mask_label}
def complex_imgaug(x, org_size, scale_size):
"""input single RGB PIL Image instance"""
x = np.array(x)
x = x[np.newaxis, :, :, :]
aug_seq = iaa.Sequential([
iaa.Sometimes(0.5, iaa.OneOf([
iaa.GaussianBlur((3, 15)),
iaa.AverageBlur(k=(3, 15)),
iaa.MedianBlur(k=(3, 15)),
iaa.MotionBlur((5, 25))
])),
iaa.Resize(scale_size, interpolation=ia.ALL),
iaa.Sometimes(0.2, iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.1*255), per_channel=0.5)),
iaa.Sometimes(0.7, iaa.JpegCompression(compression=(10, 65))),
iaa.Resize(org_size),
])
aug_img = aug_seq(images=x)
return aug_img[0]
def random_gray(x, p=0.5):
"""input single RGB PIL Image instance"""
x = np.array(x)
x = x[np.newaxis, :, :, :]
aug = iaa.Sometimes(p, iaa.Grayscale(alpha=1.0))
aug_img = aug(images=x)
return aug_img[0]
@@ -0,0 +1,67 @@
"""A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
@@ -0,0 +1,43 @@
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import numpy as np
class SingleDataset(BaseDataset):
"""This dataset class can load a set of images specified by the path --dataroot /path/to/data.
It can be used for generating CycleGAN results only for one side with the model option '-model test'.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.A_paths = sorted(make_dataset(opt.src_dir, opt.max_dataset_size))
input_nc = self.opt.output_nc
self.transform = get_transform(opt, grayscale=(input_nc == 1))
self.opt = opt
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns a dictionary that contains A and A_paths
A(tensor) - - an image in one domain
A_paths(str) - - the path of the image
"""
A_path = self.A_paths[index]
A_img = Image.open(A_path).convert('RGB')
A_img = A_img.resize((512, 512), Image.BICUBIC)
A = self.transform(A_img)
return {'LR': A, 'LR_paths': A_path}
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.A_paths)
@@ -0,0 +1,92 @@
import os
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from utils import utils
from PIL import Image
from tqdm import tqdm
import torch
import time
import numpy as np
import cv2
import glob
from torchvision.transforms import transforms
if __name__ == '__main__':
opt = TestOptions()
opt = opt.parse() # get test options
opt.num_threads = 0 # test code only supports num_threads = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
model = create_model(opt) # create a model given opt.model and other options
model.load_pretrain_models()
netP = model.netP
model.eval()
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
image_dir = "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan/"
output_dir = "G:/VGGFace2-HQ/VGGface2_HQ_original_aligned_mask"
temp_path = os.path.join(image_dir,'*/')
pathes = glob.glob(temp_path)
dataset = []
for dir_item in pathes:
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
print("processing %s"%dir_item,end='\r')
temp_list = []
for item in join_path:
temp_list.append(item)
dataset.append(temp_list)
# ------------------------ restore ------------------------
for i_dir in dataset:
path = os.path.dirname(i_dir[0])
dir_name = os.path.join(output_dir, os.path.basename(path))
if not os.path.exists(dir_name):
os.makedirs(dir_name)
for img_path in i_dir:
hr_img = Image.open(img_path).convert('RGB')
inp = to_tensor(hr_img).unsqueeze(0)
with torch.no_grad():
parse_map, _ = netP(inp)
parse_map_sm = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float()
ref_parse_img = utils.color_parse_map(parse_map_sm)
img_name = os.path.basename(img_path)
basename, ext = os.path.splitext(img_name)
save_face_name = f'{basename}.png'
# print(save_face_name)
save_path = os.path.join(dir_name, save_face_name)
# os.makedirs(opt.save_masks_dir, exist_ok=True)
img = cv2.cvtColor(ref_parse_img[0],cv2.COLOR_RGB2GRAY)
cv2.imwrite(save_path,img)
# for i, data in tqdm(enumerate(dataset), total=len(dataset)//opt.batch_size):
# inp = data['LR']
# with torch.no_grad():
# parse_map, _ = netP(inp)
# parse_map_sm = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float()
# img_path = data['LR_paths'] # get image paths
# ref_parse_img = utils.color_parse_map(parse_map_sm)
# for i in range(len(img_path)):
# img_name = os.path.basename(img_path[i])
# basename, ext = os.path.splitext(img_name)
# save_face_name = f'{basename}.png'
# # print(save_face_name)
# save_path = os.path.join(opt.save_masks_dir, save_face_name)
# os.makedirs(opt.save_masks_dir, exist_ok=True)
# img = cv2.cvtColor(ref_parse_img[i],cv2.COLOR_RGB2GRAY)
# cv2.imwrite(save_path,img)
# save_img = Image.fromarray(ref_parse_img[i])
# save_img.save(save_path)
@@ -0,0 +1,67 @@
"""This package contains modules related to objective functions, optimizations, and network architectures.
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
In the function <__init__>, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""
import importlib
from models.base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
return instance
@@ -0,0 +1,248 @@
import os
import torch
from collections import OrderedDict
from abc import ABC, abstractmethod
from . import networks
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): specify the images that you want to display and save.
-- self.visual_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
self.loss_names = []
self.model_names = []
self.visual_names = []
self.optimizers = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
@abstractmethod
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def setup(self, opt):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
self.load_networks(load_suffix)
self.print_networks(opt.verbose)
def eval(self):
"""Make models eval mode during test time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
pass
def get_image_paths(self):
""" Return image paths that are used to load current data"""
return self.image_paths
def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.metric)
else:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
def get_lr(self,):
lrs = {}
for idx, p in enumerate(self.optimizers):
lrs['LR{}'.format(idx)] = p.param_groups[0]['lr']
return lrs
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret['Loss_' + name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
return errors_ret
def save_networks(self, epoch, info=None):
"""Save all the networks to the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
print('Model saved in:', save_filename)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
if info is not None:
torch.save(info, os.path.join(self.save_dir, '%s.info' % epoch))
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def load_networks(self, epoch):
"""Load all the networks from the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.load_model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel) or isinstance(net, torch.nn.parallel.DistributedDataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
map_location = str(self.device)
state_dict = torch.load(load_path, map_location=map_location)
# patch InstanceNorm checkpoints prior to 0.4
# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
# net.load_state_dict(state_dict)
if not self.opt.no_strict_load:
net.load_state_dict(state_dict)
# Load partial weights
else:
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict, strict=False)
info_path = os.path.join(self.save_dir, '%s.info' % epoch)
if os.path.exists(info_path):
info_dict = torch.load(info_path)
for k, v in info_dict.items():
setattr(self.opt, k, v)
def print_networks(self, verbose):
"""Print the total number of parameters in the network and (if verbose) network architecture
Parameters:
verbose (bool) -- if verbose: print the network architecture
"""
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn import functional as F
import numpy as np
class NormLayer(nn.Module):
"""Normalization Layers.
------------
# Arguments
- channels: input channels, for batch norm and instance norm.
- input_size: input shape without batch size, for layer norm.
"""
def __init__(self, channels, normalize_shape=None, norm_type='bn'):
super(NormLayer, self).__init__()
norm_type = norm_type.lower()
self.norm_type = norm_type
self.channels = channels
if norm_type == 'bn':
self.norm = nn.BatchNorm2d(channels, affine=True)
elif norm_type == 'in':
self.norm = nn.InstanceNorm2d(channels, affine=False)
elif norm_type == 'gn':
self.norm = nn.GroupNorm(32, channels, affine=True)
elif norm_type == 'pixel':
self.norm = lambda x: F.normalize(x, p=2, dim=1)
elif norm_type == 'layer':
self.norm = nn.LayerNorm(normalize_shape)
elif norm_type == 'none':
self.norm = lambda x: x*1.0
else:
assert 1==0, 'Norm type {} not support.'.format(norm_type)
def forward(self, x, ref=None):
return self.norm(x)
class ReluLayer(nn.Module):
"""Relu Layer.
------------
# Arguments
- relu type: type of relu layer, candidates are
- ReLU
- LeakyReLU: default relu slope 0.2
- PRelu
- SELU
- none: direct pass
"""
def __init__(self, channels, relu_type='relu'):
super(ReluLayer, self).__init__()
relu_type = relu_type.lower()
if relu_type == 'relu':
self.func = nn.ReLU(True)
elif relu_type == 'leakyrelu':
self.func = nn.LeakyReLU(0.2, inplace=True)
elif relu_type == 'prelu':
self.func = nn.PReLU(channels)
elif relu_type == 'selu':
self.func = nn.SELU(True)
elif relu_type == 'none':
self.func = lambda x: x*1.0
else:
assert 1==0, 'Relu type {} not support.'.format(relu_type)
def forward(self, x):
return self.func(x)
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, scale='none', norm_type='none', relu_type='none', use_pad=True, bias=True):
super(ConvLayer, self).__init__()
self.use_pad = use_pad
self.norm_type = norm_type
self.in_channels = in_channels
if norm_type in ['bn']:
bias = False
stride = 2 if scale == 'down' else 1
self.scale = scale
self.scale_func = lambda x: x
if scale == 'up':
self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.)/2)))
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
self.avgpool = nn.AvgPool2d(2, 2)
self.relu = ReluLayer(out_channels, relu_type)
self.norm = NormLayer(out_channels, norm_type=norm_type)
def forward(self, x):
out = self.scale_func(x)
if self.use_pad:
out = self.reflection_pad(out)
out = self.conv2d(out)
if self.scale == 'down_avg':
out = self.avgpool(out)
out = self.norm(out)
out = self.relu(out)
return out
class ResidualBlock(nn.Module):
"""
Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
"""
def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
super(ResidualBlock, self).__init__()
if scale == 'none' and c_in == c_out:
self.shortcut_func = lambda x: x
else:
self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
scale_conf = scale_config_dict[scale]
self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
def forward(self, x):
identity = self.shortcut_func(x)
res = self.conv1(x)
res = self.conv2(res)
return identity + res
@@ -0,0 +1,157 @@
import os
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.optim as optim
from models import loss
from models import networks
from .base_model import BaseModel
from utils import utils
class EnhanceModel(BaseModel):
def modify_commandline_options(parser, is_train):
if is_train:
parser.add_argument('--parse_net_weight', type=str, default='./pretrain_models/parse_multi_iter_90000.pth', help='parse model path')
parser.add_argument('--lambda_pix', type=float, default=10.0, help='weight for parsing map')
parser.add_argument('--lambda_pcp', type=float, default=0.0, help='weight for vgg perceptual loss')
parser.add_argument('--lambda_fm', type=float, default=10.0, help='weight for sr')
parser.add_argument('--lambda_g', type=float, default=1.0, help='weight for sr')
parser.add_argument('--lambda_ss', type=float, default=1000., help='weight for global style')
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.netP = networks.define_P(opt, weight_path=opt.parse_net_weight)
self.netG = networks.define_G(opt, use_norm='spectral_norm')
if self.isTrain:
self.netD = networks.define_D(opt, opt.Dinput_nc, use_norm='spectral_norm')
self.vgg_model = loss.PCPFeat(weight_path='./pretrain_models/vgg19-dcbb9e9d.pth').to(opt.device)
if len(opt.gpu_ids) > 0:
self.vgg_model = torch.nn.DataParallel(self.vgg_model, opt.gpu_ids, output_device=opt.device)
self.model_names = ['G']
self.loss_names = ['Pix', 'PCP', 'G', 'FM', 'D', 'SS'] # Generator loss, fm loss, parsing loss, discriminator loss
self.visual_names = ['img_LR', 'img_HR', 'img_SR', 'ref_Parse', 'hr_mask']
self.fm_weights = [1**x for x in range(opt.D_num)]
if self.isTrain:
self.model_names = ['G', 'D']
self.load_model_names = ['G', 'D']
self.criterionParse = torch.nn.CrossEntropyLoss().to(opt.device)
self.criterionFM = loss.FMLoss().to(opt.device)
self.criterionGAN = loss.GANLoss(opt.gan_mode).to(opt.device)
self.criterionPCP = loss.PCPLoss(opt)
self.criterionPix= nn.L1Loss()
self.criterionRS = loss.RegionStyleLoss()
self.optimizer_G = optim.Adam([p for p in self.netG.parameters() if p.requires_grad], lr=opt.g_lr, betas=(opt.beta1, 0.999))
self.optimizer_D = optim.Adam([p for p in self.netD.parameters() if p.requires_grad], lr=opt.d_lr, betas=(opt.beta1, 0.999))
self.optimizers = [self.optimizer_G, self.optimizer_D]
def eval(self):
self.netG.eval()
self.netP.eval()
def load_pretrain_models(self,):
self.netP.eval()
print('Loading pretrained LQ face parsing network from', self.opt.parse_net_weight)
if len(self.opt.gpu_ids) > 0:
self.netP.module.load_state_dict(torch.load(self.opt.parse_net_weight))
else:
self.netP.load_state_dict(torch.load(self.opt.parse_net_weight))
self.netG.eval()
print('Loading pretrained PSFRGAN from', self.opt.psfr_net_weight)
if len(self.opt.gpu_ids) > 0:
self.netG.module.load_state_dict(torch.load(self.opt.psfr_net_weight), strict=False)
else:
self.netG.load_state_dict(torch.load(self.opt.psfr_net_weight), strict=False)
def set_input(self, input, cur_iters=None):
self.cur_iters = cur_iters
self.img_LR = input['LR'].to(self.opt.device)
self.img_HR = input['HR'].to(self.opt.device)
self.hr_mask = input['Mask'].to(self.opt.device)
if self.opt.debug:
print('SRNet input shape:', self.img_LR.shape, self.img_HR.shape)
def forward(self):
with torch.no_grad():
ref_mask, _ = self.netP(self.img_LR)
self.ref_mask_onehot = (ref_mask == ref_mask.max(dim=1, keepdim=True)[0]).float().detach()
if self.opt.debug:
print('SRNet reference mask shape:', self.ref_mask_onehot.shape)
self.img_SR = self.netG(self.img_LR, self.ref_mask_onehot)
self.real_D_results = self.netD(torch.cat((self.img_HR, self.hr_mask), dim=1), return_feat=True)
self.fake_D_results = self.netD(torch.cat((self.img_SR.detach(), self.hr_mask), dim=1), return_feat=False)
self.fake_G_results = self.netD(torch.cat((self.img_SR, self.hr_mask), dim=1), return_feat=True)
self.img_SR_feats = self.vgg_model(self.img_SR)
self.img_HR_feats = self.vgg_model(self.img_HR)
def backward_G(self):
# Pix Loss
self.loss_Pix = self.criterionPix(self.img_SR, self.img_HR) * self.opt.lambda_pix
# semantic style loss
self.loss_SS = self.criterionRS(self.img_SR_feats, self.img_HR_feats, self.hr_mask) * self.opt.lambda_ss
# perceptual loss
self.loss_PCP = self.criterionPCP(self.img_SR_feats, self.img_HR_feats) * self.opt.lambda_pcp
# Feature matching loss
tmp_loss = 0
for i, w in zip(range(self.opt.D_num), self.fm_weights):
tmp_loss = tmp_loss + self.criterionFM(self.fake_G_results[i][1], self.real_D_results[i][1]) * w
self.loss_FM = tmp_loss * self.opt.lambda_fm / self.opt.D_num
# Generator loss
tmp_loss = 0
for i in range(self.opt.D_num):
tmp_loss = tmp_loss + self.criterionGAN(self.fake_G_results[i][0], True, for_discriminator=False)
self.loss_G = tmp_loss * self.opt.lambda_g / self.opt.D_num
total_loss = self.loss_Pix + self.loss_PCP + self.loss_FM + self.loss_G + self.loss_SS
total_loss.backward()
def backward_D(self, ):
self.loss_D = 0
for i in range(self.opt.D_num):
self.loss_D += 0.5 * (self.criterionGAN(self.fake_D_results[i], False) + self.criterionGAN(self.real_D_results[i][0], True))
self.loss_D /= self.opt.D_num
self.loss_D.backward()
def optimize_parameters(self, ):
# ---- Update G ------------
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()
# ---- Update D ------------
self.optimizer_D.zero_grad()
self.backward_D()
self.optimizer_D.step()
def get_current_visuals(self, size=512):
out = []
visual_imgs = []
out.append(utils.tensor_to_numpy(self.img_LR))
out.append(utils.tensor_to_numpy(self.img_SR))
out.append(utils.tensor_to_numpy(self.img_HR))
out_imgs = [utils.batch_numpy_to_image(x, size) for x in out]
visual_imgs += out_imgs
visual_imgs.append(utils.color_parse_map(self.ref_mask_onehot, size))
visual_imgs.append(utils.color_parse_map(self.hr_mask, size))
return visual_imgs
@@ -0,0 +1,224 @@
import torch
from torchvision import models
from utils import utils
from torch import nn
def tv_loss(x):
"""
Total Variation Loss.
"""
return torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])
) + torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
class PCPFeat(torch.nn.Module):
"""
Features used to calculate Perceptual Loss based on ResNet50 features.
Input: (B, C, H, W), RGB, [0, 1]
"""
def __init__(self, weight_path, model='vgg'):
super(PCPFeat, self).__init__()
if model == 'vgg':
self.model = models.vgg19(pretrained=False)
self.build_vgg_layers()
elif model == 'resnet':
self.model = models.resnet50(pretrained=False)
self.build_resnet_layers()
self.model.load_state_dict(torch.load(weight_path))
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def build_resnet_layers(self):
self.layer1 = torch.nn.Sequential(
self.model.conv1,
self.model.bn1,
self.model.relu,
self.model.maxpool,
self.model.layer1
)
self.layer2 = self.model.layer2
self.layer3 = self.model.layer3
self.layer4 = self.model.layer4
self.features = torch.nn.ModuleList(
[self.layer1, self.layer2, self.layer3, self.layer4]
)
def build_vgg_layers(self):
vgg_pretrained_features = self.model.features
self.features = []
feature_layers = [0, 3, 8, 17, 26, 35]
for i in range(len(feature_layers)-1):
module_layers = torch.nn.Sequential()
for j in range(feature_layers[i], feature_layers[i+1]):
module_layers.add_module(str(j), vgg_pretrained_features[j])
self.features.append(module_layers)
self.features = torch.nn.ModuleList(self.features)
def preprocess(self, x):
x = (x + 1) / 2
mean = torch.Tensor([0.485, 0.456, 0.406]).to(x)
std = torch.Tensor([0.229, 0.224, 0.225]).to(x)
mean = mean.view(1, 3, 1, 1)
std = std.view(1, 3, 1, 1)
x = (x - mean) / std
if x.shape[3] < 224:
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
return x
def forward(self, x):
x = self.preprocess(x)
features = []
for m in self.features:
x = m(x)
features.append(x)
return features
class PCPLoss(torch.nn.Module):
"""Perceptual Loss.
"""
def __init__(self,
opt,
layer=5,
model='vgg',
):
super(PCPLoss, self).__init__()
self.mse = torch.nn.MSELoss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x_feats, y_feats):
loss = 0
for xf, yf, w in zip(x_feats, y_feats, self.weights):
loss = loss + self.mse(xf, yf.detach()) * w
return loss
class FMLoss(nn.Module):
def __init__(self):
super().__init__()
self.mse = torch.nn.MSELoss()
def forward(self, x_feats, y_feats):
loss = 0
for xf, yf in zip(x_feats, y_feats):
loss = loss + self.mse(xf, yf.detach())
return loss
class GANLoss(nn.Module):
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == 'hinge':
pass
elif gan_mode in ['wgangp']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def get_target_tensor(self, prediction, target_is_real):
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def __call__(self, prediction, target_is_real, for_discriminator=True):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
the calculated loss.
"""
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'hinge':
if for_discriminator:
if target_is_real:
loss = nn.ReLU()(1 - prediction).mean()
else:
loss = nn.ReLU()(1 + prediction).mean()
else:
assert target_is_real, "The generator's hinge loss must be aiming for real"
loss = - prediction.mean()
return loss
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
class RegionStyleLoss(nn.Module):
def __init__(self, reg_num=19, eps=1e-8):
super().__init__()
self.reg_num = reg_num
self.eps = eps
self.mse = nn.MSELoss()
def __masked_gram_matrix(self, x, m):
b, c, h, w = x.shape
m = m.view(b, -1, h*w)
x = x.view(b, -1, h*w)
total_elements = m.sum(2) + self.eps
x = x * m
G = torch.bmm(x, x.transpose(1, 2))
return G / (c * total_elements.view(b, 1, 1))
def __layer_gram_matrix(self, x, mask):
b, c, h, w = x.shape
all_gm = []
for i in range(self.reg_num):
sub_mask = mask[:, i].unsqueeze(1)
gram_matrix = self.__masked_gram_matrix(x, sub_mask)
all_gm.append(gram_matrix)
return torch.stack(all_gm, dim=1)
def forward(self, x_feats, y_feats, mask):
loss = 0
for xf, yf in zip(x_feats[2:], y_feats[2:]):
tmp_mask = torch.nn.functional.interpolate(mask, xf.shape[2:])
xf_gm = self.__layer_gram_matrix(xf, tmp_mask)
yf_gm = self.__layer_gram_matrix(yf, tmp_mask)
tmp_loss = self.mse(xf_gm, yf_gm.detach())
loss = loss + tmp_loss
return loss
if __name__ == '__main__':
x = [
torch.randn(2, 64, 512, 512),
torch.randn(2, 128, 256, 256),
torch.randn(2, 256, 128, 128),
torch.randn(2, 512, 64, 64),
torch.randn(2, 512, 32, 32),
]
y = torch.randint(10, (2, 19, 512, 512)).float()
loss = RegionStyleLoss()
print(loss(x, x, y))
@@ -0,0 +1,254 @@
from models.blocks import *
import torch
from torch import nn
from torch.nn import init
from torch.optim import lr_scheduler
from utils import utils
import numpy as np
from models import psfrnet
import torch.nn.utils as tutils
from models.loss import PCPFeat
def apply_norm(net, weight_norm_type):
for m in net.modules():
if isinstance(m, nn.Conv2d):
if weight_norm_type.lower() == 'spectral_norm':
tutils.spectral_norm(m)
elif weight_norm_type.lower() == 'weight_norm':
tutils.weight_norm(m)
else:
pass
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def define_P(opt, in_size=512, out_size=512, min_feat_size=32, relu_type='LeakyReLU', isTrain=True, weight_path=None):
net = ParseNet(in_size, out_size, min_feat_size, 64, 19, norm_type=opt.Pnorm, relu_type=relu_type, ch_range=[32, 256])
if not isTrain:
net.eval()
if weight_path is not None:
net.load_state_dict(torch.load(weight_path))
if len(opt.gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(opt.device)
net = torch.nn.DataParallel(net, opt.gpu_ids, output_device=opt.device)
return net
def define_G(opt, isTrain=True, use_norm='none', relu_type='LeakyReLU'):
net = psfrnet.PSFRGenerator(3, 3, in_size=opt.Gin_size, out_size=opt.Gout_size, relu_type=relu_type, parse_ch=19, norm_type=opt.Gnorm)
apply_norm(net, use_norm)
if not isTrain:
net.eval()
if len(opt.gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(opt.device)
net = torch.nn.DataParallel(net, opt.gpu_ids, output_device=opt.device)
# init_weights(net, init_type='normal', init_gain=0.02)
return net
def define_D(opt, in_channel=3, isTrain=True, use_norm='none'):
net = MultiScaleDiscriminator(in_channel, opt.ndf, opt.n_layers_D, opt.Dnorm, num_D=opt.D_num)
apply_norm(net, use_norm)
if not isTrain:
net.eval()
if len(opt.gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(opt.device)
net = torch.nn.DataParallel(net, opt.gpu_ids, output_device=opt.device)
init_weights(net, init_type='normal', init_gain=0.02)
return net
class ParseNet(nn.Module):
def __init__(self,
in_size=128,
out_size=128,
min_feat_size=32,
base_ch=64,
parsing_ch=19,
res_depth=10,
relu_type='prelu',
norm_type='bn',
ch_range=[32, 512],
):
super().__init__()
self.res_depth = res_depth
act_args = {'norm_type': norm_type, 'relu_type': relu_type}
min_ch, max_ch = ch_range
ch_clip = lambda x: max(min_ch, min(x, max_ch))
min_feat_size = min(in_size, min_feat_size)
down_steps = int(np.log2(in_size//min_feat_size))
up_steps = int(np.log2(out_size//min_feat_size))
# =============== define encoder-body-decoder ====================
self.encoder = []
self.encoder.append(ConvLayer(3, base_ch, 3, 1))
head_ch = base_ch
for i in range(down_steps):
cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
head_ch = head_ch * 2
self.body = []
for i in range(res_depth):
self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
self.decoder = []
for i in range(up_steps):
cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
head_ch = head_ch // 2
self.encoder = nn.Sequential(*self.encoder)
self.body = nn.Sequential(*self.body)
self.decoder = nn.Sequential(*self.decoder)
self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
def forward(self, x):
feat = self.encoder(x)
x = feat + self.body(feat)
x = self.decoder(x)
out_img = self.out_img_conv(x)
out_mask = self.out_mask_conv(x)
return out_mask, out_img
class MultiScaleDiscriminator(nn.Module):
def __init__(self, input_ch, base_ch=64, n_layers=3, norm_type='none', relu_type='LeakyReLU', num_D=4):
super().__init__()
self.D_pool = nn.ModuleList()
for i in range(num_D):
netD = NLayerDiscriminator(input_ch, base_ch, depth=n_layers, norm_type=norm_type, relu_type=relu_type)
self.D_pool.append(netD)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, input, return_feat=False):
results = []
for netd in self.D_pool:
output = netd(input, return_feat)
results.append(output)
# Downsample input
input = self.downsample(input)
return results
class NLayerDiscriminator(nn.Module):
def __init__(self,
input_ch = 3,
base_ch = 64,
max_ch = 1024,
depth = 4,
norm_type = 'none',
relu_type = 'LeakyReLU',
):
super().__init__()
nargs = {'norm_type': norm_type, 'relu_type': relu_type}
self.norm_type = norm_type
self.input_ch = input_ch
self.model = []
self.model.append(ConvLayer(input_ch, base_ch, norm_type='none', relu_type=relu_type))
for i in range(depth):
cin = min(base_ch * 2**(i), max_ch)
cout = min(base_ch * 2**(i+1), max_ch)
self.model.append(ConvLayer(cin, cout, scale='down_avg', **nargs))
self.model = nn.Sequential(*self.model)
self.score_out = ConvLayer(cout, 1, use_pad=False)
def forward(self, x, return_feat=False):
ret_feats = []
for idx, m in enumerate(self.model):
x = m(x)
ret_feats.append(x)
x = self.score_out(x)
if return_feat:
return x, ret_feats
else:
return x
@@ -0,0 +1,80 @@
import torch
from .base_model import BaseModel
from . import networks
from utils import utils
class ParseModel(BaseModel):
def modify_commandline_options(parser, is_train):
if is_train:
parser.add_argument('--parse_map', type=float, default=1.0, help='weight for parsing map')
parser.add_argument('--parse_sr', type=float, default=1.0, help='weight for sr')
return parser
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
self.loss_names = ['P', 'SR']
self.visual_names = ['img_LR', 'img_HR', 'gt_Parse', 'img_SR', 'pred_Parse']
self.model_names = ['P']
self.netP = networks.define_P(opt)
if self.isTrain: # only defined during training time
self.criterionParse = torch.nn.CrossEntropyLoss()
self.criterionSR = torch.nn.L1Loss()
self.optimizer = torch.optim.Adam(self.netP.parameters(), lr=opt.lr, betas=(0.9, 0.999))
self.optimizers = [self.optimizer]
def set_input(self, input, cur_iters=None):
self.img_LR = input['LR'].to(self.opt.device)
self.img_HR = input['HR'].to(self.opt.device)
self.gt_Parse = input['Mask'].to(self.opt.device)
if self.opt.debug:
print('ParseNet input shape:', self.img_LR.shape, self.img_HR.shape, self.gt_Parse.shape)
def load_pretrain_models(self,):
self.netP.eval()
print('Loading pretrained LQ face parsing network from', self.opt.parse_net_weight)
self.netP.load_state_dict(torch.load(self.opt.parse_net_weight))
def forward(self):
self.pred_Parse, self.img_SR = self.netP(self.img_LR)
if self.opt.debug:
print('ParseNet output shape', self.pred_Parse.shape, self.img_SR.shape)
def backward(self):
self.loss_P = self.criterionParse(self.pred_Parse, self.gt_Parse) * self.opt.parse_map
self.loss_SR = self.criterionSR(self.img_SR, self.img_HR) * self.opt.parse_sr
loss = self.loss_P + self.loss_SR
loss.backward()
def optimize_parameters(self):
self.optimizer.zero_grad() # clear network G's existing gradients
self.backward() # calculate gradients for network G
self.optimizer.step()
def get_current_visuals(self, size=512):
out = []
visual_imgs = []
out.append(utils.tensor_to_numpy(self.img_LR))
out.append(utils.tensor_to_numpy(self.img_SR))
out.append(utils.tensor_to_numpy(self.img_HR))
out_imgs = [utils.batch_numpy_to_image(x, size) for x in out]
visual_imgs.append(out_imgs[0])
visual_imgs.append(out_imgs[1])
visual_imgs.append(utils.color_parse_map(self.pred_Parse))
visual_imgs.append(utils.color_parse_map(self.gt_Parse.unsqueeze(1)))
visual_imgs.append(out_imgs[2])
return visual_imgs
@@ -0,0 +1,130 @@
import torch
import torch.nn as nn
from torch.nn import init
import numpy as np
from models.blocks import *
class SPADENorm(nn.Module):
def __init__(self, norm_nc, ref_nc, norm_type='spade', ksz=3):
super().__init__()
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
mid_c = 64
self.norm_type = norm_type
if norm_type == 'spade':
self.conv1 = nn.Sequential(
nn.Conv2d(ref_nc, mid_c, ksz, 1, ksz//2),
nn.LeakyReLU(0.2, True),
)
self.gamma_conv = nn.Conv2d(mid_c, norm_nc, ksz, 1, ksz//2)
self.beta_conv = nn.Conv2d(mid_c, norm_nc, ksz, 1, ksz//2)
def get_gamma_beta(self, x, conv, gamma_conv, beta_conv):
act = conv(x)
gamma = gamma_conv(act)
beta = beta_conv(act)
return gamma, beta
def forward(self, x, ref):
normalized_input = self.param_free_norm(x)
if x.shape[-1] != ref.shape[-1]:
ref = nn.functional.interpolate(ref, x.shape[2:], mode='bicubic', align_corners=False)
if self.norm_type == 'spade':
gamma, beta = self.get_gamma_beta(ref, self.conv1, self.gamma_conv, self.beta_conv)
return normalized_input * gamma + beta
elif self.norm_type == 'in':
return normalized_input
class SPADEResBlock(nn.Module):
def __init__(self, fin, fout, ref_nc, relu_type, norm_type='spade'):
super().__init__()
fmiddle = min(fin, fout)
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
# define normalization layers
self.norm_0 = SPADENorm(fmiddle, ref_nc, norm_type)
self.norm_1 = SPADENorm(fmiddle, ref_nc, norm_type)
self.relu = ReluLayer(fmiddle, relu_type)
def forward(self, x, ref):
res = self.conv_0(self.relu(self.norm_0(x, ref)))
res = self.conv_1(self.relu(self.norm_1(res, ref)))
out = x + res
return out
class PSFRGenerator(nn.Module):
def __init__(self, input_nc, output_nc, in_size=512, out_size=512, min_feat_size=16, ngf=64, n_blocks=9, parse_ch=19, relu_type='relu',
ch_range=[32, 1024], norm_type='spade'):
super().__init__()
min_ch, max_ch = ch_range
ch_clip = lambda x: max(min_ch, min(x, max_ch))
get_ch = lambda size: ch_clip(1024*16//size)
self.const_input = nn.Parameter(torch.randn(1, get_ch(min_feat_size), min_feat_size, min_feat_size))
up_steps = int(np.log2(out_size//min_feat_size))
self.up_steps = up_steps
ref_ch = 19+3
head_ch = get_ch(min_feat_size)
head = [
nn.Conv2d(head_ch, head_ch, kernel_size=3, padding=1),
SPADEResBlock(head_ch, head_ch, ref_ch, relu_type, norm_type),
]
body = []
for i in range(up_steps):
cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
body += [
nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(cin, cout, kernel_size=3, padding=1),
SPADEResBlock(cout, cout, ref_ch, relu_type, norm_type)
)
]
head_ch = head_ch // 2
self.img_out = nn.Conv2d(ch_clip(head_ch), output_nc, kernel_size=3, padding=1)
self.head = nn.Sequential(*head)
self.body = nn.Sequential(*body)
self.upsample = nn.Upsample(scale_factor=2)
def forward_spade(self, net, x, ref):
for m in net:
x = self.forward_spade_m(m, x, ref)
return x
def forward_spade_m(self, m, x, ref):
if isinstance(m, SPADENorm) or isinstance(m, SPADEResBlock):
x = m(x, ref)
else:
x = m(x)
return x
def forward(self, x, ref):
b, c, h, w = x.shape
const_input = self.const_input.repeat(b, 1, 1, 1)
ref_input = torch.cat((x, ref), dim=1)
feat = self.forward_spade(self.head, const_input, ref_input)
for idx, m in enumerate(self.body):
feat = self.forward_spade(m, feat, ref_input)
out_img = self.img_out(feat)
return out_img
if __name__ == '__main__':
x = torch.randn(2, 16, 567, 234)
nearest_interpolate(x)
@@ -0,0 +1 @@
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
@@ -0,0 +1,165 @@
import argparse
import os
import numpy as np
import random
from utils import utils
import torch
import models
import data
from utils import utils
class BaseOptions():
"""This class defines options used during both training and test time.
It also implements several helper functions such as parsing, printing, and saving the options.
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
"""
def __init__(self):
"""Reset the class; indicates the class hasn't been initailized"""
self.initialized = False
def initialize(self, parser):
"""Define the common options that are used in both training and test."""
# basic parameters
parser.add_argument('--dataroot', required=False, help='path to images')
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--gpus', type=int, default=1, help='how many gpus to use')
parser.add_argument('--seed', type=int, default=123, help='Random seed for training')
parser.add_argument('--checkpoints_dir', type=str, default='./check_points', help='models are saved here')
# model parameters
parser.add_argument('--model', type=str, default='enhance', help='chooses which model to train [parse|enhance]')
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--Dinput_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--n_layers_D', type=int, default=4, help='downsampling layers in discriminator')
parser.add_argument('--D_num', type=int, default=3, help='numbers of discriminators')
parser.add_argument('--Pnorm', type=str, default='bn', help='parsing net norm [in | bn| none]')
parser.add_argument('--Gnorm', type=str, default='spade', help='generator norm [in | bn | none]')
parser.add_argument('--Dnorm', type=str, default='in', help='discriminator norm [in | bn | none]')
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
# dataset parameters
parser.add_argument('--dataset_name', type=str, default='single', help='dataset name')
parser.add_argument('--Pimg_size', type=int, default='512', help='image size for face parse net')
parser.add_argument('--Gin_size', type=int, default='512', help='image size for face parse net')
parser.add_argument('--Gout_size', type=int, default='512', help='image size for face parse net')
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--num_threads', default=8, type=int, help='# threads for loading data')
parser.add_argument('--batch_size', type=int, default=16, help='input batch size')
parser.add_argument('--load_size', type=int, default=512, help='scale images to this size')
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--preprocess', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
# additional parameters
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
parser.add_argument('--debug', action='store_true', help='if specified, set to debug mode')
self.initialized = True
return parser
def gather_options(self):
"""Initialize our parser with basic options(only once).
Add additional model-specific and dataset-specific options.
These options are defined in the <modify_commandline_options> function
in model and dataset classes.
"""
if not self.initialized: # check if it has been initialized
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
opt, _ = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
opt, _ = parser.parse_known_args() # parse again with new defaults
# modify dataset-related parser options
dataset_name = opt.dataset_name
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)
# save and return the parser
self.parser = parser
return parser.parse_args()
def print_options(self, opt):
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
# save to the disk
opt.expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
utils.mkdirs(opt.expr_dir)
file_name = os.path.join(opt.expr_dir, '{}_opt.txt'.format(opt.phase))
with open(file_name, 'wt') as opt_file:
opt_file.write(message)
opt_file.write('\n')
opt.log_dir = os.path.join(opt.checkpoints_dir, 'log_dir')
utils.mkdirs(opt.log_dir)
opt.log_archive = os.path.join(opt.checkpoints_dir, 'log_archive')
utils.mkdirs(opt.log_archive)
def parse(self):
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
if opt.debug:
opt.name = 'debug'
opt.save_iter_freq = 1
opt.save_latest_freq = 1
opt.visual_freq = 1
opt.print_freq = 1
# Find avaliable GPUs automatically
if opt.gpus > 0:
opt.gpu_ids = utils.get_gpu_memory_map()[1][:opt.gpus]
if not isinstance(opt.gpu_ids, list):
opt.gpu_ids = [opt.gpu_ids]
torch.cuda.set_device(opt.gpu_ids[0])
opt.device = torch.device('cuda:{}'.format(opt.gpu_ids[0 % opt.gpus]))
opt.data_device = torch.device('cuda:{}'.format(opt.gpu_ids[1 % opt.gpus]))
else:
opt.gpu_ids = []
opt.device = torch.device('cpu')
# set random seeds to ensure reproducibility
np.random.seed(opt.seed)
random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed_all(opt.seed)
# process opt.suffix
if opt.suffix:
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
opt.name = opt.name + suffix
self.print_options(opt)
self.opt = opt
return self.opt
@@ -0,0 +1,30 @@
from .base_options import BaseOptions
class TestOptions(BaseOptions):
"""This class includes test options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser) # define shared options
parser.add_argument('--src_dir', type=str, default='G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan/n000002', help='source directory containing test images')
parser.add_argument('--save_masks_dir', type=str, default='../datasets/FFHQ/masks512', help='path to save parsing masks for FFHQ')
parser.add_argument('--test_img_path', type=str, default='', help='path for single image test')
parser.add_argument('--test_upscale', type=float, default=1, help='upsample scale for single image test')
parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
# Dropout and Batchnorm has different behavioir during training and test.
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
parser.add_argument('--parse_net_weight', type=str, default='./pretrain_models/parse_multi_iter_90000.pth', help='parse model path')
parser.add_argument('--psfr_net_weight', type=str, default='./pretrain_models/psfrgan_epoch15_net_G.pth', help='parse model path')
# rewrite devalue values
# To avoid cropping, the load_size should be the same as crop_size
parser.set_defaults(load_size=parser.get_default('crop_size'))
self.isTrain = False
return parser
@@ -0,0 +1,41 @@
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
"""This class includes training options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser)
# visdom and HTML visualization parameters
parser.add_argument('--visual_freq', type=int, default=400, help='frequency of show training images in tensorboard')
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
# network saving and loading parameters
parser.add_argument('--save_iter_freq', type=int, default=5000, help='frequency of saving the models')
parser.add_argument('--save_latest_freq', type=int, default=500, help='save latest freq')
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--no_strict_load', action='store_true', help='set strict load to false')
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
# training parameters
parser.add_argument('--resume_epoch', type=int, default=0, help='training resume epoch')
parser.add_argument('--resume_iter', type=int, default=0, help='training resume iter')
parser.add_argument('--total_epochs', type=int, default=50, help='# of epochs to train')
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')
parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--g_lr', type=float, default=0.0001, help='generator learning rate')
parser.add_argument('--d_lr', type=float, default=0.0004, help='discriminator learning rate')
parser.add_argument('--gan_mode', type=str, default='hinge', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
parser.add_argument('--lr_decay_gamma', type=float, default=1, help='multiply by a gamma every lr_decay_iters iterations')
self.isTrain = True
return parser
@@ -0,0 +1,11 @@
torch==1.5.1
torchvision==0.6.1
tensorflow>=1.15.4
tensorboard==1.15.0
tensorboardX==2.1
opencv-python
dlib
scikit-image==0.17.2
scipy==1.4.1
tqdm
imgaug
@@ -0,0 +1,62 @@
import os
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from utils import utils
from PIL import Image
from tqdm import tqdm
import torch
import time
import numpy as np
if __name__ == '__main__':
opt = TestOptions().parse() # get test options
opt.num_threads = 0 # test code only supports num_threads = 1
opt.batch_size = 4 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
model = create_model(opt) # create a model given opt.model and other options
model.load_pretrain_models()
save_dir = opt.results_dir
os.makedirs(save_dir, exist_ok=True)
print('creating result directory', save_dir)
netP = model.netP
netG = model.netG
model.eval()
max_size = 9999
os.makedirs(os.path.join(save_dir, 'sr'), exist_ok=True)
for i, data in tqdm(enumerate(dataset), total=len(dataset)//opt.batch_size):
inp = data['LR']
with torch.no_grad():
parse_map, _ = netP(inp)
parse_map_sm = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float()
output_SR = netG(inp, parse_map_sm)
img_path = data['LR_paths'] # get image paths
for i in tqdm(range(len(img_path))):
inp_img = utils.batch_tensor_to_img(inp)
output_sr_img = utils.batch_tensor_to_img(output_SR)
ref_parse_img = utils.color_parse_map(parse_map_sm)
save_path = os.path.join(save_dir, 'lq', os.path.basename(img_path[i]))
os.makedirs(os.path.join(save_dir, 'lq'), exist_ok=True)
save_img = Image.fromarray(inp_img[i])
save_img.save(save_path)
save_path = os.path.join(save_dir, 'hq', os.path.basename(img_path[i]))
os.makedirs(os.path.join(save_dir, 'hq'), exist_ok=True)
save_img = Image.fromarray(output_sr_img[i])
save_img.save(save_path)
save_path = os.path.join(save_dir, 'parse', os.path.basename(img_path[i]))
os.makedirs(os.path.join(save_dir, 'parse'), exist_ok=True)
save_img = Image.fromarray(ref_parse_img[i])
save_img.save(save_path)
if i > max_size: break
@@ -0,0 +1,57 @@
'''
This script enhance images with unaligned faces in a folder and paste it back to the original place.
'''
import dlib
import os
import cv2
import numpy as np
from tqdm import tqdm
from skimage import transform as trans
from skimage import io
import torch
from utils import utils
from options.test_options import TestOptions
from models import create_model
from test_enhance_single_unalign import *
if __name__ == '__main__':
opt = TestOptions().parse()
# face_detector = dlib.get_frontal_face_detector()
face_detector = dlib.cnn_face_detection_model_v1('./pretrain_models/mmod_human_face_detector.dat')
lmk_predictor = dlib.shape_predictor('./pretrain_models/shape_predictor_5_face_landmarks.dat')
template_path = './pretrain_models/FFHQ_template.npy'
enhance_model = def_models(opt)
for img_name in os.listdir(opt.src_dir):
img_path = os.path.join(opt.src_dir, img_name)
save_current_dir = os.path.join(opt.results_dir, os.path.splitext(img_name)[0])
os.makedirs(save_current_dir, exist_ok=True)
print('======> Loading image', img_path)
img = dlib.load_rgb_image(img_path)
aligned_faces, tform_params = detect_and_align_faces(img, face_detector, lmk_predictor, template_path)
# Save aligned LQ faces
save_lq_dir = os.path.join(save_current_dir, 'LQ_faces')
os.makedirs(save_lq_dir, exist_ok=True)
print('======> Saving aligned LQ faces to', save_lq_dir)
save_imgs(aligned_faces, save_lq_dir)
hq_faces, lq_parse_maps = enhance_faces(aligned_faces, enhance_model)
# Save LQ parsing maps and enhanced faces
save_parse_dir = os.path.join(save_current_dir, 'ParseMaps')
save_hq_dir = os.path.join(save_current_dir, 'HQ')
os.makedirs(save_parse_dir, exist_ok=True)
os.makedirs(save_hq_dir, exist_ok=True)
print('======> Save parsing map and the enhanced faces.')
save_imgs(lq_parse_maps, save_parse_dir)
save_imgs(hq_faces, save_hq_dir)
print('======> Paste the enhanced faces back to the original image.')
hq_img = past_faces_back(img, hq_faces, tform_params, upscale=opt.test_upscale)
final_save_path = os.path.join(save_current_dir, 'hq_final.jpg')
print('======> Save final result to', final_save_path)
io.imsave(final_save_path, hq_img)
@@ -0,0 +1,126 @@
'''
This script enhance all faces in one image with PSFR-GAN and paste it back to the original place.
'''
import dlib
import os
import cv2
import numpy as np
from tqdm import tqdm
from skimage import transform as trans
from skimage import io
import torch
from utils import utils
from options.test_options import TestOptions
from models import create_model
def detect_and_align_faces(img, face_detector, lmk_predictor, template_path, template_scale=2, size_threshold=999):
align_out_size = (512, 512)
ref_points = np.load(template_path) / template_scale
# Detect landmark points
face_dets = face_detector(img, 1)
assert len(face_dets) > 0, 'No faces detected'
aligned_faces = []
tform_params = []
for det in face_dets:
if isinstance(face_detector, dlib.cnn_face_detection_model_v1):
rec = det.rect # for cnn detector
else:
rec = det
if rec.width() > size_threshold or rec.height() > size_threshold:
print('Face is too large')
break
landmark_points = lmk_predictor(img, rec)
single_points = []
for i in range(5):
single_points.append([landmark_points.part(i).x, landmark_points.part(i).y])
single_points = np.array(single_points)
tform = trans.SimilarityTransform()
tform.estimate(single_points, ref_points)
tmp_face = trans.warp(img, tform.inverse, output_shape=align_out_size, order=3)
aligned_faces.append(tmp_face*255)
tform_params.append(tform)
return [aligned_faces, tform_params]
def def_models(opt):
model = create_model(opt)
model.load_pretrain_models()
model.netP.to(opt.device)
model.netG.to(opt.device)
return model
def enhance_faces(LQ_faces, model):
hq_faces = []
lq_parse_maps = []
for lq_face in tqdm(LQ_faces):
with torch.no_grad():
lq_tensor = torch.tensor(lq_face.transpose(2, 0, 1)) / 255. * 2 - 1
lq_tensor = lq_tensor.unsqueeze(0).float().to(model.device)
parse_map, _ = model.netP(lq_tensor)
parse_map_onehot = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float()
output_SR = model.netG(lq_tensor, parse_map_onehot)
hq_faces.append(utils.tensor_to_img(output_SR))
lq_parse_maps.append(utils.color_parse_map(parse_map_onehot)[0])
return hq_faces, lq_parse_maps
def past_faces_back(img, hq_faces, tform_params, upscale=1):
h, w = img.shape[:2]
img = cv2.resize(img, (int(w*upscale), int(h*upscale)), interpolation=cv2.INTER_CUBIC)
for hq_img, tform in tqdm(zip(hq_faces, tform_params), total=len(hq_faces)):
tform.params[0:2,0:2] /= upscale
back_img = trans.warp(hq_img/255., tform, output_shape=[int(h*upscale), int(w*upscale)], order=3) * 255
# blur mask to avoid border artifacts
mask = (back_img == 0)
mask = cv2.blur(mask.astype(np.float32), (5,5))
mask = (mask > 0)
img = img * mask + (1 - mask) * back_img
return img.astype(np.uint8)
def save_imgs(img_list, save_dir):
for idx, img in enumerate(img_list):
save_path = os.path.join(save_dir, '{:03d}.jpg'.format(idx))
io.imsave(save_path, img.astype(np.uint8))
if __name__ == '__main__':
opt = TestOptions().parse()
# face_detector = dlib.get_frontal_face_detector()
face_detector = dlib.cnn_face_detection_model_v1('./pretrain_models/mmod_human_face_detector.dat')
lmk_predictor = dlib.shape_predictor('./pretrain_models/shape_predictor_5_face_landmarks.dat')
template_path = './pretrain_models/FFHQ_template.npy'
print('======> Loading images, crop and align faces.')
img_path = opt.test_img_path
img = dlib.load_rgb_image(img_path)
aligned_faces, tform_params = detect_and_align_faces(img, face_detector, lmk_predictor, template_path)
# Save aligned LQ faces
save_lq_dir = os.path.join(opt.results_dir, 'LQ_faces')
os.makedirs(save_lq_dir, exist_ok=True)
print('======> Saving aligned LQ faces to', save_lq_dir)
save_imgs(aligned_faces, save_lq_dir)
enhance_model = def_models(opt)
hq_faces, lq_parse_maps = enhance_faces(aligned_faces, enhance_model)
# Save LQ parsing maps and enhanced faces
save_parse_dir = os.path.join(opt.results_dir, 'ParseMaps')
save_hq_dir = os.path.join(opt.results_dir, 'HQ')
os.makedirs(save_parse_dir, exist_ok=True)
os.makedirs(save_hq_dir, exist_ok=True)
print('======> Save parsing map and the enhanced faces.')
save_imgs(lq_parse_maps, save_parse_dir)
save_imgs(hq_faces, save_hq_dir)
print('======> Paste the enhanced faces back to the original image.')
hq_img = past_faces_back(img, hq_faces, tform_params, upscale=opt.test_upscale)
final_save_path = os.path.join(opt.results_dir, 'hq_final.jpg')
print('======> Save final result to', final_save_path)
io.imsave(final_save_path, hq_img)
@@ -0,0 +1,79 @@
from utils.timer import Timer
from utils.logger import Logger
from utils import utils
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
import torch
import os
import torch.multiprocessing as mp
def train(opt):
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset) # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)
model = create_model(opt)
model.setup(opt)
logger = Logger(opt)
timer = Timer()
single_epoch_iters = (dataset_size // opt.batch_size)
total_iters = opt.total_epochs * single_epoch_iters
cur_iters = opt.resume_iter + opt.resume_epoch * single_epoch_iters
start_iter = opt.resume_iter
print('Start training from epoch: {:05d}; iter: {:07d}'.format(opt.resume_epoch, opt.resume_iter))
for epoch in range(opt.resume_epoch, opt.total_epochs + 1):
for i, data in enumerate(dataset, start=start_iter):
cur_iters += 1
logger.set_current_iter(cur_iters)
# =================== load data ===============
model.set_input(data, cur_iters)
timer.update_time('DataTime')
# =================== model train ===============
model.forward(), timer.update_time('Forward')
model.optimize_parameters()
loss = model.get_current_losses()
loss.update(model.get_lr())
logger.record_losses(loss)
timer.update_time('Backward')
# =================== save model and visualize ===============
if cur_iters % opt.print_freq == 0:
print('Model log directory: {}'.format(opt.expr_dir))
epoch_progress = '{:03d}|{:05d}/{:05d}'.format(epoch, i, single_epoch_iters)
logger.printIterSummary(epoch_progress, cur_iters, total_iters, timer)
if cur_iters % opt.visual_freq == 0:
visual_imgs = model.get_current_visuals()
logger.record_images(visual_imgs)
if cur_iters % opt.save_iter_freq == 0:
print('saving current model (epoch %d, iters %d)' % (epoch, cur_iters))
save_suffix = 'iter_%d' % cur_iters
info = {'resume_epoch': epoch, 'resume_iter': i+1}
model.save_networks(save_suffix, info)
if cur_iters % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, iters %d)' % (epoch, cur_iters))
info = {'resume_epoch': epoch, 'resume_iter': i+1}
model.save_networks('latest', info)
if i >= single_epoch_iters - 1:
start_iter = 0
break
# model.update_learning_rate()
if opt.debug: break
if opt.debug and epoch >= 0: break
logger.close()
if __name__ == '__main__':
opt = TrainOptions().parse()
train(opt)
@@ -0,0 +1,91 @@
import os
from collections import OrderedDict
import numpy as np
from .utils import mkdirs
from tensorboardX import SummaryWriter
from datetime import datetime
import socket
import shutil
class Logger():
def __init__(self, opts):
time_stamp = '_{}'.format(datetime.now().strftime('%Y-%m-%d_%H:%M'))
self.opts = opts
self.log_dir = os.path.join(opts.log_dir, opts.name+time_stamp)
self.phase_keys = ['train', 'val', 'test']
self.iter_log = []
self.epoch_log = OrderedDict()
self.set_mode(opts.phase)
# check if exist previous log belong to the same experiment name
exist_log = None
for log_name in os.listdir(opts.log_dir):
if opts.name in log_name:
exist_log = log_name
if exist_log is not None:
old_dir = os.path.join(opts.log_dir, exist_log)
archive_dir = os.path.join(opts.log_archive, exist_log)
shutil.move(old_dir, archive_dir)
self.mk_log_file()
self.writer = SummaryWriter(self.log_dir)
def mk_log_file(self):
mkdirs(self.log_dir)
self.txt_files = OrderedDict()
for i in self.phase_keys:
self.txt_files[i] = os.path.join(self.log_dir, 'log_{}'.format(i))
def set_mode(self, mode):
self.mode = mode
self.epoch_log[mode] = []
def set_current_iter(self, cur_iter):
self.cur_iter = cur_iter
def record_losses(self, items):
"""
iteration log: [iter][{key: value}]
"""
self.iter_log.append(items)
for k, v in items.items():
if 'loss' in k.lower():
self.writer.add_scalar('loss/{}'.format(k), v, self.cur_iter)
def record_scalar(self, items):
"""
Add scalar records. item, {key: value}
"""
for i in items.keys():
self.writer.add_scalar('{}'.format(i), items[i], self.cur_iter)
def record_images(self, visuals, nrow=6, tag='ckpt_image'):
imgs = []
max_len = visuals[0].shape[0]
for i in range(nrow):
if i >= max_len: continue
tmp_imgs = [x[i] for x in visuals]
imgs.append(np.hstack(tmp_imgs))
imgs = np.vstack(imgs).astype(np.uint8)
self.writer.add_image(tag, imgs, self.cur_iter, dataformats='HWC')
def record_text(self, tag, text):
self.writer.add_text(tag, text)
def printIterSummary(self, epoch, cur_iters, total_it, timer):
msg = '{}\nIter: [{}]{:03d}/{:03d}\t\t'.format(
timer.to_string(total_it - cur_iters), epoch, cur_iters, total_it)
for k, v in self.iter_log[-1].items():
msg += '{}: {:.6f}\t'.format(k, v)
print(msg + '\n')
with open(self.txt_files[self.mode], 'a+') as f:
f.write(msg + '\n')
def close(self):
self.writer.export_scalars_to_json(os.path.join(self.log_dir, 'all_scalars.json'))
self.writer.close()
@@ -0,0 +1,34 @@
import time
import datetime
from collections import OrderedDict
class Timer():
def __init__(self):
self.reset_timer()
self.start = time.time()
def reset_timer(self):
self.before = time.time()
self.timer = OrderedDict()
def restart(self):
self.before = time.time()
def update_time(self, key):
self.timer[key] = time.time() - self.before
self.before = time.time()
def to_string(self, iters_left, short=False):
iter_total = sum(self.timer.values())
msg = "{:%Y-%m-%d %H:%M:%S}\tElapse: {}\tTimeLeft: {}\t".format(
datetime.datetime.now(),
datetime.timedelta(seconds=round(time.time() - self.start)),
datetime.timedelta(seconds=round(iter_total*iters_left))
)
if short:
msg += '{}: {:.2f}s'.format('|'.join(self.timer.keys()), iter_total)
else:
msg += '\tIterTotal: {:.2f}s\t{}: {} '.format(iter_total,
'|'.join(self.timer.keys()), ' '.join('{:.2f}s'.format(x) for x in self.timer.values()))
return msg
@@ -0,0 +1,169 @@
import torch
import numpy as np
import cv2 as cv
from skimage import io
from PIL import Image
import os
import subprocess
# MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
MASK_COLORMAP = [[0, 0, 0], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [0,0, 0], [0, 0, 0], [255, 255, 255], [255, 255, 255], [255, 255, 255], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]
label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth']
def array_to_heatmap(x):
x = (x - x.min()) / (x.max() - x.min()) * 255
x = x.astype(np.uint8)
return cv.applyColorMap(x.astype(np.uint8), cv.COLORMAP_RAINBOW)
def img_to_tensor(img_path, device, size=None, mode='rgb'):
"""
Read image from img_path, and convert to (C, H, W) tensor in range [-1, 1]
"""
img = Image.open(img_path).convert('RGB')
img = np.array(img)
if mode=='bgr':
img = img[..., ::-1]
if size:
img = cv.resize(img, size)
img = img / 255 * 2 - 1
img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
return img_tensor.float()
def tensor_to_img(tensor, save_path=None, size=None, mode='RGB', normal=[-1, 1]):
"""
mode: RGB or L (gray image)
Input: tensor with shape (C, H, W)
Output: PIL Image
"""
if isinstance(size, int):
size = (size, size)
img_array = tensor.squeeze().data.cpu().numpy()
if mode == 'RGB':
img_array = img_array.transpose(1, 2, 0)
if size is not None:
img_array = cv.resize(img_array, size, interpolation=cv.INTER_LINEAR)
if len(normal):
img_array = (img_array - normal[0]) / (normal[1] - normal[0]) * 255
img_array = img_array.clip(0, 255)
img_array = img_array.astype(np.uint8)
if save_path:
img = Image.fromarray(img_array, mode)
img.save(save_path)
return img_array
def tensor_to_numpy(tensor):
return tensor.data.cpu().numpy()
def batch_numpy_to_image(array, size=None):
"""
Input: numpy array (B, C, H, W) in [-1, 1]
"""
if isinstance(size, int):
size = (size, size)
out_imgs = []
array = np.clip((array + 1)/2 * 255, 0, 255)
array = np.transpose(array, (0, 2, 3, 1))
for i in range(array.shape[0]):
if size is not None:
tmp_array = cv.resize(array[i], size)
else:
tmp_array = array[i]
out_imgs.append(tmp_array)
return np.array(out_imgs).astype(np.uint8)
def batch_tensor_to_img(tensor, size=None):
"""
Input: (B, C, H, W)
Return: RGB image, [0, 255]
"""
arrays = tensor_to_numpy(tensor)
out_imgs = batch_numpy_to_image(arrays, size)
return out_imgs
def color_parse_map(tensor, size=None):
"""
input: tensor or batch tensor
return: colorized parsing maps
"""
if len(tensor.shape) < 4:
tensor = tensor.unsqueeze(0)
if tensor.shape[1] > 1:
tensor = tensor.argmax(dim=1)
tensor = tensor.squeeze(1).data.cpu().numpy()
color_maps = []
for t in tensor:
tmp_img = np.zeros(tensor.shape[1:] + (3,))
for idx, color in enumerate(MASK_COLORMAP):
tmp_img[t == idx] = color
if size is not None:
tmp_img = cv.resize(tmp_img, (size, size))
color_maps.append(tmp_img.astype(np.uint8))
return color_maps
def onehot_parse_map(img):
"""
input: RGB color parse map
output: one hot encoding of parse map
"""
n_label = len(MASK_COLORMAP)
img = np.array(img, dtype=np.uint8)
h, w = img.shape[:2]
onehot_label = np.zeros((n_label, h, w))
colormap = np.array(MASK_COLORMAP).reshape(n_label, 1, 1, 3)
colormap = np.tile(colormap, (1, h, w, 1))
for idx, color in enumerate(MASK_COLORMAP):
tmp_label = colormap[idx] == img
onehot_label[idx] = tmp_label[..., 0] * tmp_label[..., 1] * tmp_label[..., 2]
return onehot_label
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
if not os.path.exists(path):
os.makedirs(path)
else:
if not os.path.exists(paths):
os.makedirs(paths)
def get_gpu_memory_map():
"""Get the current gpu usage within visible cuda devices.
Returns
-------
Memory Map: dict
Keys are device ids as integers.
Values are memory usage as integers in MB.
Device Ids: gpu ids sorted in descending order according to the available memory.
"""
result = subprocess.check_output(
[
'nvidia-smi', '--query-gpu=memory.used',
'--format=csv,nounits,noheader'
]).decode('utf-8')
# Convert lines into a dictionary
gpu_memory = np.array([int(x) for x in result.strip().split('\n')])
if 'CUDA_VISIBLE_DEVICES' in os.environ:
visible_devices = sorted([int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',')])
else:
visible_devices = range(len(gpu_memory))
gpu_memory_map = dict(zip(range(len(visible_devices)), gpu_memory[visible_devices]))
return gpu_memory_map, sorted(gpu_memory_map, key=gpu_memory_map.get)
if __name__ == '__main__':
hm = torch.randn(32, 68, 128, 128).cuda()
flip(hm, 2)
x = torch.ones(32, 68)
y = torch.ones(32, 68)
print(get_gpu_memory_map())
+52
View File
@@ -0,0 +1,52 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: filter.py
# Created Date: Wednesday April 13th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 13th April 2022 3:49:23 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import cv2
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision import transforms
class HighPass(nn.Module):
def __init__(self, w_hpf, device):
super(HighPass, self).__init__()
self.filter = torch.tensor([[-1, -1, -1],
[-1, 8., -1],
[-1, -1, -1]]).to(device) / w_hpf
def forward(self, x):
filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1)
return F.conv2d(x, filter, padding=1, groups=x.size(1))
if __name__ == "__main__":
transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
img = "G:/swap_data/ID/2.jpg"
attr = cv2.imread(img)
attr = Image.fromarray(cv2.cvtColor(attr,cv2.COLOR_BGR2RGB))
attr = transformer_Arcface(attr).unsqueeze(0)
results = HighPass(0.5,torch.device("cpu"))(attr)
results = results * imagenet_std + imagenet_mean
results = results.cpu().permute(0,2,3,1)[0,...]
results = results.numpy()
results = np.clip(results,0.0,1.0) * 255
results = cv2.cvtColor(results,cv2.COLOR_RGB2BGR)
cv2.imwrite("filter_results2.png",results)
+6 -5
View File
@@ -5,7 +5,7 @@
# Created Date: Sunday February 13th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 4th March 2022 1:53:53 am
# Last Modified: Monday, 18th April 2022 10:52:57 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -22,7 +22,7 @@ from thop import clever_format
if __name__ == '__main__':
#
# script = "Generator_modulation_up"
script = "Generator_Invobn_config3"
script = "Generator_2mask"
# script = "Generator_ori_modulation_config"
# script = "Generator_ori_config"
class_name = "Generator"
@@ -30,12 +30,13 @@ if __name__ == '__main__':
model_config={
"id_dim": 512,
"g_kernel_size": 3,
"in_channel":16,
"res_num": 9,
"in_channel":64,
"res_num": 3,
# "up_mode": "nearest",
"up_mode": "bilinear",
"aggregator": "eca_invo",
"res_mode": "conv"
"res_mode": "conv",
"norm": "bn"
}
+76
View File
@@ -0,0 +1,76 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: id_cos.py
# Created Date: Friday March 25th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 29th March 2022 11:58:30 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms
from insightface_func.face_detect_crop_single import Face_detect_crop
from arcface_torch.backbones.iresnet import iresnet100
if __name__ == "__main__":
imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
transformer_Arcface = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
arcface_ckpt = "./arcface_ckpt/arcface_checkpoint.tar"
arcface1 = torch.load(arcface_ckpt, map_location=torch.device("cpu"))
arcface = arcface1['model'].module
arcface.eval()
root1 = "G:/VGGFace2-HQ/VGGface2_ffhq_align_256_9_28_512_bygfpgan/n000002/"
root2 = "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan/n000002/"
# arcface_ckpt = "./arcface_torch/checkpoints/backbone.pth" # backbone.pth glint360k_cosface_r100_fp16_backbone.pth
# arcface = iresnet100(pretrained=False, fp16=False)
# arcface.load_state_dict(torch.load(arcface_ckpt, map_location='cpu'))
# arcface.eval()
# id1 = "G:/swap_data/ID/hinton.jpg"
# id2 = "G:/hififace-master/hififace-master/assets/inference_samples/hififace/img-172.jpg"
id1 = root2 + "0003_01.jpg"
id2 = root2 + "0036_01.jpg"
mode = "none"
cos_loss = torch.nn.CosineSimilarity()
# detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
# detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
id_img = cv2.imread(id1)
# id_img_align_crop, _ = detect.get(id_img,256)
# cv2.imwrite("id1_crop.png",id_img_align_crop[0])
# id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB))
id_img = transformer_Arcface(id_img_align_crop_pil)
id_img = id_img.unsqueeze(0)
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
# id_img = (id_img-0.5)*2.0
latend_id = arcface(id_img)
latend_id = F.normalize(latend_id, p=2, dim=1)
id_img2 = cv2.imread(id2)
# id_img_align_crop2, _ = detect.get(id_img2,256)
# cv2.imwrite("id2_crop.png",id_img_align_crop2[0])
# id_img_align_crop_pil2 = Image.fromarray(cv2.cvtColor(id_img_align_crop2[0],cv2.COLOR_BGR2RGB))
id_img_align_crop_pil2 = Image.fromarray(cv2.cvtColor(id_img2,cv2.COLOR_BGR2RGB))
id_img2 = transformer_Arcface(id_img_align_crop_pil2)
id_img2 = id_img2.unsqueeze(0)
id_img2 = F.interpolate(id_img2,size=(112,112), mode='bicubic')
# id_img2 = (id_img2-0.5)*2.0
latend_id2 = arcface(id_img2)
latend_id2 = F.normalize(latend_id2, p=2, dim=1)
cos_dis = 1 - cos_loss(latend_id, latend_id2)
print("cosine similarity:", cos_dis.item())
+267
View File
@@ -0,0 +1,267 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
"Alias-Free Generative Adversarial Networks"."""
import copy
import numpy as np
import torch
import torch.fft
from torch_utils.ops import upfirdn2d
from . import metric_utils
#----------------------------------------------------------------------------
# Utilities.
def sinc(x):
y = (x * np.pi).abs()
z = torch.sin(y) / y.clamp(1e-30, float('inf'))
return torch.where(y < 1e-30, torch.ones_like(x), z)
def lanczos_window(x, a):
x = x.abs() / a
return torch.where(x < 1, sinc(x), torch.zeros_like(x))
def rotation_matrix(angle):
angle = torch.as_tensor(angle).to(torch.float32)
mat = torch.eye(3, device=angle.device)
mat[0, 0] = angle.cos()
mat[0, 1] = angle.sin()
mat[1, 0] = -angle.sin()
mat[1, 1] = angle.cos()
return mat
#----------------------------------------------------------------------------
# Apply integer translation to a batch of 2D images. Corresponds to the
# operator T_x in Appendix E.1.
def apply_integer_translation(x, tx, ty):
_N, _C, H, W = x.shape
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
ix = tx.round().to(torch.int64)
iy = ty.round().to(torch.int64)
z = torch.zeros_like(x)
m = torch.zeros_like(x)
if abs(ix) < W and abs(iy) < H:
y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
return z, m
#----------------------------------------------------------------------------
# Apply integer translation to a batch of 2D images. Corresponds to the
# operator T_x in Appendix E.2.
def apply_fractional_translation(x, tx, ty, a=3):
_N, _C, H, W = x.shape
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
ix = tx.floor().to(torch.int64)
iy = ty.floor().to(torch.int64)
fx = tx - ix
fy = ty - iy
b = a - 1
z = torch.zeros_like(x)
zx0 = max(ix - b, 0)
zy0 = max(iy - b, 0)
zx1 = min(ix + a, 0) + W
zy1 = min(iy + a, 0) + H
if zx0 < zx1 and zy0 < zy1:
taps = torch.arange(a * 2, device=x.device) - b
filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
y = x
y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
z[:, :, zy0:zy1, zx0:zx1] = y
m = torch.zeros_like(x)
mx0 = max(ix + a, 0)
my0 = max(iy + a, 0)
mx1 = min(ix - b, 0) + W
my1 = min(iy - b, 0) + H
if mx0 < mx1 and my0 < my1:
m[:, :, my0:my1, mx0:mx1] = 1
return z, m
#----------------------------------------------------------------------------
# Construct an oriented low-pass filter that applies the appropriate
# bandlimit with respect to the input and output of the given affine 2D
# image transformation.
def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
assert a <= amax < aflt
mat = torch.as_tensor(mat).to(torch.float32)
# Construct 2D filter taps in input & output coordinate spaces.
taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
yi, xi = torch.meshgrid(taps, taps)
xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
# Convolution of two oriented 2D sinc filters.
fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
# Convolution of two oriented 2D Lanczos windows.
wi = lanczos_window(xi, a) * lanczos_window(yi, a)
wo = lanczos_window(xo, a) * lanczos_window(yo, a)
w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
# Construct windowed FIR filter.
f = f * w
# Finalize.
c = (aflt - amax) * up
f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
f = f / f.sum([0,2], keepdim=True) / (up ** 2)
f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
return f
#----------------------------------------------------------------------------
# Apply the given affine transformation to a batch of 2D images.
def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
_N, _C, H, W = x.shape
mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
# Construct filter.
f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
p = f.shape[0] // 2
# Construct sampling grid.
theta = mat.inverse()
theta[:2, 2] *= 2
theta[0, 2] += 1 / up / W
theta[1, 2] += 1 / up / H
theta[0, :] *= W / (W + p / up * 2)
theta[1, :] *= H / (H + p / up * 2)
theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
# Resample image.
y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
# Form mask.
m = torch.zeros_like(y)
c = p * 2 + 1
m[:, :, c:-c, c:-c] = 1
m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
return z, m
#----------------------------------------------------------------------------
# Apply fractional rotation to a batch of 2D images. Corresponds to the
# operator R_\alpha in Appendix E.3.
def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
mat = rotation_matrix(angle)
return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
#----------------------------------------------------------------------------
# Modify the frequency content of a batch of 2D images as if they had undergo
# fractional rotation -- but without actually rotating them. Corresponds to
# the operator R^*_\alpha in Appendix E.3.
def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
mat = rotation_matrix(-angle)
f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
y = upfirdn2d.filter2d(x=x, f=f)
m = torch.zeros_like(y)
c = f.shape[0] // 2
m[:, :, c:-c, c:-c] = 1
return y, m
#----------------------------------------------------------------------------
# Compute the selected equivariance metrics for the given generator.
def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
assert compute_eqt_int or compute_eqt_frac or compute_eqr
# Setup generator and labels.
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
I = torch.eye(3, device=opts.device)
M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
if M is None:
raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
# Sampling loop.
sums = None
progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
progress.update(batch_start)
s = []
# Randomize noise buffers, if any.
for name, buf in G.named_buffers():
if name.endswith('.noise_const'):
buf.copy_(torch.randn_like(buf))
# Run mapping network.
z = torch.randn([batch_size, G.z_dim], device=opts.device)
c = next(c_iter)
ws = G.mapping(z=z, c=c)
# Generate reference image.
M[:] = I
orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
# Integer translation (EQ-T).
if compute_eqt_int:
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
t = (t * G.img_resolution).round() / G.img_resolution
M[:] = I
M[:2, 2] = -t
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
ref, mask = apply_integer_translation(orig, t[0], t[1])
s += [(ref - img).square() * mask, mask]
# Fractional translation (EQ-T_frac).
if compute_eqt_frac:
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
M[:] = I
M[:2, 2] = -t
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
ref, mask = apply_fractional_translation(orig, t[0], t[1])
s += [(ref - img).square() * mask, mask]
# Rotation (EQ-R).
if compute_eqr:
angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
M[:] = rotation_matrix(-angle)
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
ref, ref_mask = apply_fractional_rotation(orig, angle)
pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
mask = ref_mask * pseudo_mask
s += [(ref - pseudo).square() * mask, mask]
# Accumulate results.
s = torch.stack([x.to(torch.float64).sum() for x in s])
sums = sums + s if sums is not None else s
progress.update(num_samples)
# Compute PSNRs.
if opts.num_gpus > 1:
torch.distributed.all_reduce(sums)
sums = sums.cpu()
mses = sums[0::2] / sums[1::2]
psnrs = np.log10(2) * 20 - mses.log10() * 10
psnrs = tuple(psnrs.numpy())
return psnrs[0] if len(psnrs) == 1 else psnrs
#----------------------------------------------------------------------------
+41
View File
@@ -0,0 +1,41 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Frechet Inception Distance (FID) from the paper
"GANs trained by a two time-scale update rule converge to a local Nash
equilibrium". Matches the original implementation by Heusel et al. at
https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
import numpy as np
import scipy.linalg
from . import metric_utils
#----------------------------------------------------------------------------
def compute_fid(opts, max_real, num_gen, swav=False, sfid=False):
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, swav=swav, sfid=sfid).get_mean_cov()
mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, swav=swav, sfid=sfid).get_mean_cov()
if opts.rank != 0:
return float('nan')
m = np.square(mu_gen - mu_real).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
return float(fid)
#----------------------------------------------------------------------------
+38
View File
@@ -0,0 +1,38 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Inception Score (IS) from the paper "Improved techniques for training
GANs". Matches the original implementation by Salimans et al. at
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
import numpy as np
from . import metric_utils
#----------------------------------------------------------------------------
def compute_is(opts, num_gen, num_splits):
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
gen_probs = metric_utils.compute_feature_stats_for_generator(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
capture_all=True, max_items=num_gen).get_all()
if opts.rank != 0:
return float('nan'), float('nan')
scores = []
for i in range(num_splits):
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
kl = np.mean(np.sum(kl, axis=1))
scores.append(np.exp(kl))
return float(np.mean(scores)), float(np.std(scores))
#----------------------------------------------------------------------------
+46
View File
@@ -0,0 +1,46 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
GANs". Matches the original implementation by Binkowski et al. at
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
import numpy as np
from . import metric_utils
#----------------------------------------------------------------------------
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
real_features = metric_utils.compute_feature_stats_for_dataset(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
gen_features = metric_utils.compute_feature_stats_for_generator(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
if opts.rank != 0:
return float('nan')
n = real_features.shape[1]
m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
t = 0
for _subset_idx in range(num_subsets):
x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
b = (x @ y.T / n + 1) ** 3
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
kid = t / num_subsets / m
return float(kid)
#----------------------------------------------------------------------------
+151
View File
@@ -0,0 +1,151 @@
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Main API for computing and reporting quality metrics."""
import os
import time
import json
import torch
import dnnlib
from . import metric_utils
from . import frechet_inception_distance
from . import kernel_inception_distance
from . import precision_recall
from . import perceptual_path_length
from . import inception_score
from . import equivariance
#----------------------------------------------------------------------------
_metric_dict = dict() # name => fn
def register_metric(fn):
assert callable(fn)
_metric_dict[fn.__name__] = fn
return fn
def is_valid_metric(metric):
return metric in _metric_dict
def list_valid_metrics():
return list(_metric_dict.keys())
#----------------------------------------------------------------------------
def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
assert is_valid_metric(metric)
opts = metric_utils.MetricOptions(**kwargs)
# Calculate.
start_time = time.time()
results = _metric_dict[metric](opts)
total_time = time.time() - start_time
# Broadcast results.
for key, value in list(results.items()):
if opts.num_gpus > 1:
value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
torch.distributed.broadcast(tensor=value, src=0)
value = float(value.cpu())
results[key] = value
# Decorate with metadata.
return dnnlib.EasyDict(
results = dnnlib.EasyDict(results),
metric = metric,
total_time = total_time,
total_time_str = dnnlib.util.format_time(total_time),
num_gpus = opts.num_gpus,
)
#----------------------------------------------------------------------------
def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
metric = result_dict['metric']
assert is_valid_metric(metric)
if run_dir is not None and snapshot_pkl is not None:
snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
print(jsonl_line)
if run_dir is not None and os.path.isdir(run_dir):
with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
f.write(jsonl_line + '\n')
#----------------------------------------------------------------------------
# Recommended metrics.
@register_metric
def fid50k_full(opts):
opts.dataset_kwargs.update(max_size=None, xflip=False)
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
return dict(fid50k_full=fid)
@register_metric
def fid10k_full(opts):
opts.dataset_kwargs.update(max_size=None, xflip=False)
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=10000)
return dict(fid10k_full=fid)
@register_metric
def kid50k_full(opts):
opts.dataset_kwargs.update(max_size=None, xflip=False)
kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
return dict(kid50k_full=kid)
@register_metric
def pr50k3_full(opts):
opts.dataset_kwargs.update(max_size=None, xflip=False)
precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
@register_metric
def ppl2_wend(opts):
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
return dict(ppl2_wend=ppl)
@register_metric
def eqt50k_int(opts):
opts.G_kwargs.update(force_fp32=True)
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
return dict(eqt50k_int=psnr)
@register_metric
def eqt50k_frac(opts):
opts.G_kwargs.update(force_fp32=True)
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
return dict(eqt50k_frac=psnr)
@register_metric
def eqr50k(opts):
opts.G_kwargs.update(force_fp32=True)
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
return dict(eqr50k=psnr)
# Legacy metrics.
@register_metric
def fid50k(opts):
opts.dataset_kwargs.update(max_size=None)
fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
return dict(fid50k=fid)
@register_metric
def kid50k(opts):
opts.dataset_kwargs.update(max_size=None)
kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
return dict(kid50k=kid)
@register_metric
def pr50k3(opts):
opts.dataset_kwargs.update(max_size=None)
precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
return dict(pr50k3_precision=precision, pr50k3_recall=recall)
@register_metric
def is50k(opts):
opts.dataset_kwargs.update(max_size=None, xflip=False)
mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
return dict(is50k_mean=mean, is50k_std=std)
+298
View File
@@ -0,0 +1,298 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Miscellaneous utilities used internally by the quality metrics."""
import os
import time
import hashlib
import pickle
import copy
import uuid
import numpy as np
import torch
import dnnlib
from tqdm import tqdm
#----------------------------------------------------------------------------
class MetricOptions:
def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True, run_dir=None, cur_nimg=None, snapshot_pkl=None):
assert 0 <= rank < num_gpus
self.G = G
self.G_kwargs = dnnlib.EasyDict(G_kwargs)
self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
self.num_gpus = num_gpus
self.rank = rank
self.device = device if device is not None else torch.device('cuda', rank)
self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
self.cache = cache
self.run_dir = run_dir
self.cur_nimg = cur_nimg
self.snapshot_pkl = snapshot_pkl
#----------------------------------------------------------------------------
_feature_detector_cache = dict()
def get_feature_detector_name(url):
return os.path.splitext(url.split('/')[-1])[0]
def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
assert 0 <= rank < num_gpus
key = (url, device)
if key not in _feature_detector_cache:
is_leader = (rank == 0)
if not is_leader and num_gpus > 1:
torch.distributed.barrier() # leader goes first
with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
_feature_detector_cache[key] = pickle.load(f).to(device)
if is_leader and num_gpus > 1:
torch.distributed.barrier() # others follow
return _feature_detector_cache[key]
#----------------------------------------------------------------------------
def iterate_random_labels(opts, batch_size):
if opts.G.c_dim == 0:
c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
while True:
yield c
else:
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
while True:
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
yield c
#----------------------------------------------------------------------------
class FeatureStats:
def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
self.capture_all = capture_all
self.capture_mean_cov = capture_mean_cov
self.max_items = max_items
self.num_items = 0
self.num_features = None
self.all_features = None
self.raw_mean = None
self.raw_cov = None
def set_num_features(self, num_features):
if self.num_features is not None:
assert num_features == self.num_features
else:
self.num_features = num_features
self.all_features = []
self.raw_mean = np.zeros([num_features], dtype=np.float64)
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
def is_full(self):
return (self.max_items is not None) and (self.num_items >= self.max_items)
def append(self, x):
x = np.asarray(x, dtype=np.float32)
assert x.ndim == 2
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
if self.num_items >= self.max_items:
return
x = x[:self.max_items - self.num_items]
self.set_num_features(x.shape[1])
self.num_items += x.shape[0]
if self.capture_all:
self.all_features.append(x)
if self.capture_mean_cov:
x64 = x.astype(np.float64)
self.raw_mean += x64.sum(axis=0)
self.raw_cov += x64.T @ x64
def append_torch(self, x, num_gpus=1, rank=0):
assert isinstance(x, torch.Tensor) and x.ndim == 2
assert 0 <= rank < num_gpus
if num_gpus > 1:
ys = []
for src in range(num_gpus):
y = x.clone()
torch.distributed.broadcast(y, src=src)
ys.append(y)
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
self.append(x.cpu().numpy())
def get_all(self):
assert self.capture_all
return np.concatenate(self.all_features, axis=0)
def get_all_torch(self):
return torch.from_numpy(self.get_all())
def get_mean_cov(self):
assert self.capture_mean_cov
mean = self.raw_mean / self.num_items
cov = self.raw_cov / self.num_items
cov = cov - np.outer(mean, mean)
return mean, cov
def save(self, pkl_file):
with open(pkl_file, 'wb') as f:
pickle.dump(self.__dict__, f)
@staticmethod
def load(pkl_file):
with open(pkl_file, 'rb') as f:
s = dnnlib.EasyDict(pickle.load(f))
obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
obj.__dict__.update(s)
return obj
#----------------------------------------------------------------------------
class ProgressMonitor:
def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
self.tag = tag
self.num_items = num_items
self.verbose = verbose
self.flush_interval = flush_interval
self.progress_fn = progress_fn
self.pfn_lo = pfn_lo
self.pfn_hi = pfn_hi
self.pfn_total = pfn_total
self.start_time = time.time()
self.batch_time = self.start_time
self.batch_items = 0
if self.progress_fn is not None:
self.progress_fn(self.pfn_lo, self.pfn_total)
def update(self, cur_items):
assert (self.num_items is None) or (cur_items <= self.num_items)
if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
return
cur_time = time.time()
total_time = cur_time - self.start_time
time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
if (self.verbose) and (self.tag is not None):
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
self.batch_time = cur_time
self.batch_items = cur_items
if (self.progress_fn is not None) and (self.num_items is not None):
self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
return ProgressMonitor(
tag = tag,
num_items = num_items,
flush_interval = flush_interval,
verbose = self.verbose,
progress_fn = self.progress_fn,
pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
pfn_total = self.pfn_total,
)
#----------------------------------------------------------------------------
def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, swav=False, sfid=False, **stats_kwargs):
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
if data_loader_kwargs is None:
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
# Try to lookup from cache.
cache_file = None
if opts.cache:
det_name = get_feature_detector_name(detector_url)
# Choose cache file name.
args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
cache_tag = f'{dataset.name}-{det_name}-{md5.hexdigest()}'
cache_file = os.path.join('.', 'dnnlib', 'gan-metrics', cache_tag + '.pkl')
# cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
# Check if the file exists (all processes must agree).
flag = os.path.isfile(cache_file) if opts.rank == 0 else False
if opts.num_gpus > 1:
flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
torch.distributed.broadcast(tensor=flag, src=0)
flag = (float(flag.cpu()) != 0)
# Load.
if flag:
return FeatureStats.load(cache_file)
print('Calculating the stats for this dataset the first time\n')
print(f'Saving them to {cache_file}')
# Initialize.
num_items = len(dataset)
if max_items is not None:
num_items = min(num_items, max_items)
stats = FeatureStats(max_items=num_items, **stats_kwargs)
progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
# get detector
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
# Main loop.
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
for images, _labels in tqdm(torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs)):
if images.shape[1] == 1:
images = images.repeat([1, 3, 1, 1])
with torch.no_grad():
features = detector(images.to(opts.device), **detector_kwargs)
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
progress.update(stats.num_items)
# Save to cache.
if cache_file is not None and opts.rank == 0:
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
temp_file = cache_file + '.' + uuid.uuid4().hex
stats.save(temp_file)
os.replace(temp_file, cache_file) # atomic
return stats
#----------------------------------------------------------------------------
def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, swav=False, sfid=False, **stats_kwargs):
if batch_gen is None:
batch_gen = min(batch_size, 4)
assert batch_size % batch_gen == 0
# Setup generator and labels.
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
# Initialize.
stats = FeatureStats(**stats_kwargs)
assert stats.max_items is not None
progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
# get detector
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
# Main loop.
while not stats.is_full():
images = []
for _i in range(batch_size // batch_gen):
z = torch.randn([batch_gen, G.z_dim], device=opts.device)
# img = G(z=z, c=next(c_iter), truncation_psi=0.1, **opts.G_kwargs)
img = G(z=z, c=next(c_iter), **opts.G_kwargs)
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
images.append(img)
images = torch.cat(images)
if images.shape[1] == 1:
images = images.repeat([1, 3, 1, 1])
with torch.no_grad():
features = detector(images.to(opts.device), **detector_kwargs)
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
progress.update(stats.num_items)
return stats
+125
View File
@@ -0,0 +1,125 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
Architecture for Generative Adversarial Networks". Matches the original
implementation by Karras et al. at
https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
import copy
import numpy as np
import torch
from . import metric_utils
#----------------------------------------------------------------------------
# Spherical interpolation of a batch of vectors.
def slerp(a, b, t):
a = a / a.norm(dim=-1, keepdim=True)
b = b / b.norm(dim=-1, keepdim=True)
d = (a * b).sum(dim=-1, keepdim=True)
p = t * torch.acos(d)
c = b - d * a
c = c / c.norm(dim=-1, keepdim=True)
d = a * torch.cos(p) + c * torch.sin(p)
d = d / d.norm(dim=-1, keepdim=True)
return d
#----------------------------------------------------------------------------
class PPLSampler(torch.nn.Module):
def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
assert space in ['z', 'w']
assert sampling in ['full', 'end']
super().__init__()
self.G = copy.deepcopy(G)
self.G_kwargs = G_kwargs
self.epsilon = epsilon
self.space = space
self.sampling = sampling
self.crop = crop
self.vgg16 = copy.deepcopy(vgg16)
def forward(self, c):
# Generate random latents and interpolation t-values.
t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
# Interpolate in W or Z.
if self.space == 'w':
w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
else: # space == 'z'
zt0 = slerp(z0, z1, t.unsqueeze(1))
zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
# Randomize noise buffers.
for name, buf in self.G.named_buffers():
if name.endswith('.noise_const'):
buf.copy_(torch.randn_like(buf))
# Generate images.
img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
# Center crop.
if self.crop:
assert img.shape[2] == img.shape[3]
c = img.shape[2] // 8
img = img[:, :, c*3 : c*7, c*2 : c*6]
# Downsample to 256x256.
factor = self.G.img_resolution // 256
if factor > 1:
img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
# Scale dynamic range from [-1,1] to [0,255].
img = (img + 1) * (255 / 2)
if self.G.img_channels == 1:
img = img.repeat([1, 3, 1, 1])
# Evaluate differential LPIPS.
lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
return dist
#----------------------------------------------------------------------------
def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
# Setup sampler and labels.
sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
sampler.eval().requires_grad_(False).to(opts.device)
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
# Sampling loop.
dist = []
progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
progress.update(batch_start)
x = sampler(next(c_iter))
for src in range(opts.num_gpus):
y = x.clone()
if opts.num_gpus > 1:
torch.distributed.broadcast(y, src=src)
dist.append(y)
progress.update(num_samples)
# Compute PPL.
if opts.rank != 0:
return float('nan')
dist = torch.cat(dist)[:num_samples].cpu().numpy()
lo = np.percentile(dist, 1, interpolation='lower')
hi = np.percentile(dist, 99, interpolation='higher')
ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
return float(ppl)
#----------------------------------------------------------------------------
+62
View File
@@ -0,0 +1,62 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Precision/Recall (PR) from the paper "Improved Precision and Recall
Metric for Assessing Generative Models". Matches the original implementation
by Kynkaanniemi et al. at
https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
import torch
from . import metric_utils
#----------------------------------------------------------------------------
def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
assert 0 <= rank < num_gpus
num_cols = col_features.shape[0]
num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
dist_batches = []
for col_batch in col_batches[rank :: num_gpus]:
dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
for src in range(num_gpus):
dist_broadcast = dist_batch.clone()
if num_gpus > 1:
torch.distributed.broadcast(dist_broadcast, src=src)
dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
#----------------------------------------------------------------------------
def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
detector_kwargs = dict(return_features=True)
real_features = metric_utils.compute_feature_stats_for_dataset(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
gen_features = metric_utils.compute_feature_stats_for_generator(
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
results = dict()
for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
kth = []
for manifold_batch in manifold.split(row_batch_size):
dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
kth = torch.cat(kth) if opts.rank == 0 else None
pred = []
for probes_batch in probes.split(row_batch_size):
dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
return results['precision'], results['recall']
#----------------------------------------------------------------------------
+6 -5
View File
@@ -5,7 +5,7 @@
# Created Date: Thursday February 10th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 3rd March 2022 6:44:57 pm
# Last Modified: Sunday, 3rd April 2022 6:39:26 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
@@ -23,7 +23,7 @@ if __name__ == '__main__':
# script = "Generator_modulation_up"
# script = "Generator_modulation_up"
# script = "Generator_Invobn_config3"
script = "Generator_ori_config"
script = "Generator_maskhead_config"
# script = "Generator_ori_config"
class_name = "Generator"
arcface_ckpt= "arcface_ckpt/arcface_checkpoint.tar"
@@ -31,11 +31,12 @@ if __name__ == '__main__':
"id_dim": 512,
"g_kernel_size": 3,
"in_channel":64,
"res_num": 9,
"res_num": 0,
# "up_mode": "nearest",
"up_mode": "bilinear",
"aggregator": "eca_invo",
"res_mode": "eca_invo"
"res_mode": "eca_invo",
"norm": "bn"
}
os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
print("GPU used : ", os.environ['CUDA_VISIBLE_DEVICES'])
@@ -57,7 +58,7 @@ if __name__ == '__main__':
id_latent = torch.rand((4,512)).cuda()
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0]
attr = torch.rand((4,3,224,224)).cuda()
attr = torch.rand((4,3,512,512)).cuda()
import datetime
start_time = time.time()

Some files were not shown because too many files have changed in this diff Show More