update
This commit is contained in:
@@ -5,14 +5,17 @@
|
||||
# Created Date: Wednesday December 22nd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 17th February 2022 10:47:35 pm
|
||||
# Last Modified: Friday, 22nd April 2022 11:23:17 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
from glob import glob
|
||||
from ipaddress import ip_address
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
@@ -37,6 +40,7 @@ import tkinter.ttk as ttk
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from tkinter.filedialog import askopenfilename
|
||||
|
||||
|
||||
|
||||
@@ -163,6 +167,27 @@ class fileUploaderClass(object):
|
||||
self.__ssh__.close()
|
||||
return roots
|
||||
|
||||
def sshScpGetRNamesBySuffix(self, remoteDir, suffix):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
wocao = sftp.listdir(remoteDir)
|
||||
# print(wocao.st_mtime)
|
||||
roots = {}
|
||||
for item in wocao:
|
||||
wocao = sftp.stat(remoteDir+"/"+item)
|
||||
roots[item] = {
|
||||
"t":wocao.st_mtime,
|
||||
"p":remoteDir+"/"+item
|
||||
}
|
||||
# temp= remoteDir+ "/"+item
|
||||
# child_dirs = sftp.listdir(temp)
|
||||
# child_dirs = ["save\\" +item + "\\" + i for i in child_dirs]
|
||||
# list_name += child_dirs
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
return roots
|
||||
|
||||
def sshScpGet(self, remoteFile, localFile, showProgress=False):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
@@ -308,16 +333,16 @@ class Application(tk.Frame):
|
||||
"config_json_name":"model_config.json"
|
||||
}
|
||||
machine_text = {
|
||||
"ip": "0.0.0.0",
|
||||
"ip": "localhost",
|
||||
"user": "username",
|
||||
"port": 22,
|
||||
"passwd": "12345678",
|
||||
"path": "/path/to/remote_host",
|
||||
"path": ".",
|
||||
"ckp_path":"save",
|
||||
"logfilename": "filestate_machine0.json"
|
||||
"logfilename": "filestate_machine_localhost.json"
|
||||
}
|
||||
current_log = {}
|
||||
|
||||
current_ckpt = {}
|
||||
|
||||
def __init__(self, master=None):
|
||||
tk.Frame.__init__(self, master,bg='black')
|
||||
@@ -435,7 +460,7 @@ class Application(tk.Frame):
|
||||
config_frame.pack(fill="both", padx=5,pady=5)
|
||||
config_frame.columnconfigure(0, weight=1)
|
||||
config_frame.columnconfigure(1, weight=1)
|
||||
# config_frame.columnconfigure(2, weight=1)
|
||||
config_frame.columnconfigure(2, weight=1)
|
||||
|
||||
machine_btn = tk.Button(config_frame, text = "Ignore Conf",
|
||||
font=font_list, command = self.IgnoreConfig, bg='#660099', fg='#F5F5F5')
|
||||
@@ -445,12 +470,17 @@ class Application(tk.Frame):
|
||||
font=font_list, command = self.EnvConfig, bg='#660099', fg='#F5F5F5')
|
||||
machine_btn2.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
machine_btn2 = tk.Button(config_frame, text = "Test Conf",
|
||||
font=font_list, command = self.TestConfig, bg='#660099', fg='#F5F5F5')
|
||||
machine_btn2.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
log_frame = tk.Frame(self.master)
|
||||
log_frame.pack(fill="both", padx=5,pady=5)
|
||||
log_frame.columnconfigure(0, weight=1)
|
||||
log_frame.columnconfigure(1, weight=1)
|
||||
log_frame.columnconfigure(2, weight=1)
|
||||
log_frame.columnconfigure(3, weight=1)
|
||||
|
||||
self.log_var = tkinter.StringVar()
|
||||
|
||||
@@ -460,36 +490,104 @@ class Application(tk.Frame):
|
||||
self.update_ckpt_task()
|
||||
self.log_com.bind("<<ComboboxSelected>>",select_log)
|
||||
|
||||
self.test_var = tkinter.StringVar()
|
||||
|
||||
self.test_com = ttk.Combobox(log_frame, textvariable=self.test_var)
|
||||
self.test_com.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
log_update_button = tk.Button(log_frame, text = "Fresh",
|
||||
font=font_list, command = self.UpdateLog, bg='#F4A460', fg='#F5F5F5')
|
||||
log_update_button.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
log_update_button = tk.Button(log_frame, text = "Pull Log",
|
||||
font=font_list, command = self.PullLog, bg='#F4A460', fg='#F5F5F5')
|
||||
log_update_button.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
# log_update_button = tk.Button(log_frame, text = "Pull Log",
|
||||
# font=font_list, command = self.PullLog, bg='#F4A460', fg='#F5F5F5')
|
||||
# log_update_button.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
log_update_button = tk.Button(log_frame, text = "Fresh CKPT",
|
||||
font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5')
|
||||
log_update_button.grid(row=0,column=3,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
test_frame = tk.Frame(self.master)
|
||||
test_frame.pack(fill="both", padx=5,pady=5)
|
||||
test_frame.columnconfigure(0, weight=1)
|
||||
test_frame.columnconfigure(1, weight=1)
|
||||
test_frame.columnconfigure(2, weight=1)
|
||||
# test_frame.columnconfigure(3, weight=1)
|
||||
|
||||
self.test_var = tkinter.StringVar()
|
||||
self.testscript_var = tkinter.StringVar()
|
||||
|
||||
self.test_com = ttk.Combobox(test_frame, textvariable=self.test_var)
|
||||
self.test_com.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
self.testscript_com = ttk.Combobox(test_frame, textvariable=self.testscript_var)
|
||||
self.testscript_com.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
test_update_button = tk.Button(test_frame, text = "Fresh CKPT",
|
||||
font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5')
|
||||
test_update_button.grid(row=0,column=1,sticky=tk.EW)
|
||||
testscript_files = Path("./test_scripts").glob("*.py")
|
||||
testscript_list = []
|
||||
for item in testscript_files:
|
||||
basename = item.name
|
||||
basename = os.path.splitext(basename)[0]
|
||||
testscript_list.append(basename)
|
||||
self.testscript_com["value"] = testscript_list
|
||||
|
||||
# test_update_button = tk.Button(test_frame, text = "Fresh CKPT",
|
||||
# font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5')
|
||||
# test_update_button.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
test_update_button = tk.Button(test_frame, text = "Test",
|
||||
font=font_list, command = self.Test, bg='#F4A460', fg='#F5F5F5')
|
||||
test_update_button.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
# test_update_button = tk.Button(test_frame, text = "Test Config",
|
||||
# font=font_list, command = self.TestConfig, bg='#660099', fg='#F5F5F5')
|
||||
# test_update_button.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
test_update_button = tk.Button(test_frame, text = "Sample",
|
||||
font=font_list, command = self.OpenSample, bg='#0033FF', fg='#F5F5F5')
|
||||
test_update_button.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
# #################################################################################################
|
||||
|
||||
select_frame = tk.Frame(self.master)
|
||||
select_frame.pack(fill="both", padx=5,pady=5)
|
||||
|
||||
select_frame.columnconfigure(0, weight=2)
|
||||
select_frame.columnconfigure(1, weight=5)
|
||||
select_frame.columnconfigure(2, weight=1)
|
||||
select_frame.columnconfigure(3, weight=2)
|
||||
select_frame.columnconfigure(4, weight=5)
|
||||
select_frame.columnconfigure(5, weight=1)
|
||||
select_frame.columnconfigure(6, weight=3)
|
||||
|
||||
self.preprocess_var = tkinter.StringVar()
|
||||
|
||||
self.preprocess_com = ttk.Combobox(select_frame, textvariable=self.preprocess_var)
|
||||
self.preprocess_com.grid(row=0,column=6,sticky=tk.EW)
|
||||
self.preprocess_com["value"] = ["True", "False"]
|
||||
self.preprocess_com.current(1)
|
||||
self.ID_path = tkinter.StringVar()
|
||||
self.ID_path.set("...")
|
||||
tk.Label(select_frame, text="ID:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
tk.Entry(select_frame, textvariable= self.ID_path, font=font_list)\
|
||||
.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
tk.Button(select_frame, text = "...", font=font_list,
|
||||
command = self.Select_ID_path, bg='#F4A460', fg='#F5F5F5')\
|
||||
.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
|
||||
self.Attr_path = tkinter.StringVar()
|
||||
self.Attr_path.set("...")
|
||||
tk.Label(select_frame, text="Attr:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=3,sticky=tk.EW)
|
||||
|
||||
tk.Entry(select_frame, textvariable= self.Attr_path, font=font_list)\
|
||||
.grid(row=0,column=4,sticky=tk.EW)
|
||||
|
||||
tk.Button(select_frame, text = "...", font=font_list,
|
||||
command = self.Select_Attr_path, bg='#F4A460', fg='#F5F5F5')\
|
||||
.grid(row=0,column=5,sticky=tk.EW)
|
||||
|
||||
|
||||
# #################################################################################################
|
||||
# tbtext_frame = tk.Frame(self.master)
|
||||
@@ -527,23 +625,61 @@ class Application(tk.Frame):
|
||||
|
||||
self.master.protocol("WM_DELETE_WINDOW", self.on_closing)
|
||||
|
||||
# def __scaning_logs__(self):
|
||||
|
||||
def Select_ID_path(self):
|
||||
thread_update = threading.Thread(target=self.select_ID_task)
|
||||
thread_update.start()
|
||||
|
||||
def select_ID_task(self):
|
||||
path = askopenfilename()
|
||||
print("Selected ID: %s"%path)
|
||||
self.ID_path.set(path)
|
||||
|
||||
def Select_Attr_path(self):
|
||||
thread_update = threading.Thread(target=self.select_Attr_task)
|
||||
thread_update.start()
|
||||
|
||||
def select_Attr_task(self):
|
||||
path = askopenfilename()
|
||||
print("Selected Attibutes: %s"%path)
|
||||
self.Attr_path.set(path)
|
||||
|
||||
def UpdateCKPT(self):
|
||||
thread_update = threading.Thread(target=self.update_ckpt_task)
|
||||
thread_update.start()
|
||||
|
||||
|
||||
def update_ckpt_task(self):
|
||||
ip = self.list_com.get()
|
||||
log = self.log_com.get()
|
||||
cur_mac = self.machine_dict[ip]
|
||||
files = Path('.',cur_mac["ckp_path"], log)
|
||||
files = files.glob('*.pth')
|
||||
all_files = []
|
||||
for one_file in files:
|
||||
all_files.append(one_file.name)
|
||||
self.test_com["value"] =all_files
|
||||
if len(all_files):
|
||||
print("Loading checkpoints..........................")
|
||||
remotemachine, mac = self.connection()
|
||||
log = self.log_com.get()
|
||||
remote_path = os.path.join(mac["path"],mac["ckp_path"], log, "checkpoints").replace("\\", "/")
|
||||
|
||||
if remotemachine == []:
|
||||
files = Path(remote_path).glob("*/")
|
||||
first_level = {}
|
||||
for one_file in files:
|
||||
first_level[one_file.name] = {
|
||||
"t":"",
|
||||
"p":""
|
||||
}
|
||||
else:
|
||||
first_level = remotemachine.sshScpGetNames(remote_path)
|
||||
|
||||
if len(first_level) == 0:
|
||||
self.test_com["value"] = [""]
|
||||
self.test_com.current(0)
|
||||
print("No checkpoint found!")
|
||||
return
|
||||
logs = []
|
||||
for k,v in first_level.items():
|
||||
logs.append([k,v["t"]])
|
||||
# logs = sorted(logs)
|
||||
logs = sorted(logs, key= lambda logs : logs[1],reverse=True)
|
||||
self.test_com["value"] =[item[0] for item in logs]
|
||||
|
||||
self.test_com.current(0)
|
||||
self.current_ckpt = first_level
|
||||
print("Checkpoints list update success!")
|
||||
|
||||
def CopyPasswd(self):
|
||||
def copy():
|
||||
@@ -557,13 +693,25 @@ class Application(tk.Frame):
|
||||
|
||||
def Test(self):
|
||||
def test_task():
|
||||
ip = self.list_com.get()
|
||||
log = self.log_com.get()
|
||||
ckpt = self.test_com.get()
|
||||
script_name = self.testscript_com.get()
|
||||
id_path = self.ID_path.get()
|
||||
attr_path = self.Attr_path.get()
|
||||
preprocess = self.preprocess_com.get()
|
||||
# if preprocess == "Preprocess-Off":
|
||||
# preprocess = "off"
|
||||
# else:
|
||||
# preprocess = "on"
|
||||
ckpt = re.sub("\D", "", ckpt)
|
||||
cwd = os.getcwd()
|
||||
files = str(Path(log, ckpt))
|
||||
print(files)
|
||||
# files = str(Path(log, ckpt))
|
||||
# print(files)
|
||||
print("start cmd /k \"cd /d %s && conda activate base \
|
||||
&& python test.py -v %s -s %s -t %s -n %s -i %s -a %s --preprocess %s \""%(cwd, log, ckpt, script_name, ip, id_path, attr_path, preprocess))
|
||||
subprocess.check_call("start cmd /k \"cd /d %s && conda activate base \
|
||||
&& python test.py --model %s\""%(cwd, files), shell=True)
|
||||
&& python test.py -v %s -s %s -t %s -n %s --preprocess %s -i %s -a %s\""%(cwd, log, ckpt, script_name, ip, preprocess , id_path, attr_path), shell=True)
|
||||
thread_update = threading.Thread(target=test_task)
|
||||
thread_update.start()
|
||||
|
||||
@@ -595,7 +743,7 @@ class Application(tk.Frame):
|
||||
ssh_username = cur_mac["user"]
|
||||
ssh_passwd = cur_mac["passwd"]
|
||||
ssh_port = int(cur_mac["port"])
|
||||
print(ssh_ip)
|
||||
print("Processing IP: %s."%ssh_ip)
|
||||
if ip.lower() == "local" or ip.lower() == "localhost":
|
||||
print("localhost no need to connect!")
|
||||
return [], cur_mac
|
||||
@@ -607,6 +755,7 @@ class Application(tk.Frame):
|
||||
print(cells)
|
||||
|
||||
def update_log_task(self):
|
||||
print("Processing! Do not touch!")
|
||||
remotemachine,mac = self.connection()
|
||||
remote_path = os.path.join(mac["path"],mac["ckp_path"]).replace("\\", "/")
|
||||
if remotemachine == []:
|
||||
@@ -617,16 +766,23 @@ class Application(tk.Frame):
|
||||
"t":"",
|
||||
"p":""
|
||||
}
|
||||
# elif remotemachine == "localhost":
|
||||
|
||||
else:
|
||||
first_level = remotemachine.sshScpGetNames(remote_path)
|
||||
if len(first_level) == 0:
|
||||
print("No training log found!")
|
||||
return
|
||||
logs = []
|
||||
for k,v in first_level.items():
|
||||
logs.append(k)
|
||||
logs = sorted(logs)
|
||||
self.log_com["value"] =logs
|
||||
logs.append([k,v["t"]])
|
||||
# logs = sorted(logs)
|
||||
logs = sorted(logs, key= lambda logs : logs[1],reverse=True)
|
||||
self.log_com["value"] = [item[0] for item in logs]
|
||||
self.log_com.current(0)
|
||||
self.current_log = first_level
|
||||
self.update_ckpt_task()
|
||||
print("Done!")
|
||||
|
||||
def UpdateLog(self):
|
||||
thread_update = threading.Thread(target=self.update_log_task)
|
||||
@@ -666,6 +822,21 @@ class Application(tk.Frame):
|
||||
|
||||
thread_update = threading.Thread(target=pull_log_task)
|
||||
thread_update.start()
|
||||
|
||||
def TestConfig(self):
|
||||
def test_config_task():
|
||||
subprocess.call("start %s"%"test.py", shell=True)
|
||||
thread_update = threading.Thread(target=test_config_task)
|
||||
thread_update.start()
|
||||
|
||||
def OpenSample(self):
|
||||
def open_cmd_task():
|
||||
log = self.log_com.get()
|
||||
cwd = os.getcwd()
|
||||
sample = os.path.join(cwd,"test_logs",log,"samples")
|
||||
subprocess.call("explorer "+sample, shell=False)
|
||||
thread_update = threading.Thread(target=open_cmd_task)
|
||||
thread_update.start()
|
||||
|
||||
def OpenCMD(self):
|
||||
def open_cmd_task():
|
||||
@@ -685,6 +856,7 @@ class Application(tk.Frame):
|
||||
ip = self.list_com.get()
|
||||
if ip.lower() == "local" or ip.lower() == "localhost":
|
||||
print("localhost no need to connect!")
|
||||
return
|
||||
cur_mac = self.machine_dict[ip]
|
||||
ssh_ip = cur_mac["ip"]
|
||||
ssh_username = cur_mac["user"]
|
||||
@@ -707,6 +879,10 @@ class Application(tk.Frame):
|
||||
def GPUUsage(self):
|
||||
def gpu_usage_task():
|
||||
remotemachine,_ = self.connection()
|
||||
if remotemachine == "local" or remotemachine == "localhost":
|
||||
print("localhost no need to connect!")
|
||||
return
|
||||
|
||||
results = remotemachine.sshExec("nvidia-smi")
|
||||
print(results)
|
||||
|
||||
|
||||
@@ -1,91 +1,91 @@
|
||||
{
|
||||
"GUI.py": 1647657822.9152665,
|
||||
"test.py": 1647879709.2723496,
|
||||
"train.py": 1647657822.9562755,
|
||||
"components\\Generator.py": 1647657822.93127,
|
||||
"components\\projected_discriminator.py": 1647657822.938272,
|
||||
"components\\pg_modules\\blocks.py": 1647657822.9362714,
|
||||
"components\\pg_modules\\diffaug.py": 1647657822.9362714,
|
||||
"components\\pg_modules\\discriminator.py": 1647657822.937271,
|
||||
"components\\pg_modules\\networks_fastgan.py": 1647657822.937271,
|
||||
"components\\pg_modules\\networks_stylegan2.py": 1647657822.937271,
|
||||
"components\\pg_modules\\projector.py": 1647657822.938272,
|
||||
"data_tools\\data_loader.py": 1647657822.9392715,
|
||||
"data_tools\\data_loader_condition.py": 1647657822.9402719,
|
||||
"data_tools\\data_loader_VGGFace2HQ.py": 1647657822.9392715,
|
||||
"data_tools\\StyleResize.py": 1647657822.9392715,
|
||||
"data_tools\\test_dataloader_dir.py": 1647657822.941272,
|
||||
"losses\\PerceptualLoss.py": 1647657822.9432724,
|
||||
"losses\\SliceWassersteinDistance.py": 1647657822.9432724,
|
||||
"models\\arcface_models.py": 1647657822.9442725,
|
||||
"models\\config.py": 1647657822.9442725,
|
||||
"models\\__init__.py": 1647657822.9442725,
|
||||
"test_scripts\\tester_common.py": 1647657822.9472733,
|
||||
"GUI.py": 1649868392.5891902,
|
||||
"test.py": 1649641910.555175,
|
||||
"train.py": 1643397924.974299,
|
||||
"components\\Generator.py": 1644689001.9005148,
|
||||
"components\\projected_discriminator.py": 1642348101.4661522,
|
||||
"components\\pg_modules\\blocks.py": 1640773190.0,
|
||||
"components\\pg_modules\\diffaug.py": 1640773190.0,
|
||||
"components\\pg_modules\\discriminator.py": 1642349784.9407308,
|
||||
"components\\pg_modules\\networks_fastgan.py": 1640773190.0,
|
||||
"components\\pg_modules\\networks_stylegan2.py": 1640773190.0,
|
||||
"components\\pg_modules\\projector.py": 1642349764.3896568,
|
||||
"data_tools\\data_loader.py": 1611123530.660446,
|
||||
"data_tools\\data_loader_condition.py": 1625411562.8217106,
|
||||
"data_tools\\data_loader_VGGFace2HQ.py": 1644234949.3769877,
|
||||
"data_tools\\StyleResize.py": 1624954084.7176485,
|
||||
"data_tools\\test_dataloader_dir.py": 1634041792.6743984,
|
||||
"losses\\PerceptualLoss.py": 1615020169.668723,
|
||||
"losses\\SliceWassersteinDistance.py": 1634022704.6082795,
|
||||
"models\\arcface_models.py": 1642390690.623,
|
||||
"models\\config.py": 1632643596.2908099,
|
||||
"models\\__init__.py": 1642390864.8828168,
|
||||
"test_scripts\\tester_common.py": 1625369535.199175,
|
||||
"test_scripts\\tester_FastNST.py": 1634041357.607633,
|
||||
"train_scripts\\trainer_base.py": 1647657822.9582758,
|
||||
"train_scripts\\trainer_FM.py": 1647657822.957276,
|
||||
"train_scripts\\trainer_naiv512.py": 1647657822.9602764,
|
||||
"utilities\\checkpoint_manager.py": 1647657822.9652774,
|
||||
"utilities\\figure.py": 1647657822.9652774,
|
||||
"utilities\\json_config.py": 1647657822.9652774,
|
||||
"utilities\\learningrate_scheduler.py": 1647657822.9652774,
|
||||
"utilities\\logo_class.py": 1647657822.9662776,
|
||||
"utilities\\plot.py": 1647657822.9662776,
|
||||
"utilities\\reporter.py": 1647657822.9662776,
|
||||
"utilities\\save_heatmap.py": 1647657822.967278,
|
||||
"utilities\\sshupload.py": 1647657822.967278,
|
||||
"utilities\\transfer_checkpoint.py": 1647657822.967278,
|
||||
"utilities\\utilities.py": 1647657822.9682784,
|
||||
"utilities\\yaml_config.py": 1647657822.9682784,
|
||||
"train_yamls\\train_512FM.yaml": 1647657822.961277,
|
||||
"train_scripts\\trainer_2layer_FM.py": 1647657822.957276,
|
||||
"train_yamls\\train_2layer_FM.yaml": 1647657822.961277,
|
||||
"components\\Generator_reduce.py": 1647657822.934271,
|
||||
"train_scripts\\trainer_base.py": 1642396105.3868554,
|
||||
"train_scripts\\trainer_FM.py": 1643021959.3577182,
|
||||
"train_scripts\\trainer_naiv512.py": 1642315674.9740853,
|
||||
"utilities\\checkpoint_manager.py": 1611123530.6624403,
|
||||
"utilities\\figure.py": 1611123530.6634378,
|
||||
"utilities\\json_config.py": 1611123530.6614666,
|
||||
"utilities\\learningrate_scheduler.py": 1611123530.675422,
|
||||
"utilities\\logo_class.py": 1633883995.3093486,
|
||||
"utilities\\plot.py": 1641911100.7995758,
|
||||
"utilities\\reporter.py": 1646311333.3067005,
|
||||
"utilities\\save_heatmap.py": 1611123530.679439,
|
||||
"utilities\\sshupload.py": 1649910787.5441866,
|
||||
"utilities\\transfer_checkpoint.py": 1642397157.0163105,
|
||||
"utilities\\utilities.py": 1649907294.9180465,
|
||||
"utilities\\yaml_config.py": 1611123530.6614666,
|
||||
"train_yamls\\train_512FM.yaml": 1643021615.8106658,
|
||||
"train_scripts\\trainer_2layer_FM.py": 1642826548.2530458,
|
||||
"train_yamls\\train_2layer_FM.yaml": 1642411635.5534878,
|
||||
"components\\Generator_reduce.py": 1645020911.0651233,
|
||||
"insightface_func\\face_detect_crop_multi.py": 1643796928.6362474,
|
||||
"insightface_func\\face_detect_crop_single.py": 1638370471.7967434,
|
||||
"insightface_func\\__init__.py": 1624197300.011183,
|
||||
"insightface_func\\utils\\face_align_ffhqandnewarc.py": 1638370471.850638,
|
||||
"losses\\PatchNCE.py": 1647677173.0239084,
|
||||
"losses\\PatchNCE.py": 1647769120.567006,
|
||||
"parsing_model\\model.py": 1626745709.554252,
|
||||
"parsing_model\\resnet.py": 1626745709.554252,
|
||||
"test_scripts\\tester_common copy.py": 1647657822.9472733,
|
||||
"test_scripts\\tester_video.py": 1647657822.9482737,
|
||||
"train_scripts\\trainer_cycleloss.py": 1647657822.9592762,
|
||||
"train_scripts\\trainer_GramFM.py": 1647657822.9582758,
|
||||
"utilities\\ImagenetNorm.py": 1647657822.9642777,
|
||||
"utilities\\reverse2original.py": 1647657822.9662776,
|
||||
"train_yamls\\train_cycleloss.yaml": 1647704989.2310138,
|
||||
"train_yamls\\train_GramFM.yaml": 1647657822.9622767,
|
||||
"train_yamls\\train_512FM_Modulation.yaml": 1647657822.961277,
|
||||
"face_crop.py": 1647657822.9422722,
|
||||
"face_crop_video.py": 1647657822.9422722,
|
||||
"similarity.py": 1647657822.945273,
|
||||
"train_multigpu.py": 1647967698.603863,
|
||||
"components\\arcface_decoder.py": 1647657822.9352713,
|
||||
"test_scripts\\tester_common copy.py": 1625369535.199175,
|
||||
"test_scripts\\tester_video.py": 1649749352.843031,
|
||||
"train_scripts\\trainer_cycleloss.py": 1642580463.495596,
|
||||
"train_scripts\\trainer_GramFM.py": 1643095575.2628715,
|
||||
"utilities\\ImagenetNorm.py": 1642732910.5280058,
|
||||
"utilities\\reverse2original.py": 1648533907.6187606,
|
||||
"train_yamls\\train_cycleloss.yaml": 1647769120.6110919,
|
||||
"train_yamls\\train_GramFM.yaml": 1643398791.363959,
|
||||
"train_yamls\\train_512FM_Modulation.yaml": 1643022022.3165789,
|
||||
"face_crop.py": 1649079350.7075238,
|
||||
"face_crop_video.py": 1649988435.4005315,
|
||||
"similarity.py": 1643269705.1073737,
|
||||
"train_multigpu.py": 1650004781.6307411,
|
||||
"components\\arcface_decoder.py": 1643396144.2575414,
|
||||
"components\\Generator_nobias.py": 1643179001.810856,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1647657822.9402719,
|
||||
"data_tools\\data_loader_VGGFace2HQ_Rec.py": 1647657822.9392715,
|
||||
"test_scripts\\tester_arcface_Rec.py": 1647657822.946273,
|
||||
"test_scripts\\tester_image.py": 1647657822.9472733,
|
||||
"torch_utils\\custom_ops.py": 1647657822.9482737,
|
||||
"torch_utils\\misc.py": 1647657822.9492736,
|
||||
"torch_utils\\persistence.py": 1647657822.9552753,
|
||||
"torch_utils\\training_stats.py": 1647657822.9562755,
|
||||
"torch_utils\\utils_spectrum.py": 1647657822.9562755,
|
||||
"torch_utils\\__init__.py": 1647657822.9482737,
|
||||
"torch_utils\\ops\\bias_act.py": 1647657822.9502747,
|
||||
"torch_utils\\ops\\conv2d_gradfix.py": 1647657822.9512744,
|
||||
"torch_utils\\ops\\conv2d_resample.py": 1647657822.9512744,
|
||||
"torch_utils\\ops\\filtered_lrelu.py": 1647657822.9532747,
|
||||
"torch_utils\\ops\\fma.py": 1647657822.9532747,
|
||||
"torch_utils\\ops\\grid_sample_gradfix.py": 1647657822.9542756,
|
||||
"torch_utils\\ops\\upfirdn2d.py": 1647657822.9552753,
|
||||
"torch_utils\\ops\\__init__.py": 1647657822.9492736,
|
||||
"train_scripts\\trainer_arcface_rec.py": 1647657822.9582758,
|
||||
"train_scripts\\trainer_multigpu_base.py": 1647657822.9602764,
|
||||
"train_scripts\\trainer_multi_gpu.py": 1647657822.9592762,
|
||||
"train_yamls\\train_arcface_rec.yaml": 1647657822.9622767,
|
||||
"train_yamls\\train_multigpu.yaml": 1647657822.963277,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1649177633.2426238,
|
||||
"data_tools\\data_loader_VGGFace2HQ_Rec.py": 1643398754.86898,
|
||||
"test_scripts\\tester_arcface_Rec.py": 1643431261.9333818,
|
||||
"test_scripts\\tester_image.py": 1648028827.923354,
|
||||
"torch_utils\\custom_ops.py": 1640773190.0,
|
||||
"torch_utils\\misc.py": 1640773190.0,
|
||||
"torch_utils\\persistence.py": 1640773190.0,
|
||||
"torch_utils\\training_stats.py": 1640773190.0,
|
||||
"torch_utils\\utils_spectrum.py": 1640773190.0,
|
||||
"torch_utils\\__init__.py": 1640773190.0,
|
||||
"torch_utils\\ops\\bias_act.py": 1640773190.0,
|
||||
"torch_utils\\ops\\conv2d_gradfix.py": 1640773190.0,
|
||||
"torch_utils\\ops\\conv2d_resample.py": 1640773190.0,
|
||||
"torch_utils\\ops\\filtered_lrelu.py": 1640773190.0,
|
||||
"torch_utils\\ops\\fma.py": 1640773190.0,
|
||||
"torch_utils\\ops\\grid_sample_gradfix.py": 1640773190.0,
|
||||
"torch_utils\\ops\\upfirdn2d.py": 1640773190.0,
|
||||
"torch_utils\\ops\\__init__.py": 1640773190.0,
|
||||
"train_scripts\\trainer_arcface_rec.py": 1643399647.0182135,
|
||||
"train_scripts\\trainer_multigpu_base.py": 1644131205.772292,
|
||||
"train_scripts\\trainer_multi_gpu.py": 1648285132.309124,
|
||||
"train_yamls\\train_arcface_rec.yaml": 1643398807.3434353,
|
||||
"train_yamls\\train_multigpu.yaml": 1644549590.0652373,
|
||||
"wandb\\run-20220129_032741-340btp9k\\files\\conda-environment.yaml": 1643398065.409959,
|
||||
"wandb\\run-20220129_032741-340btp9k\\files\\config.yaml": 1643398069.2392955,
|
||||
"wandb\\run-20220129_032939-2nmaozxq\\files\\conda-environment.yaml": 1643398182.647548,
|
||||
@@ -100,93 +100,208 @@
|
||||
"wandb\\run-20220129_034859-2puk6sph\\files\\config.yaml": 1643399477.881678,
|
||||
"wandb\\run-20220129_035624-3hmwgcgw\\files\\conda-environment.yaml": 1643399787.8899708,
|
||||
"wandb\\run-20220129_035624-3hmwgcgw\\files\\config.yaml": 1643426465.6088357,
|
||||
"dnnlib\\util.py": 1647657822.941272,
|
||||
"dnnlib\\__init__.py": 1647657822.941272,
|
||||
"components\\Generator_ori.py": 1647657822.9332705,
|
||||
"losses\\cos.py": 1647657822.9442725,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1647657822.9402719,
|
||||
"speed_test.py": 1647657822.945273,
|
||||
"components\\DeConv_Invo.py": 1647657822.9302697,
|
||||
"dnnlib\\util.py": 1640773190.0,
|
||||
"dnnlib\\__init__.py": 1640773190.0,
|
||||
"components\\Generator_ori.py": 1644689174.414655,
|
||||
"losses\\cos.py": 1644229583.4023254,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1644860106.943826,
|
||||
"speed_test.py": 1648982366.0803514,
|
||||
"components\\DeConv_Invo.py": 1644426607.1588645,
|
||||
"components\\Generator_reduce_up.py": 1644688655.2096283,
|
||||
"components\\Generator_upsample.py": 1647657822.9352713,
|
||||
"components\\misc\\Involution.py": 1647657822.9352713,
|
||||
"train_yamls\\train_Invoup.yaml": 1647657822.9622767,
|
||||
"flops.py": 1647657822.9422722,
|
||||
"detection_test.py": 1647657822.941272,
|
||||
"components\\DeConv_Depthwise.py": 1647657822.9292698,
|
||||
"components\\DeConv_Depthwise1.py": 1647657822.9292698,
|
||||
"components\\Generator_upsample.py": 1644689723.8293872,
|
||||
"components\\misc\\Involution.py": 1644509321.5267963,
|
||||
"train_yamls\\train_Invoup.yaml": 1644689981.9794765,
|
||||
"flops.py": 1649040334.6186154,
|
||||
"detection_test.py": 1644935512.6830947,
|
||||
"components\\DeConv_Depthwise.py": 1645064447.4379447,
|
||||
"components\\DeConv_Depthwise1.py": 1644946969.5054545,
|
||||
"components\\Generator_modulation_depthwise.py": 1644861291.4467516,
|
||||
"components\\Generator_modulation_depthwise_config.py": 1647657822.93227,
|
||||
"components\\Generator_modulation_up.py": 1647657822.9332705,
|
||||
"components\\Generator_oriae_modulation.py": 1647657822.934271,
|
||||
"components\\Generator_ori_config.py": 1647657822.934271,
|
||||
"train_scripts\\trainer_multi_gpu1.py": 1647657822.9602764,
|
||||
"train_yamls\\train_Depthwise.yaml": 1647657822.961277,
|
||||
"train_yamls\\train_depthwise_modulation.yaml": 1647657822.963277,
|
||||
"train_yamls\\train_oriae_modulation.yaml": 1647657822.9642777,
|
||||
"train_distillation_mgpu.py": 1647657822.9562755,
|
||||
"components\\DeConv.py": 1647657822.9292698,
|
||||
"components\\DeConv_Depthwise_ECA.py": 1647657822.9292698,
|
||||
"components\\ECA.py": 1647657822.9302697,
|
||||
"components\\ECA_Depthwise_Conv.py": 1647657822.93127,
|
||||
"components\\Generator_eca_depthwise.py": 1647657822.93227,
|
||||
"losses\\KA.py": 1647657822.9432724,
|
||||
"train_scripts\\trainer_distillation_mgpu.py": 1647657822.9592762,
|
||||
"train_yamls\\train_distillation.yaml": 1647657822.963277,
|
||||
"annotation.py": 1647657822.9172668,
|
||||
"components\\DeConv_ECA_Invo.py": 1647657822.9302697,
|
||||
"components\\DeConv_Invobn.py": 1647657822.9302697,
|
||||
"components\\Generator_Invobn_config.py": 1647657822.93127,
|
||||
"components\\Generator_Invobn_config1.py": 1647657822.93127,
|
||||
"components\\misc\\Involution_BN.py": 1647657822.9362714,
|
||||
"components\\misc\\Involution_ECA.py": 1647657822.9362714,
|
||||
"train_yamls\\train_Invobn_config.yaml": 1647657822.9622767,
|
||||
"components\\Generator_Invobn_config2.py": 1647657822.93227,
|
||||
"components\\Generator_Invobn_config3.py": 1647657822.93227,
|
||||
"components\\Generator_ori_modulation_config.py": 1647657822.934271,
|
||||
"test_scripts\\tester_image_allstep.py": 1647657822.9482737,
|
||||
"train_yamls\\train_ori_modulation_config.yaml": 1647657822.9642777,
|
||||
"test_arcface.py": 1647657822.946273,
|
||||
"arcface_torch\\dataset.py": 1647657822.9222684,
|
||||
"arcface_torch\\eval_ijbc.py": 1647657822.9242685,
|
||||
"arcface_torch\\inference.py": 1647657822.9242685,
|
||||
"arcface_torch\\losses.py": 1647657822.9242685,
|
||||
"arcface_torch\\lr_scheduler.py": 1647657822.9242685,
|
||||
"arcface_torch\\onnx_helper.py": 1647657822.9252684,
|
||||
"arcface_torch\\onnx_ijbc.py": 1647657822.9252684,
|
||||
"arcface_torch\\partial_fc.py": 1647657822.9252684,
|
||||
"arcface_torch\\torch2onnx.py": 1647657822.9262686,
|
||||
"arcface_torch\\train.py": 1647657822.9262686,
|
||||
"arcface_torch\\backbones\\iresnet.py": 1647657822.918267,
|
||||
"arcface_torch\\backbones\\iresnet2060.py": 1647657822.9192681,
|
||||
"arcface_torch\\backbones\\mobilefacenet.py": 1647657822.9192681,
|
||||
"arcface_torch\\backbones\\__init__.py": 1647657822.918267,
|
||||
"arcface_torch\\configs\\3millions.py": 1647657822.9192681,
|
||||
"arcface_torch\\configs\\base.py": 1647657822.9192681,
|
||||
"arcface_torch\\configs\\glint360k_mobileface_lr02_bs4k.py": 1647657822.9202676,
|
||||
"arcface_torch\\configs\\glint360k_r100_lr02_bs4k_16gpus.py": 1647657822.9202676,
|
||||
"arcface_torch\\configs\\ms1mv3_mobileface_lr02.py": 1647657822.9202676,
|
||||
"arcface_torch\\configs\\ms1mv3_r100_lr02.py": 1647657822.9202676,
|
||||
"arcface_torch\\configs\\ms1mv3_r50_lr02.py": 1647657822.9202676,
|
||||
"arcface_torch\\configs\\webface42m_mobilefacenet_pfc02_bs8k_16gpus.py": 1647657822.9212687,
|
||||
"arcface_torch\\configs\\webface42m_r100_lr01_pfc02_bs4k_16gpus.py": 1647657822.9212687,
|
||||
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_32gpus.py": 1647657822.9212687,
|
||||
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_8gpus.py": 1647657822.9212687,
|
||||
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs8k_16gpus.py": 1647657822.9212687,
|
||||
"arcface_torch\\configs\\__init__.py": 1647657822.9192681,
|
||||
"arcface_torch\\eval\\verification.py": 1647657822.923268,
|
||||
"arcface_torch\\eval\\__init__.py": 1647657822.923268,
|
||||
"arcface_torch\\utils\\plot.py": 1647657822.927269,
|
||||
"arcface_torch\\utils\\utils_callbacks.py": 1647657822.927269,
|
||||
"arcface_torch\\utils\\utils_config.py": 1647657822.927269,
|
||||
"arcface_torch\\utils\\utils_logging.py": 1647657822.927269,
|
||||
"arcface_torch\\utils\\__init__.py": 1647657822.9262686,
|
||||
"components\\LSTU.py": 1647702612.240765,
|
||||
"test_scripts\\tester_ID_Pose.py": 1647657822.946273,
|
||||
"train_scripts\\trainer_distillation_mgpu_withrec_importweight.py": 1647657822.9592762,
|
||||
"train_scripts\\trainer_multi_gpu_CUT.py": 1647676964.475,
|
||||
"train_scripts\\trainer_multi_gpu_cycle.py": 1647705628.7020626,
|
||||
"components\\Generator_LSTU_config.py": 1647954099.1135788,
|
||||
"components\\Generator_Res_config.py": 1648006159.4385264,
|
||||
"train_yamls\\train_cycleloss_res.yaml": 1648006232.5734456
|
||||
"components\\Generator_modulation_depthwise_config.py": 1645262162.9779513,
|
||||
"components\\Generator_modulation_up.py": 1644946498.7005584,
|
||||
"components\\Generator_oriae_modulation.py": 1644897798.1987727,
|
||||
"components\\Generator_ori_config.py": 1646329319.6131227,
|
||||
"train_scripts\\trainer_multi_gpu1.py": 1644859528.8428593,
|
||||
"train_yamls\\train_Depthwise.yaml": 1644860961.099242,
|
||||
"train_yamls\\train_depthwise_modulation.yaml": 1645035964.9551077,
|
||||
"train_yamls\\train_oriae_modulation.yaml": 1644897891.2576747,
|
||||
"train_distillation_mgpu.py": 1645554603.908166,
|
||||
"components\\DeConv.py": 1645263338.9001615,
|
||||
"components\\DeConv_Depthwise_ECA.py": 1645265769.1076133,
|
||||
"components\\ECA.py": 1614848426.9604986,
|
||||
"components\\ECA_Depthwise_Conv.py": 1645265754.2023985,
|
||||
"components\\Generator_eca_depthwise.py": 1645266338.9750814,
|
||||
"losses\\KA.py": 1646388425.4841197,
|
||||
"train_scripts\\trainer_distillation_mgpu.py": 1645601961.4139585,
|
||||
"train_yamls\\train_distillation.yaml": 1645600099.540936,
|
||||
"annotation.py": 1648654581.017103,
|
||||
"components\\DeConv_ECA_Invo.py": 1645869347.379311,
|
||||
"components\\DeConv_Invobn.py": 1645862876.018001,
|
||||
"components\\Generator_Invobn_config.py": 1645929418.6924264,
|
||||
"components\\Generator_Invobn_config1.py": 1645862695.8743145,
|
||||
"components\\misc\\Involution_BN.py": 1645867197.3984175,
|
||||
"components\\misc\\Involution_ECA.py": 1645869012.4927464,
|
||||
"train_yamls\\train_Invobn_config.yaml": 1646101598.499709,
|
||||
"components\\Generator_Invobn_config2.py": 1645962618.7056074,
|
||||
"components\\Generator_Invobn_config3.py": 1646302561.1984286,
|
||||
"components\\Generator_ori_modulation_config.py": 1646329636.719998,
|
||||
"test_scripts\\tester_image_allstep.py": 1646312637.9363256,
|
||||
"train_yamls\\train_ori_modulation_config.yaml": 1646330406.200162,
|
||||
"test_arcface.py": 1647448497.69041,
|
||||
"arcface_torch\\dataset.py": 1647445446.261035,
|
||||
"arcface_torch\\eval_ijbc.py": 1647445446.2630043,
|
||||
"arcface_torch\\inference.py": 1647445446.2630043,
|
||||
"arcface_torch\\losses.py": 1647445446.2630043,
|
||||
"arcface_torch\\lr_scheduler.py": 1647445446.2630043,
|
||||
"arcface_torch\\onnx_helper.py": 1647445446.2630043,
|
||||
"arcface_torch\\onnx_ijbc.py": 1647445446.2640254,
|
||||
"arcface_torch\\partial_fc.py": 1647445446.2640254,
|
||||
"arcface_torch\\torch2onnx.py": 1647445446.2640254,
|
||||
"arcface_torch\\train.py": 1647445446.2649992,
|
||||
"arcface_torch\\backbones\\iresnet.py": 1647445446.2580183,
|
||||
"arcface_torch\\backbones\\iresnet2060.py": 1647445446.2580183,
|
||||
"arcface_torch\\backbones\\mobilefacenet.py": 1647445446.2580183,
|
||||
"arcface_torch\\backbones\\__init__.py": 1647445446.25702,
|
||||
"arcface_torch\\configs\\3millions.py": 1647445446.2580183,
|
||||
"arcface_torch\\configs\\base.py": 1647445446.259039,
|
||||
"arcface_torch\\configs\\glint360k_mobileface_lr02_bs4k.py": 1647445446.259039,
|
||||
"arcface_torch\\configs\\glint360k_r100_lr02_bs4k_16gpus.py": 1647445446.259039,
|
||||
"arcface_torch\\configs\\ms1mv3_mobileface_lr02.py": 1647445446.259039,
|
||||
"arcface_torch\\configs\\ms1mv3_r100_lr02.py": 1647445446.259039,
|
||||
"arcface_torch\\configs\\ms1mv3_r50_lr02.py": 1647445446.260039,
|
||||
"arcface_torch\\configs\\webface42m_mobilefacenet_pfc02_bs8k_16gpus.py": 1647445446.260039,
|
||||
"arcface_torch\\configs\\webface42m_r100_lr01_pfc02_bs4k_16gpus.py": 1647445446.260039,
|
||||
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_32gpus.py": 1647445446.260039,
|
||||
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs4k_8gpus.py": 1647445446.260039,
|
||||
"arcface_torch\\configs\\webface42m_r50_lr01_pfc02_bs8k_16gpus.py": 1647445446.260039,
|
||||
"arcface_torch\\configs\\__init__.py": 1647445446.2580183,
|
||||
"arcface_torch\\eval\\verification.py": 1647445446.2620306,
|
||||
"arcface_torch\\eval\\__init__.py": 1647445446.2620306,
|
||||
"arcface_torch\\utils\\plot.py": 1647445446.2649992,
|
||||
"arcface_torch\\utils\\utils_callbacks.py": 1647445446.2649992,
|
||||
"arcface_torch\\utils\\utils_config.py": 1647445446.2649992,
|
||||
"arcface_torch\\utils\\utils_logging.py": 1647445446.2659965,
|
||||
"arcface_torch\\utils\\__init__.py": 1647445446.2649992,
|
||||
"components\\LSTU.py": 1648482475.4378786,
|
||||
"test_scripts\\tester_ID_Pose.py": 1646558809.85301,
|
||||
"train_scripts\\trainer_distillation_mgpu_withrec_importweight.py": 1646391740.1106014,
|
||||
"train_scripts\\trainer_multi_gpu_CUT.py": 1647769120.5968685,
|
||||
"train_scripts\\trainer_multi_gpu_cycle.py": 1648313934.2140906,
|
||||
"components\\Generator_LSTU_config.py": 1648028831.4087331,
|
||||
"components\\Generator_Res_config.py": 1648053382.053794,
|
||||
"train_yamls\\train_cycleloss_res.yaml": 1648103614.4515965,
|
||||
"clear_dataset.py": 1648920044.7807672,
|
||||
"id_cos.py": 1648569510.5593822,
|
||||
"translation_list2json.py": 1648106406.478007,
|
||||
"components\\Generator_ResSkip_config.py": 1648526573.2056787,
|
||||
"components\\Generator_Res_config1.py": 1648092270.7609806,
|
||||
"components\\Generator_Res_config2.py": 1648103885.6257715,
|
||||
"test_scripts\\tester_image_list.py": 1648145245.0948818,
|
||||
"test_scripts\\tester_image_nofusion.py": 1648096849.0402405,
|
||||
"train_yamls\\train_cycleloss_resskip.yaml": 1648313962.968481,
|
||||
"check_list.txt": 1648657338.5051336,
|
||||
"test_imgs_list.txt": 1649574560.1498592,
|
||||
"vggface2hq_failed.txt": 1648926017.999394,
|
||||
"arcface_torch\\requirement.txt": 1647445446.2640254,
|
||||
"wandb\\run-20220129_032741-340btp9k\\files\\requirements.txt": 1643398065.409959,
|
||||
"wandb\\run-20220129_032939-2nmaozxq\\files\\requirements.txt": 1643398182.647548,
|
||||
"wandb\\run-20220129_033051-21z19tyg\\files\\requirements.txt": 1643398254.926299,
|
||||
"wandb\\run-20220129_033202-16la4gpu\\files\\requirements.txt": 1643398325.8784783,
|
||||
"wandb\\run-20220129_034327-1bmseytq\\files\\requirements.txt": 1643399010.865907,
|
||||
"wandb\\run-20220129_034859-2puk6sph\\files\\requirements.txt": 1643399343.3508356,
|
||||
"wandb\\run-20220129_035624-3hmwgcgw\\files\\requirements.txt": 1643399787.8869605,
|
||||
"components\\Generator_featout_config.py": 1648877243.4813964,
|
||||
"components\\Generator_ResSkip_config1.py": 1648530486.7970698,
|
||||
"components\\LSTU_Config.py": 1648528200.7229428,
|
||||
"components\\Nonstau_Discriminator.py": 1648476236.8430562,
|
||||
"components\\Nonstau_Discriminator_FM.py": 1649833901.6153808,
|
||||
"metrics\\equivariance.py": 1640773190.0,
|
||||
"metrics\\frechet_inception_distance.py": 1640773190.0,
|
||||
"metrics\\inception_score.py": 1640773190.0,
|
||||
"metrics\\kernel_inception_distance.py": 1640773190.0,
|
||||
"metrics\\metric_main.py": 1640773190.0,
|
||||
"metrics\\metric_utils.py": 1640773190.0,
|
||||
"metrics\\perceptual_path_length.py": 1640773190.0,
|
||||
"metrics\\precision_recall.py": 1640773190.0,
|
||||
"test_scripts\\tester_image_list_w_mask.py": 1649074097.8263073,
|
||||
"test_scripts\\tester_image_w_mask.py": 1648529828.1194224,
|
||||
"train_scripts\\trainer_mgpu_fm.py": 1648878513.5066502,
|
||||
"train_scripts\\trainer_multi_gpu_cycle_nonstatue_dis.py": 1648559801.6695006,
|
||||
"train_yamls\\train_cycleloss_fm_nonstatu.yaml": 1648572102.661056,
|
||||
"train_yamls\\train_cycleloss_resskip_nonstatu.yaml": 1648527833.9132054,
|
||||
"components\\Generator_Res_config3.py": 1648628068.1252878,
|
||||
"data_tools\\data_loader_FFHQ_multigpu.py": 1648640139.408,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\align_and_crop_dir.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\generate_mask.py": 1648652827.224725,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\test_enhance_dir_align.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\test_enhance_dir_unalign.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\test_enhance_single_unalign.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\train.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\base_dataset.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\celebahqmask_dataset.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\ffhq_dataset.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\image_folder.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\single_dataset.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\data\\__init__.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\base_model.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\blocks.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\enhance_model.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\loss.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\networks.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\parse_model.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\psfrnet.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\models\\__init__.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\base_options.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\test_options.py": 1648647365.0084722,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\train_options.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\options\\__init__.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\utils\\logger.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\utils\\timer.py": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\utils\\utils.py": 1648653194.9661705,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\requirements.txt": 1640774152.0,
|
||||
"face_parse\\PSFRGAN-master\\PSFRGAN-master\\check_points\\experiment_name\\test_opt.txt": 1648653211.0647676,
|
||||
"components\\Generator_maskhead_config.py": 1648919192.4128494,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu_w_mask.py": 1648950658.8917239,
|
||||
"train_scripts\\trainer_mgpu_fm_w_mask.py": 1649842622.1032736,
|
||||
"train_scripts\\trainer_mgpu_maskloss.py": 1650003673.00616,
|
||||
"train_yamls\\train_maskhead_fm_nonstatu.yaml": 1648918857.2325976,
|
||||
"train_yamls\\train_maskhead_maskloss.yaml": 1648919545.869737,
|
||||
"dataset.check.py": 1648925868.9100032,
|
||||
"components\\Generator_involution_maskhead_config.py": 1648919192.4128494,
|
||||
"components\\Generator_VGGStyle_maskhead_config.py": 1649177890.0438528,
|
||||
"train_yamls\\train_maskhead_fm_vggstyle.yaml": 1649177995.515654,
|
||||
"train_yamls\\train_maskloss.yaml": 1648920140.455,
|
||||
"components\\Generator_maskhead_config1.py": 1649248730.7064345,
|
||||
"train_yamls\\train_maskhead_hififace.yaml": 1649345218.4856737,
|
||||
"filter.py": 1649836163.761168,
|
||||
"test_json.py": 1649232568.5223434,
|
||||
"components\\Generator_2maskhead_config copy.py": 1649816572.443475,
|
||||
"components\\Generator_2maskhead_config.py": 1649816572.443475,
|
||||
"components\\Generator_maskhead_config2.py": 1649833973.4127545,
|
||||
"components\\Generator_starganv2.py": 1649845840.4322417,
|
||||
"face_enhancer\\gfpgan\\train.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\utils.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\version.py": 1648618484.3356814,
|
||||
"face_enhancer\\gfpgan\\__init__.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\archs\\arcface_arch.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\archs\\gfpganv1_arch.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\archs\\gfpganv1_clean_arch.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\archs\\gfpgan_bilinear_arch.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\archs\\stylegan2_bilinear_arch.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\archs\\stylegan2_clean_arch.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\archs\\__init__.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\data\\ffhq_degradation_dataset.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\data\\__init__.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\models\\gfpgan_model.py": 1644942770.0,
|
||||
"face_enhancer\\gfpgan\\models\\__init__.py": 1644942770.0,
|
||||
"face_enhancer\\scripts\\convert_gfpganv_to_clean.py": 1644942770.0,
|
||||
"face_enhancer\\scripts\\parse_landmark.py": 1644942770.0,
|
||||
"test_scripts\\tester_image_list_w_2mask.py": 1649768648.90081,
|
||||
"test_scripts\\tester_image_w_2mask.py": 1649729361.2799067,
|
||||
"test_scripts\\tester_image_w_mask_gfpgan.py": 1649872098.8747168,
|
||||
"test_scripts\\tester_video_gfpgan.py": 1649907748.9359775,
|
||||
"train_scripts\\trainer_mgpu_2maskloss.py": 1649699504.6697042,
|
||||
"train_yamls\\train_1maskhead.yaml": 1650003571.40854,
|
||||
"train_yamls\\train_2maskhead.yaml": 1649953481.822017,
|
||||
"train_yamls\\train_maskhead_hififace1.yaml": 1649471700.4954374,
|
||||
"components\\Generator_2mask.py": 1649953828.8860412
|
||||
}
|
||||
+2
-1
@@ -2,7 +2,8 @@
|
||||
"white_list": {
|
||||
"extension": [
|
||||
"py",
|
||||
"yaml"
|
||||
"yaml",
|
||||
"txt"
|
||||
],
|
||||
"file": [],
|
||||
"path": []
|
||||
|
||||
@@ -8,6 +8,15 @@
|
||||
"ckp_path": "train_logs",
|
||||
"logfilename": "filestate_machine0.json"
|
||||
},
|
||||
{
|
||||
"ip": "119.29.91.52",
|
||||
"user": "ubuntu",
|
||||
"port": 22,
|
||||
"passwd": "zpKlOW0sMlyt!xhE",
|
||||
"path": "/home/ubuntu/CXH/simswap_plus",
|
||||
"ckp_path": "train_logs",
|
||||
"logfilename": "filestate_machine3.json"
|
||||
},
|
||||
{
|
||||
"ip": "2001:da8:8000:6880:f284:d61c:3c76:f9cb",
|
||||
"user": "ps",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Simswap++
|
||||
|
||||
## Dependencies
|
||||
- moviepy
|
||||
- python >= 3.7
|
||||
- yaml (pip install pyyaml)
|
||||
- paramiko (For ssh file transportation)
|
||||
|
||||
+2
-2
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 27th February 2022 11:03:58 am
|
||||
# Last Modified: Wednesday, 30th March 2022 11:36:20 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -33,7 +33,7 @@ def str2bool(v):
|
||||
def getParameters():
|
||||
parser = argparse.ArgumentParser()
|
||||
# general
|
||||
parser.add_argument('--image_dir', type=str, default="G:\\VGGFace2-HQ\\VGGface2_None_norm_512_true_bygfpgan")
|
||||
parser.add_argument('--image_dir', type=str, default="G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan")
|
||||
parser.add_argument('--savetxt', type=str, default="./check_list.txt")
|
||||
parser.add_argument('--winWidth', type=int, default=512)
|
||||
parser.add_argument('--winHeight', type=int, default=512)
|
||||
|
||||
+2
-2
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"breakpoint": [
|
||||
31,
|
||||
110
|
||||
54,
|
||||
101
|
||||
]
|
||||
}
|
||||
+1278
File diff suppressed because it is too large
Load Diff
+7
-4
@@ -5,7 +5,7 @@
|
||||
# Created Date: Thursday March 24th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 24th March 2022 3:32:40 pm
|
||||
# Last Modified: Sunday, 3rd April 2022 1:20:44 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -18,11 +18,14 @@ if __name__ == "__main__":
|
||||
savePath = "./vggface2hq_failed.txt"
|
||||
env_config = readConfig('env/env.json')
|
||||
env_config = env_config["path"]
|
||||
dataset_root = env_config["dataset_paths"]["vggface2_hq"]
|
||||
dataset_root = env_config["dataset_paths"]["vggface2_hq"]["images"]
|
||||
# dataset_root = "G:/VGGFace2-HQ/newversion"
|
||||
print(dataset_root)
|
||||
|
||||
with open(savePath,'r') as logf:
|
||||
for line in logf:
|
||||
img_path = os.path.join(dataset_root,line[:-1]).replace("\\","/")
|
||||
img_path = os.path.join(dataset_root,line.replace("\n","")).replace("\\","/")
|
||||
try:
|
||||
os.rename(img_path,img_path+".deleted")
|
||||
except: pass
|
||||
except Exception as e:
|
||||
print(e)
|
||||
@@ -0,0 +1,304 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th April 2022 7:03:46 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"].lower()
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
norm_mask= nn.InstanceNorm2d
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
norm_mask = nn.BatchNorm2d
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
# self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 1
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
# self.maskhead = nn.Sequential(
|
||||
# nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# norm_mask, # 64
|
||||
# activation,
|
||||
# nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Sigmoid())
|
||||
self.maskhead_lr = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
norm_mask(in_channel, affine=True), # 32
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
norm_mask(in_channel//4, affine=True), # 64
|
||||
activation
|
||||
)
|
||||
self.maskhead_hr = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//4, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
norm_mask(in_channel//16, affine=True), # 128
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid() # 256
|
||||
)
|
||||
self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//4, 1, kernel_size=1, stride=1),
|
||||
nn.Sigmoid())
|
||||
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
|
||||
|
||||
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
mask_feat= self.maskhead_lr(res)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
mask_lr= self.maskhead_out(mask_feat)
|
||||
# res = (1-mask) * self.sigma(skip) + mask * res
|
||||
res = (1-mask_lr) * skip + mask_lr * res
|
||||
res = self.up2(res) # + skip
|
||||
res = self.up1(res)
|
||||
res = self.to_rgb(res)
|
||||
mask_hr=self.maskhead_hr(mask_feat)
|
||||
res = (1-mask_hr) * img + mask_hr * res
|
||||
return res, mask_lr, mask_hr
|
||||
@@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 15th April 2022 12:30:27 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"].lower()
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
norm_mask= nn.InstanceNorm2d
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
norm_mask = nn.BatchNorm2d
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
# self.maskhead = nn.Sequential(
|
||||
# nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# norm_mask, # 64
|
||||
# activation,
|
||||
# nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Sigmoid())
|
||||
self.maskhead_lr = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
norm_mask(in_channel, affine=True), # 32
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
norm_mask(in_channel//4, affine=True), # 64
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1),
|
||||
norm_mask(in_channel//8, affine=True), # 128
|
||||
activation,
|
||||
)
|
||||
self.maskhead_hr = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
norm_mask(in_channel//16, affine=True), # 256
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid() # 512
|
||||
)
|
||||
self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1),
|
||||
nn.Sigmoid())
|
||||
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
|
||||
|
||||
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
mask_feat= self.maskhead_lr(res)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
mask_lr= self.maskhead_out(mask_feat)
|
||||
# res = (1-mask) * self.sigma(skip) + mask * res
|
||||
res = (1-mask_lr) * skip + mask_lr * res
|
||||
res = self.up2(res) # + skip
|
||||
res = self.up1(res)
|
||||
res = self.to_rgb(res)
|
||||
mask_hr=self.maskhead_hr(mask_feat)
|
||||
res = (1-mask_hr) * img + mask_hr * res
|
||||
return res, mask_lr, mask_hr
|
||||
@@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th April 2022 12:45:55 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"].lower()
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
norm_mask= nn.InstanceNorm2d
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
norm_mask = nn.BatchNorm2d
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
# self.maskhead = nn.Sequential(
|
||||
# nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# norm_mask, # 64
|
||||
# activation,
|
||||
# nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Sigmoid())
|
||||
self.maskhead_lr = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
norm_mask(in_channel, affine=True), # 32
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
norm_mask(in_channel//4, affine=True), # 64
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1),
|
||||
norm_mask(in_channel//8, affine=True), # 128
|
||||
activation,
|
||||
)
|
||||
self.maskhead_hr = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
norm_mask(in_channel//16, affine=True), # 256
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid() # 512
|
||||
)
|
||||
self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1),
|
||||
nn.Sigmoid())
|
||||
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
|
||||
|
||||
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
# self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
mask_feat= self.maskhead_lr(res)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
mask_lr= self.maskhead_out(mask_feat)
|
||||
# res = (1-mask) * self.sigma(skip) + mask * res
|
||||
res = (1-mask_lr) * skip + mask_lr * res
|
||||
res = self.up2(res,id) # + skip
|
||||
res = self.up1(res,id)
|
||||
res = self.to_rgb(res)
|
||||
mask_hr=self.maskhead_hr(mask_feat)
|
||||
res = (1-mask_hr) * img + mask_hr * res
|
||||
return res, mask_lr, mask_hr
|
||||
@@ -0,0 +1,453 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 18th April 2022 10:20:12 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from components.ModulatedDWConv import ModulatedDWConv2d
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, 3, 1, 1, groups=dim_in),
|
||||
nn.Conv2d(dim_in, dim_in, 1, 1)
|
||||
)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, 3, 1, 1, groups=dim_in),
|
||||
nn.Conv2d(dim_in, dim_out, 1, 1)
|
||||
)
|
||||
# self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
# self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
# class ResUpBlk(nn.Module):
|
||||
# def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
# super().__init__()
|
||||
# self.actv = actv
|
||||
# self.normalize = normalize
|
||||
# self.learned_sc = dim_in != dim_out
|
||||
# self.equal_var = math.sqrt(2)
|
||||
# self._build_weights(dim_in, dim_out)
|
||||
|
||||
# def _build_weights(self, dim_in, dim_out):
|
||||
# self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
# self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
# if self.normalize.lower() == "in":
|
||||
# self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
# self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
# elif self.normalize.lower() == "bn":
|
||||
# self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
# self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
# if self.learned_sc:
|
||||
# self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
# def _shortcut(self, x):
|
||||
# x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
# if self.learned_sc:
|
||||
# x = self.conv1x1(x)
|
||||
# return x
|
||||
|
||||
# def _residual(self, x):
|
||||
# x = self.norm1(x)
|
||||
# x = self.actv(x)
|
||||
# x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
# x = self.conv1(x)
|
||||
# x = self.norm2(x)
|
||||
# x = self.actv(x)
|
||||
# x = self.conv2(x)
|
||||
# return x
|
||||
|
||||
# def forward(self, x):
|
||||
# out = self._residual(x)
|
||||
# out = (out + self._shortcut(x)) / self.equal_var
|
||||
# return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_in, 3, 1, 1,groups=dim_in),
|
||||
nn.Conv2d(dim_in, dim_out, 1, 1)
|
||||
)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(dim_out, dim_out, 3, 1, 1,groups=dim_out),
|
||||
nn.Conv2d(dim_out, dim_out, 1)
|
||||
)
|
||||
|
||||
# self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
# self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class ModulatedResBlk(nn.Module):
|
||||
def __init__(self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2),
|
||||
upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = ModulatedDWConv2d(dim_in, dim_out, style_dim, 3)
|
||||
self.conv2 = ModulatedDWConv2d(dim_out, dim_out, style_dim, 3)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x,s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x,s)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"].lower()
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
norm_mask= nn.InstanceNorm2d
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
norm_mask = nn.BatchNorm2d
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
# self.maskhead = nn.Sequential(
|
||||
# nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# norm_mask, # 64
|
||||
# activation,
|
||||
# nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Sigmoid())
|
||||
self.maskhead_lr = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
norm_mask(in_channel, affine=True), # 32
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
norm_mask(in_channel//4, affine=True), # 64
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1),
|
||||
norm_mask(in_channel//8, affine=True), # 128
|
||||
activation,
|
||||
)
|
||||
|
||||
self.maskhead_hr = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
norm_mask(in_channel//16, affine=True), # 256
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid() # 512
|
||||
)
|
||||
|
||||
self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1),
|
||||
nn.Sigmoid())
|
||||
|
||||
# self.maskhead_lr = nn.Sequential(
|
||||
# nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
# nn.Conv2d(in_channel*8, in_channel*8, 3, 1, 1,groups=in_channel*8),
|
||||
# nn.Conv2d(in_channel*8, in_channel, 1, bias=False),
|
||||
# # nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
# norm_mask(in_channel, affine=True), # 32
|
||||
# activation,
|
||||
# nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
# nn.Conv2d(in_channel, in_channel, 3, 1, 1,groups=in_channel),
|
||||
# nn.Conv2d(in_channel, in_channel//4, 1, bias=False),
|
||||
# # nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# norm_mask(in_channel//4, affine=True), # 64
|
||||
# activation,
|
||||
# nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
# # nn.Conv2d(in_channel//4, in_channel//8, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Conv2d(in_channel//4, in_channel//4, 3, 1, 1,groups=in_channel//4),
|
||||
# nn.Conv2d(in_channel//4, in_channel//8, 1, bias=False),
|
||||
# norm_mask(in_channel//8, affine=True), # 128
|
||||
# activation,
|
||||
# )
|
||||
# self.maskhead_hr = nn.Sequential(
|
||||
# nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
# # nn.Conv2d(in_channel//8, in_channel//16, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
# nn.Conv2d(in_channel//8, in_channel//8, 3, 1, 1,groups=in_channel//8),
|
||||
# nn.Conv2d(in_channel//8, in_channel//16, 1, bias=False),
|
||||
# norm_mask(in_channel//16, affine=True), # 256
|
||||
# activation,
|
||||
# nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
# nn.Conv2d(in_channel//16, 1, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Sigmoid() # 512
|
||||
# )
|
||||
# self.maskhead_out = nn.Sequential(nn.Conv2d(in_channel//8, 1, kernel_size=1, stride=1),
|
||||
# nn.Sigmoid())
|
||||
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel*2, in_channel, normalize=norm)
|
||||
|
||||
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel, in_channel, normalize=norm)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
mask_feat= self.maskhead_lr(res)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
mask_lr= self.maskhead_out(mask_feat)
|
||||
# res = (1-mask) * self.sigma(skip) + mask * res
|
||||
res = (1-mask_lr) * skip + mask_lr * res
|
||||
res = self.up2(res,id) # + skip
|
||||
res = self.up1(res,id)
|
||||
res = self.to_rgb(res)
|
||||
mask_hr=self.maskhead_hr(mask_feat)
|
||||
res = (1-mask_hr) * img + mask_hr * res
|
||||
return res, mask_lr, mask_hr
|
||||
@@ -0,0 +1,293 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 13th April 2022 10:22:52 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="bn", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"]
|
||||
norm = norm.lower()
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
# self.sigma = ResBlk(in_channel*2,in_channel*2)
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
self.maskhead_lr = nn.Sequential(
|
||||
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel), # 64
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
|
||||
|
||||
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
self.maskhead_hr = nn.Sequential(
|
||||
nn.Conv2d(in_channel, in_channel//8, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel//8), # 64
|
||||
activation,
|
||||
nn.Conv2d(in_channel//8, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
mask= self.maskhead_lr(res)
|
||||
# res = (1-mask) * self.sigma(skip) + mask * res
|
||||
res = (1-mask) * skip + mask * res
|
||||
res = self.up2(res) # + skip
|
||||
res = self.up1(res)
|
||||
mask_hr = self.maskhead_hr(res)
|
||||
res = self.to_rgb(res)
|
||||
res = (1-mask_hr)*img + mask_hr*res
|
||||
return res, mask, mask_hr
|
||||
@@ -0,0 +1,293 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 13th April 2022 10:22:52 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="bn", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"]
|
||||
norm = norm.lower()
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
# self.sigma = ResBlk(in_channel*2,in_channel*2)
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
self.maskhead_lr = nn.Sequential(
|
||||
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel), # 64
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
|
||||
|
||||
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
self.maskhead_hr = nn.Sequential(
|
||||
nn.Conv2d(in_channel, in_channel//8, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel//8), # 64
|
||||
activation,
|
||||
nn.Conv2d(in_channel//8, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
mask= self.maskhead_lr(res)
|
||||
# res = (1-mask) * self.sigma(skip) + mask * res
|
||||
res = (1-mask) * skip + mask * res
|
||||
res = self.up2(res) # + skip
|
||||
res = self.up1(res)
|
||||
mask_hr = self.maskhead_hr(res)
|
||||
res = self.to_rgb(res)
|
||||
res = (1-mask_hr)*img + mask_hr*res
|
||||
return res, mask, mask_hr
|
||||
@@ -0,0 +1,276 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 29th March 2022 12:02:53 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from components.LSTU import LSTU
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"]
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
self.lstu = LSTU(in_channel*2,norm)
|
||||
|
||||
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel*2, in_channel, normalize="in") # 256
|
||||
|
||||
# self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(in_channel),
|
||||
# activation,
|
||||
# nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Sigmoid()
|
||||
# )
|
||||
|
||||
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
# res = self.down6(res)
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
# res = self.up6(res,id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
res,mask = self.lstu(skip, res)
|
||||
res = self.up2(res,id) # + skip
|
||||
res = self.up1(res,id)
|
||||
res = self.to_rgb(res)
|
||||
return res,mask
|
||||
@@ -0,0 +1,285 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 29th March 2022 1:08:05 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"]
|
||||
|
||||
lstu_script = kwargs["lstu_script"]
|
||||
lstu_class = kwargs["lstu_class"]
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
script_name = "components." + lstu_script
|
||||
package = __import__(script_name, fromlist=True)
|
||||
lstu_class = getattr(package, lstu_class)
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
self.lstu = lstu_class(in_channel*2,norm)
|
||||
|
||||
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel*2, in_channel, normalize="in") # 256
|
||||
|
||||
# self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(in_channel),
|
||||
# activation,
|
||||
# nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Sigmoid()
|
||||
# )
|
||||
|
||||
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
# res = self.down6(res)
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
# res = self.up6(res,id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
res,mask = self.lstu(skip, res)
|
||||
res = self.up2(res,id) # + skip
|
||||
res = self.up1(res,id)
|
||||
res = self.to_rgb(res)
|
||||
return res
|
||||
@@ -0,0 +1,285 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 30th March 2022 4:14:27 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"]
|
||||
|
||||
lstu_script = kwargs["lstu_script"]
|
||||
lstu_class = kwargs["lstu_class"]
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
script_name = "components." + lstu_script
|
||||
package = __import__(script_name, fromlist=True)
|
||||
lstu_class = getattr(package, lstu_class)
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
self.lstu = lstu_class(in_channel*2,norm)
|
||||
|
||||
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel*2, in_channel, normalize="in") # 256
|
||||
|
||||
# self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(in_channel),
|
||||
# activation,
|
||||
# nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Sigmoid()
|
||||
# )
|
||||
|
||||
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
# res = self.down6(res)
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
# res = self.up6(res,id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
# res,mask = self.lstu(skip, res)
|
||||
res = self.up2(res,id) # + skip
|
||||
res = self.up1(res,id)
|
||||
res = self.to_rgb(res)
|
||||
return res
|
||||
@@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 6th April 2022 12:55:51 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainUpBlock(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2)):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.norm = AdaIN(style_dim, dim_out)
|
||||
|
||||
def forward(self, x, s):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv(x)
|
||||
x = self.norm(x, s)
|
||||
x = self.actv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"]
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = nn.Sequential(nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=2, padding=1, bias=False), # 256
|
||||
nn.BatchNorm2d(in_channel), activation)
|
||||
|
||||
self.down2 = nn.Sequential(nn.Conv2d(in_channel, in_channel*2, kernel_size=3, stride=2, padding=1, bias=False), # 128
|
||||
nn.BatchNorm2d(in_channel*2), activation)
|
||||
|
||||
self.down3 = nn.Sequential(nn.Conv2d(in_channel*2, in_channel*4, kernel_size=3, stride=2, padding=1, bias=False), # 64
|
||||
nn.BatchNorm2d(in_channel*4), activation)
|
||||
|
||||
self.down4 = nn.Sequential(nn.Conv2d(in_channel*4, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False), # 32
|
||||
nn.BatchNorm2d(in_channel*8), activation)
|
||||
|
||||
self.down5 = nn.Sequential(nn.Conv2d(in_channel*8, in_channel*8, kernel_size=3, stride=2, padding=1, bias=False), # 32
|
||||
nn.BatchNorm2d(in_channel*8), activation)
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
self.maskhead = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
nn.BatchNorm2d(in_channel), # 32
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel//2), # 64
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainUpBlock(in_channel*8, in_channel*8, style_dim=id_dim) # 32
|
||||
|
||||
self.up4 = AdainUpBlock(in_channel*8, in_channel*4, style_dim=id_dim) # 64
|
||||
|
||||
self.up3 = AdainUpBlock(in_channel*4, in_channel*2, style_dim=id_dim) # 128
|
||||
|
||||
self.up2 = AdainUpBlock(in_channel*2, in_channel, style_dim=id_dim)
|
||||
|
||||
self.up1 = AdainUpBlock(in_channel, in_channel, style_dim=id_dim)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
|
||||
self.to_rgb = nn.Sequential(nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(in_channel, 3, kernel_size=3, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
mask= self.maskhead(res)
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
res = (1-mask) * skip + mask * res
|
||||
res = self.up2(res,id) # + skip
|
||||
res = self.up1(res,id)
|
||||
res = self.to_rgb(res)
|
||||
return res, mask
|
||||
@@ -0,0 +1,298 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 2nd April 2022 1:27:23 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"]
|
||||
|
||||
lstu_script = kwargs["lstu_script"]
|
||||
lstu_class = kwargs["lstu_class"]
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
script_name = "components." + lstu_script
|
||||
package = __import__(script_name, fromlist=True)
|
||||
lstu_class = getattr(package, lstu_class)
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
self.maskhead = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channel*8, in_channel, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel), # 32
|
||||
activation,
|
||||
nn.ConvTranspose2d(in_channel, in_channel//2, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel), # 64
|
||||
activation,
|
||||
nn.ConvTranspose2d(in_channel//2, 1, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel), # 128
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
self.lstu = lstu_class(in_channel*2,norm)
|
||||
|
||||
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel*2, in_channel, normalize="in") # 256
|
||||
|
||||
# self.lstu = nn.Sequential(nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(in_channel),
|
||||
# activation,
|
||||
# nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Sigmoid()
|
||||
# )
|
||||
|
||||
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id, feat_out=False):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
if feat_out:
|
||||
return res
|
||||
# res = self.down6(res)
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
# res = self.up6(res,id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
res,mask = self.lstu(skip, res)
|
||||
res = self.up2(res,id) # + skip
|
||||
res = self.up1(res,id)
|
||||
res = self.to_rgb(res)
|
||||
return res
|
||||
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 3rd April 2022 1:06:31 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"]
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
self.maskhead = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
nn.BatchNorm2d(in_channel), # 32
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel//2), # 64
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
|
||||
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
mask= self.maskhead(res)
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
res = (1-mask) * skip + mask * res
|
||||
res = self.up2(res,id) # + skip
|
||||
res = self.up1(res,id)
|
||||
res = self.to_rgb(res)
|
||||
return res, mask
|
||||
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 3rd April 2022 1:06:31 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"]
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
self.maskhead = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
nn.BatchNorm2d(in_channel), # 32
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel//2), # 64
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
|
||||
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
mask= self.maskhead(res)
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
res = (1-mask) * skip + mask * res
|
||||
res = self.up2(res,id) # + skip
|
||||
res = self.up1(res,id)
|
||||
res = self.to_rgb(res)
|
||||
return res, mask
|
||||
@@ -0,0 +1,279 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 6th April 2022 8:38:50 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="bn", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"]
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.sigma = ResBlk(in_channel*2,in_channel*2)
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
self.maskhead = nn.Sequential(
|
||||
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel), # 64
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
|
||||
self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
mask= self.maskhead(res)
|
||||
res = (1-mask) * self.sigma(skip) + mask * res
|
||||
res = self.up2(res,id) # + skip
|
||||
res = self.up1(res,id)
|
||||
res = self.to_rgb(res)
|
||||
return res, mask
|
||||
@@ -0,0 +1,283 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 13th April 2022 3:12:53 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="bn", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"].lower()
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
# self.sigma = ResBlk(in_channel*2,in_channel*2)
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
self.maskhead = nn.Sequential(
|
||||
nn.Conv2d(in_channel*2, in_channel, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel), # 64
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize="bn")
|
||||
|
||||
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up1 = ResUpBlk(in_channel, in_channel, normalize="bn")
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
mask= self.maskhead(res)
|
||||
# res = (1-mask) * self.sigma(skip) + mask * res
|
||||
res = (1-mask) * skip + mask * res
|
||||
res = self.up2(res) # + skip
|
||||
res = self.up1(res)
|
||||
res = self.to_rgb(res)
|
||||
return res, mask
|
||||
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator_Invobn_config1.py
|
||||
# Created Date: Saturday February 26th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 13th April 2022 6:30:26 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm2d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1, 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class ResUpBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out,actv=nn.LeakyReLU(0.2),normalize="in"):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_out, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self._residual(x)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
class AdainResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=512,
|
||||
actv=nn.LeakyReLU(0.2), upsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample = upsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim=64):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
|
||||
self.norm1 = AdaIN(style_dim, dim_in)
|
||||
self.norm2 = AdaIN(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
if self.upsample:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / self.equal_var
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
id_dim = kwargs["id_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
in_channel = kwargs["in_channel"]
|
||||
up_mode = kwargs["up_mode"]
|
||||
norm = kwargs["norm"].lower()
|
||||
|
||||
aggregator = kwargs["aggregator"]
|
||||
res_mode = kwargs["res_mode"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
if norm.lower() == "in":
|
||||
norm_out = nn.InstanceNorm2d(in_channel, affine=True)
|
||||
norm_mask= nn.InstanceNorm2d
|
||||
elif norm.lower() == "bn":
|
||||
norm_out = nn.BatchNorm2d(in_channel)
|
||||
norm_mask = nn.BatchNorm2d
|
||||
|
||||
|
||||
activation = nn.LeakyReLU(0.2)
|
||||
# activation = nn.ReLU()
|
||||
|
||||
self.from_rgb = nn.Conv2d(3, in_channel, 1, 1, 0)
|
||||
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = ResBlk(in_channel, in_channel, normalize=norm, downsample=True)# 256
|
||||
|
||||
self.down2 = ResBlk(in_channel, in_channel*2, normalize=norm, downsample=True)# 128
|
||||
|
||||
self.down3 = ResBlk(in_channel*2, in_channel*4,normalize=norm, downsample=True)# 64
|
||||
|
||||
self.down4 = ResBlk(in_channel*4, in_channel*8, normalize=norm, downsample=True)# 32
|
||||
|
||||
self.down5 = ResBlk(in_channel*8, in_channel*8, normalize=norm, downsample=True)# 16
|
||||
|
||||
# self.down6 = ResBlk(in_channel*8, in_channel*8, normalize=True, downsample=True)# 8
|
||||
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=False)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
# self.up6 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 16
|
||||
|
||||
self.up5 = AdainResBlk(in_channel*8, in_channel*8, style_dim=id_dim, upsample=True) # 32
|
||||
|
||||
self.up4 = AdainResBlk(in_channel*8, in_channel*4, style_dim=id_dim, upsample=True) # 64
|
||||
|
||||
self.up3 = AdainResBlk(in_channel*4, in_channel*2, style_dim=id_dim, upsample=True) # 128
|
||||
|
||||
# self.maskhead = nn.Sequential(
|
||||
# nn.Conv2d(in_channel*2, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# norm_mask, # 64
|
||||
# activation,
|
||||
# nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
|
||||
# nn.Sigmoid())
|
||||
self.maskhead = nn.Sequential(
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel*8, in_channel, kernel_size=3, stride=1, padding=1,bias=False),
|
||||
norm_mask(in_channel, affine=True), # 32
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel, in_channel//4, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
norm_mask(in_channel//4, affine=True), # 64
|
||||
activation,
|
||||
nn.UpsamplingNearest2d(scale_factor = 2),
|
||||
nn.Conv2d(in_channel//4, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
# self.up2 = AdainResBlk(in_channel*2, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up2 = ResUpBlk(in_channel*2, in_channel, normalize=norm)
|
||||
|
||||
# self.up1 = AdainResBlk(in_channel, in_channel, style_dim=id_dim, upsample=True)
|
||||
self.up1 = ResUpBlk(in_channel, in_channel, normalize=norm)
|
||||
# ResUpBlk(in_channel, in_channel, normalize="in") # 512
|
||||
|
||||
|
||||
|
||||
self.to_rgb = nn.Sequential(
|
||||
norm_out,
|
||||
activation,
|
||||
nn.Conv2d(in_channel, 3, 3, 1, 1))
|
||||
|
||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, img, id):
|
||||
res = self.from_rgb(img)
|
||||
res = self.down1(res)
|
||||
skip = self.down2(res)
|
||||
res = self.down3(skip)
|
||||
res = self.down4(res)
|
||||
res = self.down5(res)
|
||||
mask= self.maskhead(res)
|
||||
for i in range(len(self.BottleNeck)):
|
||||
res = self.BottleNeck[i](res, id)
|
||||
res = self.up5(res,id)
|
||||
res = self.up4(res,id)
|
||||
res = self.up3(res,id)
|
||||
|
||||
# res = (1-mask) * self.sigma(skip) + mask * res
|
||||
res = (1-mask) * skip + mask * res
|
||||
res = self.up2(res) # + skip
|
||||
res = self.up1(res)
|
||||
res = self.to_rgb(res)
|
||||
return res, mask
|
||||
+96
-24
@@ -5,43 +5,115 @@
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 13th February 2022 2:03:21 am
|
||||
# Last Modified: Monday, 28th March 2022 11:47:55 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
# class LSTU(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_channel,
|
||||
# out_channel,
|
||||
# latent_channel,
|
||||
# scale = 4
|
||||
# ):
|
||||
# super().__init__()
|
||||
# sig = nn.Sigmoid()
|
||||
# self.relu = nn.ReLU(True)
|
||||
|
||||
# self.up_sample = nn.Sequential(nn.Conv2d(latent_channel, out_channel/4, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(out_channel/4),
|
||||
# self.relu,
|
||||
# nn.Conv2d(latent_channel/4, out_channel, kernel_size=3, stride=1, padding=1),
|
||||
# )
|
||||
|
||||
# self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(out_channel), sig)
|
||||
|
||||
# self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(out_channel), sig)
|
||||
|
||||
# self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True))
|
||||
|
||||
# def forward(self, encoder_in, bottleneck_in):
|
||||
# h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel`
|
||||
# h_bar_l = self.conv11(h_hat_l_1)
|
||||
# f_l = self.forget_gate(h_hat_l_1)
|
||||
# r_l = self.reset_gate (h_hat_l_1)
|
||||
# h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in
|
||||
# x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1
|
||||
# return x_hat_l
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class LSTU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
latent_channel,
|
||||
scale = 4
|
||||
norm
|
||||
):
|
||||
super().__init__()
|
||||
sig = nn.Sigmoid()
|
||||
self.relu = nn.ReLU(True)
|
||||
self.sig = nn.Sigmoid()
|
||||
|
||||
self.up_sample = nn.Sequential(nn.ConvTranspose2d(latent_channel, out_channel, kernel_size=4, stride=scale, padding=0, bias=False),
|
||||
nn.BatchNorm2d(out_channel), sig)
|
||||
|
||||
self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channel), sig)
|
||||
|
||||
self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channel), sig)
|
||||
|
||||
self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True))
|
||||
self.mask_head = ResBlk(in_channel, 1, normalize=norm)
|
||||
# self.forget_gate = ResBlk(in_channel,in_channel, normalize=norm)
|
||||
|
||||
def forward(self, encoder_in, bottleneck_in):
|
||||
h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel`
|
||||
h_bar_l = self.conv11(h_hat_l_1)
|
||||
f_l = self.forget_gate(h_hat_l_1)
|
||||
r_l = self.reset_gate (h_hat_l_1)
|
||||
h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in
|
||||
x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1
|
||||
return x_hat_l
|
||||
def forward(self, encoder_in, decoder_in):
|
||||
mask = self.sig(self.mask_head(decoder_in)) # upsample and make `channel` identical to `out_channel`
|
||||
# enc_feat= self.forget_gate(encoder_in)
|
||||
out = (1-mask)*encoder_in + mask * decoder_in
|
||||
return out, mask
|
||||
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator.py
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 29th March 2022 12:20:26 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
# class LSTU(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_channel,
|
||||
# out_channel,
|
||||
# latent_channel,
|
||||
# scale = 4
|
||||
# ):
|
||||
# super().__init__()
|
||||
# sig = nn.Sigmoid()
|
||||
# self.relu = nn.ReLU(True)
|
||||
|
||||
# self.up_sample = nn.Sequential(nn.Conv2d(latent_channel, out_channel/4, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(out_channel/4),
|
||||
# self.relu,
|
||||
# nn.Conv2d(latent_channel/4, out_channel, kernel_size=3, stride=1, padding=1),
|
||||
# )
|
||||
|
||||
# self.forget_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(out_channel), sig)
|
||||
|
||||
# self.reset_gate = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False),
|
||||
# nn.BatchNorm2d(out_channel), sig)
|
||||
|
||||
# self.conv11 = nn.Sequential(nn.Conv2d(out_channel, out_channel, kernel_size=1, bias=True))
|
||||
|
||||
# def forward(self, encoder_in, bottleneck_in):
|
||||
# h_hat_l_1 = self.up_sample(bottleneck_in) # upsample and make `channel` identical to `out_channel`
|
||||
# h_bar_l = self.conv11(h_hat_l_1)
|
||||
# f_l = self.forget_gate(h_hat_l_1)
|
||||
# r_l = self.reset_gate (h_hat_l_1)
|
||||
# h_hat_l = (1-f_l)*h_bar_l + f_l* encoder_in
|
||||
# x_hat_l = r_l* self.relu(h_hat_l) + (1-r_l)* h_hat_l_1
|
||||
# return x_hat_l
|
||||
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.equal_var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x /self.equal_var # unit variance
|
||||
|
||||
class LSTU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
norm
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# self.mask_head = ResBlk(in_channel, 1, normalize=norm)
|
||||
self.mask_head = nn.Sequential(nn.Conv2d(in_channel, in_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(in_channel//2),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Conv2d(in_channel//2, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
# self.forget_gate = ResBlk(in_channel,in_channel, normalize=norm)
|
||||
|
||||
def forward(self, encoder_in, decoder_in):
|
||||
mask = self.mask_head(decoder_in) # upsample and make `channel` identical to `out_channel`
|
||||
# enc_feat= self.forget_gate(encoder_in)
|
||||
out = (1-mask)*encoder_in + mask * decoder_in
|
||||
return out, mask
|
||||
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: ModulatedDWConv.py
|
||||
# Created Date: Monday April 18th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 18th April 2022 10:33:48 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Modified from: https://github.com/bes-dev/MobileStyleGAN.pytorch
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ModulatedDWConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels_in,
|
||||
channels_out,
|
||||
style_dim,
|
||||
kernel_size,
|
||||
demodulate=True
|
||||
):
|
||||
super().__init__()
|
||||
# create conv
|
||||
self.weight_dw = nn.Parameter(
|
||||
torch.randn(channels_in, 1, kernel_size, kernel_size)
|
||||
)
|
||||
self.weight_permute = nn.Parameter(
|
||||
torch.randn(channels_out, channels_in, 1, 1)
|
||||
)
|
||||
# create modulation network
|
||||
self.modulation = nn.Linear(style_dim, channels_in, bias=True)
|
||||
self.modulation.bias.data.fill_(1.0)
|
||||
# create demodulation parameters
|
||||
self.demodulate = demodulate
|
||||
if self.demodulate:
|
||||
self.register_buffer("style_inv", torch.randn(1, 1, channels_in, 1, 1))
|
||||
# some service staff
|
||||
self.scale = 1.0 / math.sqrt(channels_in * kernel_size ** 2)
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
def forward(self, x, style):
|
||||
modulation = self.get_modulation(style)
|
||||
x = modulation * x
|
||||
x = F.conv2d(x, self.weight_dw, padding=self.padding, groups=x.size(1))
|
||||
x = F.conv2d(x, self.weight_permute)
|
||||
if self.demodulate:
|
||||
demodulation = self.get_demodulation(style)
|
||||
x = demodulation * x
|
||||
return x
|
||||
|
||||
def get_modulation(self, style):
|
||||
style = self.modulation(style).view(style.size(0), -1, 1, 1)
|
||||
modulation = self.scale * style
|
||||
return modulation
|
||||
|
||||
def get_demodulation(self, style):
|
||||
w = (self.weight_dw.transpose(0, 1) * self.weight_permute).unsqueeze(0)
|
||||
norm = torch.rsqrt((self.scale * self.style_inv * w).pow(2).sum([2, 3, 4]) + 1e-8)
|
||||
demodulation = norm
|
||||
return demodulation.view(*demodulation.size(), 1, 1)
|
||||
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Nonstau_Discriminator.py
|
||||
# Created Date: Monday March 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 28th March 2022 10:03:56 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
elif self.normalize.lower() == "none":
|
||||
self.normalize = False
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x / self.var # unit variance
|
||||
|
||||
class Discriminator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
img_size = kwargs["img_size"]
|
||||
num_domains = 1
|
||||
max_conv_dim = kwargs["max_conv_dim"]
|
||||
norm = kwargs["norm"]
|
||||
dim_in = 2**14 // img_size
|
||||
blocks = []
|
||||
blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
|
||||
|
||||
repeat_num = int(np.log2(img_size)) - 2
|
||||
for _ in range(repeat_num):
|
||||
dim_out = min(dim_in*2, max_conv_dim)
|
||||
blocks += [ResBlk(dim_in, dim_out, normalize=norm, downsample=True)]
|
||||
dim_in = dim_out
|
||||
|
||||
blocks += [nn.LeakyReLU(0.2)]
|
||||
blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
|
||||
blocks += [nn.LeakyReLU(0.2)]
|
||||
blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)]
|
||||
self.main = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.main(x)
|
||||
out = out.view(out.size(0), -1) # (batch, num_domains)
|
||||
return out
|
||||
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Nonstau_Discriminator.py
|
||||
# Created Date: Monday March 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 13th April 2022 3:11:40 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ResBlk(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
||||
normalize="in", downsample=False):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.normalize = normalize
|
||||
self.downsample = downsample
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self.var = math.sqrt(2)
|
||||
self._build_weights(dim_in, dim_out)
|
||||
|
||||
def _build_weights(self, dim_in, dim_out):
|
||||
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
|
||||
if self.normalize.lower() == "in":
|
||||
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
||||
elif self.normalize.lower() == "bn":
|
||||
self.norm1 = nn.BatchNorm2d(dim_in)
|
||||
self.norm2 = nn.BatchNorm2d(dim_in)
|
||||
elif self.normalize.lower() == "none":
|
||||
self.normalize = False
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
|
||||
|
||||
def _shortcut(self, x):
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
return x
|
||||
|
||||
def _residual(self, x):
|
||||
if self.normalize:
|
||||
x = self.norm1(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv1(x)
|
||||
if self.downsample:
|
||||
x = F.avg_pool2d(x, 2)
|
||||
if self.normalize:
|
||||
x = self.norm2(x)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self._shortcut(x) + self._residual(x)
|
||||
return x / self.var # unit variance
|
||||
|
||||
class Discriminator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
img_size = kwargs["img_size"]
|
||||
num_domains = 1
|
||||
max_conv_dim = kwargs["max_conv_dim"]
|
||||
norm = kwargs["norm"].lower()
|
||||
dim_in = 2**14 // img_size
|
||||
blocks = []
|
||||
blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
|
||||
|
||||
repeat_num = int(np.log2(img_size)) - 2
|
||||
for _ in range(repeat_num-2):
|
||||
dim_out = min(dim_in*2, max_conv_dim)
|
||||
blocks += [ResBlk(dim_in, dim_out, normalize=norm, downsample=True)]
|
||||
dim_in = dim_out
|
||||
blocks1 = []
|
||||
for _ in range(2): # 16
|
||||
dim_out = min(dim_in*2, max_conv_dim)
|
||||
blocks1 += [ResBlk(dim_in, dim_out, normalize=norm, downsample=True)]
|
||||
dim_in = dim_out
|
||||
|
||||
blocks1 += [nn.LeakyReLU(0.2)]
|
||||
blocks1 += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
|
||||
blocks1 += [nn.LeakyReLU(0.2)]
|
||||
blocks1 += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)]
|
||||
self.main = nn.Sequential(*blocks)
|
||||
self.tail = nn.Sequential(*blocks1)
|
||||
|
||||
def get_feature(self,x):
|
||||
mid = self.main(x)
|
||||
return mid
|
||||
|
||||
def forward(self, x):
|
||||
mid = self.main(x)
|
||||
out = self.tail(mid)
|
||||
out = out.view(out.size(0), -1) # (batch, num_domains)
|
||||
return out,mid
|
||||
@@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: data_loader_VGGFace2HQ copy.py
|
||||
# Created Date: Sunday February 6th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 15th February 2022 1:50:19 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torch.utils import data
|
||||
from torchvision import transforms as T
|
||||
# from StyleResize import StyleResize
|
||||
|
||||
class InfiniteSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
||||
assert len(dataset) > 0
|
||||
assert num_replicas > 0
|
||||
assert 0 <= rank < num_replicas
|
||||
assert 0 <= window_size <= 1
|
||||
super().__init__(dataset)
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.window_size = window_size
|
||||
|
||||
def __iter__(self):
|
||||
order = np.arange(len(self.dataset))
|
||||
rnd = None
|
||||
window = 0
|
||||
if self.shuffle:
|
||||
rnd = np.random.RandomState(self.seed)
|
||||
rnd.shuffle(order)
|
||||
window = int(np.rint(order.size * self.window_size))
|
||||
|
||||
idx = 0
|
||||
while True:
|
||||
i = idx % order.size
|
||||
if idx % self.num_replicas == self.rank:
|
||||
yield order[i]
|
||||
if window >= 2:
|
||||
j = (i - rnd.randint(window)) % order.size
|
||||
order[i], order[j] = order[j], order[i]
|
||||
idx += 1
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader, cur_gpu):
|
||||
torch.cuda.set_device(cur_gpu) # must add this line to avoid excessive use of GPU 0 by the prefetcher
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream(device=cur_gpu)
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.cur_gpu = cur_gpu
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
# self.num_images = loader.__len__()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
# try:
|
||||
self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
# except StopIteration:
|
||||
# self.dataiter = iter(self.loader)
|
||||
# self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
|
||||
self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
# self.next_input = self.next_input.float()
|
||||
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
def next(self):
|
||||
torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream)
|
||||
src_image1 = self.src_image1
|
||||
src_image2 = self.src_image2
|
||||
self.preload()
|
||||
return src_image1, src_image2
|
||||
|
||||
# def __len__(self):
|
||||
# """Return the number of images."""
|
||||
# return self.num_images
|
||||
|
||||
class VGGFace2HQDataset(data.Dataset):
|
||||
"""Dataset class for the Artworks dataset and content dataset."""
|
||||
|
||||
def __init__(self,
|
||||
image_dir,
|
||||
img_transform,
|
||||
subffix='jpg',
|
||||
random_seed=1234):
|
||||
"""Initialize and preprocess the VGGFace2 HQ dataset."""
|
||||
self.image_dir = image_dir
|
||||
self.img_transform = img_transform
|
||||
self.subffix = subffix
|
||||
self.dataset = []
|
||||
self.random_seed = random_seed
|
||||
self.preprocess()
|
||||
self.num_images = len(self.dataset)
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the VGGFace2 HQ dataset."""
|
||||
print("processing VGGFace2 HQ dataset images...")
|
||||
|
||||
temp_path = os.path.join(self.image_dir,'*/')
|
||||
pathes = glob.glob(temp_path)
|
||||
self.dataset = []
|
||||
for dir_item in pathes:
|
||||
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
|
||||
print("processing %s"%dir_item,end='\r')
|
||||
temp_list = []
|
||||
for item in join_path:
|
||||
temp_list.append(item)
|
||||
self.dataset.append(temp_list)
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.dataset)
|
||||
print('Finished preprocessing the VGGFace2 HQ dataset, total dirs number: %d...'%len(self.dataset))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return two src domain images and two dst domain images."""
|
||||
dir_tmp1 = self.dataset[index]
|
||||
dir_tmp1_len = len(dir_tmp1)
|
||||
|
||||
filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
|
||||
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
|
||||
image1 = self.img_transform(Image.open(filename1))
|
||||
image2 = self.img_transform(Image.open(filename2))
|
||||
return image1, image2
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
def GetLoader( dataset_roots,
|
||||
rank,
|
||||
num_gpus,
|
||||
batch_size=16,
|
||||
**kwargs
|
||||
):
|
||||
"""Build and return a data loader."""
|
||||
|
||||
data_root = dataset_roots
|
||||
random_seed = kwargs["random_seed"]
|
||||
num_workers = kwargs["dataloader_workers"]
|
||||
|
||||
c_transforms = []
|
||||
|
||||
c_transforms.append(T.ToTensor())
|
||||
c_transforms = T.Compose(c_transforms)
|
||||
|
||||
content_dataset = VGGFace2HQDataset(
|
||||
data_root,
|
||||
c_transforms,
|
||||
"jpg",
|
||||
random_seed)
|
||||
device = torch.device('cuda', rank)
|
||||
sampler = InfiniteSampler(dataset=content_dataset, rank=rank, num_replicas=num_gpus, seed=random_seed)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True, sampler=sampler)
|
||||
# content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
# drop_last=False,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
prefetcher = data_prefetcher(content_data_loader,device)
|
||||
return prefetcher
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday February 6th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 15th February 2022 1:50:19 am
|
||||
# Last Modified: Wednesday, 6th April 2022 12:53:53 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -108,7 +108,7 @@ class VGGFace2HQDataset(data.Dataset):
|
||||
subffix='jpg',
|
||||
random_seed=1234):
|
||||
"""Initialize and preprocess the VGGFace2 HQ dataset."""
|
||||
self.image_dir = image_dir
|
||||
self.image_dir = image_dir["images"]
|
||||
self.img_transform = img_transform
|
||||
self.subffix = subffix
|
||||
self.dataset = []
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: data_loader_VGGFace2HQ copy.py
|
||||
# Created Date: Sunday February 6th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 3rd April 2022 9:48:23 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torch.utils import data
|
||||
from torchvision import transforms as T
|
||||
import cv2
|
||||
# from StyleResize import StyleResize
|
||||
|
||||
class InfiniteSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
||||
assert len(dataset) > 0
|
||||
assert num_replicas > 0
|
||||
assert 0 <= rank < num_replicas
|
||||
assert 0 <= window_size <= 1
|
||||
super().__init__(dataset)
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.window_size = window_size
|
||||
|
||||
def __iter__(self):
|
||||
order = np.arange(len(self.dataset))
|
||||
rnd = None
|
||||
window = 0
|
||||
if self.shuffle:
|
||||
rnd = np.random.RandomState(self.seed)
|
||||
rnd.shuffle(order)
|
||||
window = int(np.rint(order.size * self.window_size))
|
||||
|
||||
idx = 0
|
||||
while True:
|
||||
i = idx % order.size
|
||||
if idx % self.num_replicas == self.rank:
|
||||
yield order[i]
|
||||
if window >= 2:
|
||||
j = (i - rnd.randint(window)) % order.size
|
||||
order[i], order[j] = order[j], order[i]
|
||||
idx += 1
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader, cur_gpu):
|
||||
torch.cuda.set_device(cur_gpu) # must add this line to avoid excessive use of GPU 0 by the prefetcher
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream(device=cur_gpu)
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.cur_gpu = cur_gpu
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
# self.num_images = loader.__len__()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
# try:
|
||||
self.src_image1, self.src_image2, self.mask = next(self.dataiter)
|
||||
# except StopIteration:
|
||||
# self.dataiter = iter(self.loader)
|
||||
# self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
|
||||
self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
|
||||
self.mask = self.mask.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.mask = self.mask/255.0
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
# self.next_input = self.next_input.float()
|
||||
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
def next(self):
|
||||
torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream)
|
||||
src_image1 = self.src_image1
|
||||
src_image2 = self.src_image2
|
||||
mask = self.mask
|
||||
self.preload()
|
||||
return src_image1, src_image2, mask
|
||||
|
||||
# def __len__(self):
|
||||
# """Return the number of images."""
|
||||
# return self.num_images
|
||||
|
||||
class VGGFace2HQDataset(data.Dataset):
|
||||
"""Dataset class for the Artworks dataset and content dataset."""
|
||||
|
||||
def __init__(self,
|
||||
image_dir,
|
||||
mask_dir,
|
||||
img_transform,
|
||||
subffix='jpg',
|
||||
random_seed=1234):
|
||||
"""Initialize and preprocess the VGGFace2 HQ dataset."""
|
||||
self.image_dir = image_dir
|
||||
self.mask_dir = mask_dir
|
||||
self.img_transform = img_transform
|
||||
self.subffix = subffix
|
||||
self.dataset = []
|
||||
self.random_seed = random_seed
|
||||
self.preprocess()
|
||||
self.num_images = len(self.dataset)
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the VGGFace2 HQ dataset."""
|
||||
print("processing VGGFace2 HQ dataset images...")
|
||||
|
||||
temp_path = os.path.join(self.image_dir,'*/')
|
||||
pathes = glob.glob(temp_path)
|
||||
self.dataset = []
|
||||
for dir_item in pathes:
|
||||
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
|
||||
print("processing %s"%dir_item,end='\r')
|
||||
dir_path = os.path.dirname(join_path[1])
|
||||
dir_name = os.path.join(self.mask_dir, os.path.basename(dir_path))
|
||||
# print(dir_name)
|
||||
temp_list = []
|
||||
for item in join_path:
|
||||
img_name = os.path.basename(item)
|
||||
img_name, _ = os.path.splitext(img_name)
|
||||
temp_list.append({
|
||||
"i":item,
|
||||
"m":os.path.join(dir_name, img_name + ".png")
|
||||
})
|
||||
self.dataset.append(temp_list)
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.dataset)
|
||||
print('Finished preprocessing the VGGFace2 HQ dataset, total dirs number: %d...'%len(self.dataset))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return two src domain images and two dst domain images."""
|
||||
dir_tmp1 = self.dataset[index]
|
||||
dir_tmp1_len = len(dir_tmp1)
|
||||
|
||||
filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
|
||||
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
|
||||
image1 = self.img_transform(Image.open(filename1["i"]))
|
||||
image2 = self.img_transform(Image.open(filename2["i"]))
|
||||
mask = torch.from_numpy(cv2.imread(filename1["m"],0)).unsqueeze(0)
|
||||
return image1, image2, mask
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
def GetLoader( dataset_roots,
|
||||
rank,
|
||||
num_gpus,
|
||||
batch_size=16,
|
||||
**kwargs
|
||||
):
|
||||
"""Build and return a data loader."""
|
||||
|
||||
data_root = dataset_roots["images"]
|
||||
mask_root = dataset_roots["masks"]
|
||||
random_seed = kwargs["random_seed"]
|
||||
num_workers = kwargs["dataloader_workers"]
|
||||
|
||||
c_transforms = []
|
||||
|
||||
c_transforms.append(T.ToTensor())
|
||||
c_transforms = T.Compose(c_transforms)
|
||||
|
||||
content_dataset = VGGFace2HQDataset(
|
||||
data_root,
|
||||
mask_root,
|
||||
c_transforms,
|
||||
"jpg",
|
||||
random_seed)
|
||||
device = torch.device('cuda', rank)
|
||||
sampler = InfiniteSampler(dataset=content_dataset, rank=rank, num_replicas=num_gpus, seed=random_seed)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True, sampler=sampler)
|
||||
# content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
# drop_last=False,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
prefetcher = data_prefetcher(content_data_loader,device)
|
||||
return prefetcher
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: dataset.check.py
|
||||
# Created Date: Sunday April 3rd 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 3rd April 2022 2:57:48 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import glob
|
||||
from utilities.json_config import readConfig, writeConfig
|
||||
|
||||
# dataset = "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan"
|
||||
# mask_dir= "G:/VGGFace2-HQ/VGGface2_HQ_original_aligned_mask"
|
||||
|
||||
savePath = "./vggface2hq_failed.txt"
|
||||
env_config = readConfig('env/env.json')
|
||||
env_config = env_config["path"]
|
||||
dataset = env_config["dataset_paths"]["vggface2_hq"]["images"]
|
||||
mask_dir = env_config["dataset_paths"]["vggface2_hq"]["masks"]
|
||||
|
||||
temp_path = os.path.join(dataset,'*/')
|
||||
pathes = glob.glob(temp_path)
|
||||
for dir_item in pathes:
|
||||
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
|
||||
print("processing %s"%dir_item,end='\r')
|
||||
dir_path = os.path.dirname(join_path[1])
|
||||
dir_name = os.path.join(mask_dir, os.path.basename(dir_path))
|
||||
# print(dir_name)
|
||||
temp_list = []
|
||||
for item in join_path:
|
||||
img_name = os.path.basename(item)
|
||||
img_name, _ = os.path.splitext(img_name)
|
||||
mask_name = os.path.join(dir_name, img_name + ".png")
|
||||
if not os.path.exists(mask_name):
|
||||
print(mask_name)
|
||||
@@ -0,0 +1,113 @@
|
||||
张雨绮 --> 1
|
||||
刘亦菲 --> 2
|
||||
周杰伦 --> 3
|
||||
关晓彤 --> 4
|
||||
林志玲 --> 13
|
||||
103许岚解压密码coshunter.top --> 18
|
||||
19 --> [YOUMI灏よ湝鑽焆 2020.09.03 VOL.521 濡插繁_Toxic [61P508MB]
|
||||
金晨 --> 20
|
||||
容祖儿 --> 21
|
||||
张歆艺 --> 22
|
||||
杨颖 --> 23
|
||||
倪妮 --> 24
|
||||
|
||||
G:/4K/90 佟丽娅 --> H:/face_data/VGGFace2_HQ/27
|
||||
273 赵粤 --> 28
|
||||
G:/4K/291-不支持在线打开,请下载后解压玥儿玥er --> H:/face_data/VGGFace2_HQ/29
|
||||
G:/4K/114 张紫宁 --> H:/face_data/VGGFace2_HQ/30
|
||||
G:/4K/120 舒淇 --> H:/face_data/VGGFace2_HQ/31
|
||||
G:/4K/124 张子枫 --> H:/face_data/VGGFace2_HQ/32
|
||||
G:/4K/173 王心凌 --> H:/face_data/VGGFace2_HQ/33
|
||||
G:/4K/191 郑爽 --> H:/face_data/VGGFace2_HQ/34
|
||||
G:/4K/286尤妮丝b/防何蟹文件夹(勿在线解压) 图包 02/防何蟹文件夹(勿在线解压) 图包 02 --> H:/face_data/VGGFace2_HQ/35
|
||||
G:/4K/辛芷蕾 --> H:/face_data/VGGFace2_HQ/36
|
||||
G:/4K/197 邓家佳 --> H:/face_data/VGGFace2_HQ/37
|
||||
G:/4K/195 袁姗姗 --> H:/face_data/VGGFace2_HQ/38
|
||||
G:/4K/249 希林娜依.高 --> H:/face_data/VGGFace2_HQ/39
|
||||
G:/4K/素材合/母其弥雅原片素材 --> H:/face_data/VGGFace2_HQ/40
|
||||
G:/4K/素材合/张含韵系列 --> H:/face_data/VGGFace2_HQ/41
|
||||
G:/4K/231胡静 --> H:/face_data/VGGFace2_HQ/42
|
||||
G:/4K/259 周洁琼 --> H:/face_data/VGGFace2_HQ/43
|
||||
G:/4K/190 李晟 --> H:/face_data/VGGFace2_HQ/44
|
||||
G:/4K/193 张檬 --> H:/face_data/VGGFace2_HQ/45
|
||||
G:/4K/174 潘之琳 --> H:/face_data/VGGFace2_HQ/46
|
||||
G:/4K/196 陈都灵 --> H:/face_data/VGGFace2_HQ/47
|
||||
G:/4K/236 李亚男 --> H:/face_data/VGGFace2_HQ/48
|
||||
G:/4K/221 江疏影 --> H:/face_data/VGGFace2_HQ/49
|
||||
G:/4K/253 刘羽琦 --> H:/face_data/VGGFace2_HQ/50
|
||||
G:/4K/282 李艺彤 --> H:/face_data/VGGFace2_HQ/51
|
||||
G:/4K/沈梦辰系列 --> H:/face_data/VGGFace2_HQ/52
|
||||
G:/4K/293 张怡 --> H:/face_data/VGGFace2_HQ/53
|
||||
G:/4K/私多个文件/11.7 --> H:/face_data/VGGFace2_HQ/54
|
||||
G:/4K/私多个文件/2016.3.31木子私房/原片 --> H:/face_data/VGGFace2_HQ/55
|
||||
G:/4K/私多个文件/李可人 原片 --> H:/face_data/VGGFace2_HQ/56
|
||||
G:/4K/私多个文件/刘雨欣 --> H:/face_data/VGGFace2_HQ/57
|
||||
G:/4K/私多个文件/新建文件夹 --> H:/face_data/VGGFace2_HQ/58
|
||||
G:/4K/私多个文件/新雯 --> H:/face_data/VGGFace2_HQ/58
|
||||
G:/4K/私多个文件/原片 --> H:/face_data/VGGFace2_HQ/59
|
||||
G:/4K/私多个文件/原片22 --> H:/face_data/VGGFace2_HQ/60
|
||||
G:/4K/私多个文件/原图 --> H:/face_data/VGGFace2_HQ/61
|
||||
G:/4K/徐璐系列 --> H:/face_data/VGGFace2_HQ/62
|
||||
G:/4K/徐冬冬 --> H:/face_data/VGGFace2_HQ/63
|
||||
G:/4K/徐静蕾 --> H:/face_data/VGGFace2_HQ/64
|
||||
G:/4K/张靓颖 --> H:/face_data/VGGFace2_HQ/65
|
||||
G:/4K/私多个文件/原片22 --> H:/face_data/hg
|
||||
G:/4K/私多个文件/原片22 --> H:/face_data/hg1
|
||||
G:/4K/张靓颖 --> H:/face_data/VGGFace2_HQ/65
|
||||
G:/4K/凉高定妆 --> H:/face_data/VGGFace2_HQ/66
|
||||
G:/4K/林志玲-初整 --> H:/face_data/VGGFace2_HQ/67
|
||||
G:/4K/000未整理/张小斐 --> H:/face_data/VGGFace2_HQ/68
|
||||
G:/4K/000未整理/马思纯系列1 --> H:/face_data/VGGFace2_HQ/69
|
||||
G:/4K/000未整理/王丽坤 --> H:/face_data/VGGFace2_HQ/70
|
||||
G:/4K/000未整理/刘涛全部 --> H:/face_data/VGGFace2_HQ/71
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/张子枫 --> H:/face_data/VGGFace2_HQ/72
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/李小璐 --> H:/face_data/VGGFace2_HQ/73
|
||||
G:/4K/295-不支持在线打开,请下载后解压镜酱 --> H:/face_data/VGGFace2_HQ/74
|
||||
G:/4K/297-不支持在线打开,请下载后解压一只云烧 --> H:/face_data/VGGFace2_HQ/74
|
||||
G:/4K/295-不支持在线打开,请下载后解压镜酱 --> H:/face_data/VGGFace2_HQ/75
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/海报艺人可用素材/李世鹏可用 --> H:/face_data/VGGFace2_HQ/76
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/海报艺人可用素材/李菲儿可用 --> H:/face_data/VGGFace2_HQ/77
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/程砚秋 --> H:/face_data/VGGFace2_HQ/78
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/白舟夏绿 --> H:/face_data/VGGFace2_HQ/79
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/果果白舟夏绿 --> H:/face_data/VGGFace2_HQ/80
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/陶西安谧 --> H:/face_data/VGGFace2_HQ/81
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/祖儿紫枫 --> H:/face_data/VGGFace2_HQ/82
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/海报备份/祖儿紫枫2 --> H:/face_data/VGGFace2_HQ/83
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照0905-0912 --> H:/face_data/VGGFace2_HQ/84
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照0913-0922 --> H:/face_data/VGGFace2_HQ/85
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照0923-1004 --> H:/face_data/VGGFace2_HQ/86
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1005-1010 --> H:/face_data/VGGFace2_HQ/87
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1011-1017 --> H:/face_data/VGGFace2_HQ/88
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1018-1101 --> H:/face_data/VGGFace2_HQ/89
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1102-1114 --> H:/face_data/VGGFace2_HQ/90
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1115-1122 --> H:/face_data/VGGFace2_HQ/91
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1123-1204 --> H:/face_data/VGGFace2_HQ/92
|
||||
G:/4K/000未整理/打包区/李小璐宋祖儿张子枫等打包/剧照1205-1212 --> H:/face_data/VGGFace2_HQ/93
|
||||
G:/4K/000未整理/黄奕1 --> H:/face_data/VGGFace2_HQ/94
|
||||
G:/4K/000未整理/阚清子 --> H:/face_data/VGGFace2_HQ/95
|
||||
G:/4K/000未整理/何穗热裤李小璐等 --> H:/face_data/VGGFace2_HQ/96
|
||||
G:/4K/000未整理/景甜热裤泳池 --> H:/face_data/VGGFace2_HQ/97
|
||||
H:/face_data/VGGFace2_HQ/77 --> H:/face_data/VGGFace2_HQ/98
|
||||
G:/4K/16 吴映洁空姐制服 --> H:/face_data/VGGFace2_HQ/98
|
||||
G:/4K/2 --> H:/face_data/VGGFace2_HQ/99
|
||||
G:/4K/108 张韶涵 --> H:/face_data/VGGFace2_HQ\108 张韶涵
|
||||
G:/4K/110 张小斐 --> H:/face_data/VGGFace2_HQ\110 张小斐
|
||||
G:/4K/111 张小斐 --> H:/face_data/VGGFace2_HQ\111 张小斐
|
||||
G:/4K/119 宋茜 --> H:/face_data/VGGFace2_HQ\119 宋茜
|
||||
G:/4K/120 舒淇 --> H:/face_data/VGGFace2_HQ\120 舒淇
|
||||
G:/4K/125 舒淇合集 --> H:/face_data/VGGFace2_HQ\125 舒淇合集
|
||||
G:/4K/各地素人拍摄/新建文件夹 (6) --> H:/face_data/VGGFace2_HQ\新建文件夹 (6)
|
||||
G:/4K/各地素人拍摄/原片(6) --> H:/face_data/VGGFace2_HQ\原片(6)
|
||||
G:/4K/各地素人拍摄/20161113舞剧系拍摄jpg --> H:/face_data/VGGFace2_HQ\20161113舞剧系拍摄jpg
|
||||
G:/4K/各地素人拍摄/155 --> H:/face_data/VGGFace2_HQ\155
|
||||
G:/4K/各地素人拍摄/17 --> H:/face_data/VGGFace2_HQ\17
|
||||
G:/4K/249刘飞er解压密码coshunter.com或coshunter.top --> H:/face_data/VGGFace2_HQ\249刘飞er解压密码coshunter.com或coshunter
|
||||
H:/face_data/cw --> H:/face_data/VGGFace2_HQ\cw
|
||||
H:/face_data/cw --> H:/face_data/VGGFace2_HQ\cw
|
||||
G:/4K/柳岩棚拍 --> H:/face_data/VGGFace2_HQ\柳岩棚拍
|
||||
G:/4K/216 卢靖姗 --> H:/face_data/VGGFace2_HQ\216 卢靖姗
|
||||
G:/4K/217 卢靖姗 --> H:/face_data/VGGFace2_HQ\217 卢靖姗
|
||||
G:/4K/白百合 --> H:/face_data/VGGFace2_HQ\白百合
|
||||
G:/4K/白百何 --> H:/face_data/VGGFace2_HQ\白百何
|
||||
G:/4K/张雨绮 --> H:/face_data/VGGFace2_HQ\张雨绮
|
||||
G:/4K/179 冯薪朵 --> H:/face_data/VGGFace2_HQ\179 冯薪朵
|
||||
+44
-13
@@ -5,7 +5,7 @@
|
||||
# Created Date: Tuesday February 1st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 2nd February 2022 4:13:28 pm
|
||||
# Last Modified: Sunday, 24th April 2022 2:01:47 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -26,6 +26,8 @@ import tkinter.ttk as ttk
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from insightface_func.face_detect_crop_multi import Face_detect_crop
|
||||
|
||||
class TextRedirector(object):
|
||||
@@ -113,6 +115,7 @@ class Application(tk.Frame):
|
||||
label_frame.columnconfigure(0, weight=1)
|
||||
label_frame.columnconfigure(1, weight=1)
|
||||
label_frame.columnconfigure(2, weight=1)
|
||||
label_frame.columnconfigure(3, weight=1)
|
||||
|
||||
tk.Label(label_frame, text="Crop Size:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=0,sticky=tk.EW)
|
||||
@@ -123,6 +126,9 @@ class Application(tk.Frame):
|
||||
tk.Label(label_frame, text="Target Format:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
tk.Label(label_frame, text="Blurry Thredhold:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=3,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
|
||||
test_frame = tk.Frame(self.master)
|
||||
@@ -151,8 +157,10 @@ class Application(tk.Frame):
|
||||
self.format_com["value"] = ["png","jpg"]
|
||||
self.format_com.current(0)
|
||||
|
||||
|
||||
|
||||
self.thredhold = tkinter.StringVar()
|
||||
tk.Entry(test_frame, textvariable= self.thredhold, font=font_list)\
|
||||
.grid(row=0,column=3,sticky=tk.EW)
|
||||
self.thredhold.set("70")
|
||||
#################################################################################################
|
||||
scale_frame = tk.Frame(self.master)
|
||||
scale_frame.pack(fill="both", padx=5,pady=5)
|
||||
@@ -200,8 +208,10 @@ class Application(tk.Frame):
|
||||
|
||||
def select_task(self):
|
||||
path = askdirectory()
|
||||
print("Selected source directory: %s"%path)
|
||||
self.img_path.set(path)
|
||||
|
||||
if os.path.isdir(path):
|
||||
print("Selected source directory: %s"%path)
|
||||
self.img_path.set(path)
|
||||
|
||||
def Select_Target(self):
|
||||
thread_update = threading.Thread(target=self.select_target_task)
|
||||
@@ -209,8 +219,9 @@ class Application(tk.Frame):
|
||||
|
||||
def select_target_task(self):
|
||||
path = askdirectory()
|
||||
print("Selected target directory: %s"%path)
|
||||
self.save_path.set(path)
|
||||
if os.path.isdir(path):
|
||||
print("Selected target directory: %s"%path)
|
||||
self.save_path.set(path)
|
||||
|
||||
def Crop(self):
|
||||
thread_update = threading.Thread(target=self.crop_task)
|
||||
@@ -218,39 +229,59 @@ class Application(tk.Frame):
|
||||
|
||||
def crop_task(self):
|
||||
mode = self.align_com.get()
|
||||
if mode == "VGGFace":
|
||||
mode = "None"
|
||||
crop_size = int(self.test_com.get())
|
||||
|
||||
path = self.img_path.get()
|
||||
tg_path = self.save_path.get()
|
||||
blur_t = self.thredhold.get()
|
||||
basepath = os.path.splitext(os.path.basename(path))[0]
|
||||
tg_path = os.path.join("H:/face_data/VGGFace2_HQ",basepath)
|
||||
print("target path: ",tg_path)
|
||||
if not os.path.exists(tg_path):
|
||||
os.makedirs(tg_path)
|
||||
tg_format = self.format_com.get()
|
||||
min_scale = float(self.min_scale.get())
|
||||
blur_t = 100.0
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
blur_t = float(blur_t)
|
||||
print("Blurry thredhold %f"%blur_t)
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6,\
|
||||
det_size=(640,640),mode = mode,crop_size=crop_size,ratio=min_scale)
|
||||
log_file = "./dataset_readme.txt"
|
||||
with open(log_file,'a+') as logf: # ,encoding='UTF-8'
|
||||
logf.writelines("%s --> %s\n"%(path,tg_path))
|
||||
if path and tg_path:
|
||||
imgs_list = []
|
||||
if os.path.isdir(path):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(path,"*"))
|
||||
for item in imgs:
|
||||
# imgs = glob.glob(os.path.join(path,"**"))
|
||||
for item in glob.iglob(os.path.join(path,"**"),recursive=True):
|
||||
imgs_list.append(item)
|
||||
# print(imgs_list)
|
||||
index = 0
|
||||
for img in imgs_list:
|
||||
print(img)
|
||||
attr_img_ori= cv2.imread(img)
|
||||
try:
|
||||
attr_img_ori = cv2.imdecode(np.fromfile(img, dtype=np.uint8),-1)
|
||||
except:
|
||||
print("Illegal file!")
|
||||
continue
|
||||
# attr_img_ori= cv2.imread(img)
|
||||
try:
|
||||
attr_img_align_crop, _ = self.detect.get(attr_img_ori)
|
||||
sub_index = 0
|
||||
if len(attr_img_align_crop) < 1:
|
||||
print("Small face")
|
||||
for face_i in attr_img_align_crop:
|
||||
imageVar = cv2.Laplacian(face_i, cv2.CV_64F).var()
|
||||
f_path =os.path.join(tg_path, str(index).zfill(6)+"_%d.%s"%(sub_index,tg_format))
|
||||
# print("save path: ",f_path)
|
||||
if imageVar < blur_t:
|
||||
print("Over blurry image!")
|
||||
continue
|
||||
# face_i = cv2.putText(face_i, '%.1f'%imageVar,(50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
cv2.imwrite(f_path,face_i)
|
||||
# cv2.imwrite(f_path,face_i)
|
||||
cv2.imencode('.png',face_i)[1].tofile(f_path)
|
||||
sub_index += 1
|
||||
index += 1
|
||||
except:
|
||||
|
||||
@@ -0,0 +1,279 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: face_crop.py
|
||||
# Created Date: Tuesday February 1st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 22nd April 2022 8:43:40 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import sys
|
||||
import glob
|
||||
import json
|
||||
import tkinter
|
||||
from tkinter.filedialog import askdirectory
|
||||
|
||||
import threading
|
||||
import tkinter as tk
|
||||
import tkinter.ttk as ttk
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from insightface_func.face_detect_crop_multi import Face_detect_crop
|
||||
|
||||
class TextRedirector(object):
|
||||
def __init__(self, widget, tag="stdout"):
|
||||
self.widget = widget
|
||||
self.tag = tag
|
||||
|
||||
def write(self, str):
|
||||
self.widget.configure(state="normal")
|
||||
self.widget.insert("end", str, (self.tag,))
|
||||
self.widget.configure(state="disabled")
|
||||
self.widget.see(tk.END)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
#############################################################
|
||||
# Main Class
|
||||
#############################################################
|
||||
|
||||
class Application(tk.Frame):
|
||||
|
||||
|
||||
def __init__(self, master=None):
|
||||
tk.Frame.__init__(self, master,bg='black')
|
||||
# self.font_size = 16
|
||||
self.font_list = ("Times New Roman",14)
|
||||
self.padx = 5
|
||||
self.pady = 5
|
||||
self.window_init()
|
||||
|
||||
def __label_text__(self, usr, root):
|
||||
return "User Name: %s\nWorkspace: %s"%(usr, root)
|
||||
|
||||
def window_init(self):
|
||||
cwd = os.getcwd()
|
||||
self.master.title('Face Crop - %s'%cwd)
|
||||
# self.master.iconbitmap('./utilities/_logo.ico')
|
||||
self.master.geometry("{}x{}".format(640, 600))
|
||||
|
||||
font_list = self.font_list
|
||||
|
||||
#################################################################################################
|
||||
list_frame = tk.Frame(self.master)
|
||||
list_frame.pack(fill="both", padx=5,pady=5)
|
||||
list_frame.columnconfigure(0, weight=1)
|
||||
list_frame.columnconfigure(1, weight=1)
|
||||
list_frame.columnconfigure(2, weight=1)
|
||||
|
||||
self.img_path = tkinter.StringVar()
|
||||
|
||||
tk.Label(list_frame, text="Image/Video Path:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
tk.Entry(list_frame, textvariable= self.img_path, font=font_list)\
|
||||
.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
|
||||
tk.Button(list_frame, text = "Select Path", font=font_list,
|
||||
command = self.Select, bg='#F4A460', fg='#F5F5F5')\
|
||||
.grid(row=0,column=2,sticky=tk.EW)
|
||||
#################################################################################################
|
||||
list_frame1 = tk.Frame(self.master)
|
||||
list_frame1.pack(fill="both", padx=5,pady=5)
|
||||
list_frame1.columnconfigure(0, weight=1)
|
||||
list_frame1.columnconfigure(1, weight=1)
|
||||
list_frame1.columnconfigure(2, weight=1)
|
||||
|
||||
self.save_path = tkinter.StringVar()
|
||||
|
||||
tk.Label(list_frame1, text="Target Path:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
tk.Entry(list_frame1, textvariable= self.save_path, font=font_list)\
|
||||
.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
|
||||
tk.Button(list_frame1, text = "Select Path", font=font_list,
|
||||
command = self.Select_Target, bg='#F4A460', fg='#F5F5F5')\
|
||||
.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
label_frame = tk.Frame(self.master)
|
||||
label_frame.pack(fill="both", padx=5,pady=5)
|
||||
label_frame.columnconfigure(0, weight=1)
|
||||
label_frame.columnconfigure(1, weight=1)
|
||||
label_frame.columnconfigure(2, weight=1)
|
||||
|
||||
tk.Label(label_frame, text="Crop Size:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
tk.Label(label_frame, text="Align Mode:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
tk.Label(label_frame, text="Target Format:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
|
||||
test_frame = tk.Frame(self.master)
|
||||
test_frame.pack(fill="both", padx=5,pady=5)
|
||||
test_frame.columnconfigure(0, weight=1)
|
||||
test_frame.columnconfigure(1, weight=1)
|
||||
test_frame.columnconfigure(2, weight=1)
|
||||
|
||||
self.test_var = tkinter.StringVar()
|
||||
|
||||
self.test_com = ttk.Combobox(test_frame, textvariable=self.test_var)
|
||||
self.test_com.grid(row=0,column=0,sticky=tk.EW)
|
||||
self.test_com["value"] = [256,512,768,1024]
|
||||
self.test_com.current(1)
|
||||
|
||||
self.align_var = tkinter.StringVar()
|
||||
self.align_com = ttk.Combobox(test_frame, textvariable=self.align_var)
|
||||
self.align_com.grid(row=0,column=1,sticky=tk.EW)
|
||||
self.align_com["value"] = ["VGGFace","ffhq"]
|
||||
self.align_com.current(0)
|
||||
|
||||
self.format_var = tkinter.StringVar()
|
||||
|
||||
self.format_com = ttk.Combobox(test_frame, textvariable=self.format_var)
|
||||
self.format_com.grid(row=0,column=2,sticky=tk.EW)
|
||||
self.format_com["value"] = ["png","jpg"]
|
||||
self.format_com.current(0)
|
||||
|
||||
|
||||
|
||||
#################################################################################################
|
||||
scale_frame = tk.Frame(self.master)
|
||||
scale_frame.pack(fill="both", padx=5,pady=5)
|
||||
scale_frame.columnconfigure(0, weight=2)
|
||||
label_frame.columnconfigure(1, weight=1)
|
||||
# label_frame.columnconfigure(2, weight=1)
|
||||
|
||||
tk.Label(scale_frame, text="Min Size:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=0,sticky=tk.EW)
|
||||
self.min_scale = tkinter.StringVar()
|
||||
tk.Scale(scale_frame, from_=0.5, to=2.0, length=500, orient=tk.HORIZONTAL, variable= self.min_scale,\
|
||||
font=font_list, resolution=0.1).grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
test_frame1 = tk.Frame(self.master)
|
||||
test_frame1.pack(fill="both", padx=5,pady=5)
|
||||
test_frame1.columnconfigure(0, weight=1)
|
||||
# test_frame1.columnconfigure(1, weight=1)
|
||||
|
||||
test_update_button = tk.Button(test_frame1, text = "Crop",
|
||||
font=font_list, command = self.Crop, bg='#F4A460', fg='#F5F5F5')
|
||||
test_update_button.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
|
||||
|
||||
#################################################################################################
|
||||
|
||||
text = tk.Text(self.master, wrap="word")
|
||||
text.pack(fill="both",expand="yes", padx=5,pady=5)
|
||||
|
||||
|
||||
sys.stdout = TextRedirector(text, "stdout")
|
||||
|
||||
self.init_algorithm()
|
||||
self.master.protocol("WM_DELETE_WINDOW", self.on_closing)
|
||||
|
||||
def init_algorithm(self):
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
|
||||
|
||||
# def __scaning_logs__(self):
|
||||
def Select(self):
|
||||
thread_update = threading.Thread(target=self.select_task)
|
||||
thread_update.start()
|
||||
|
||||
def select_task(self):
|
||||
path = askdirectory()
|
||||
print("Selected source directory: %s"%path)
|
||||
self.img_path.set(path)
|
||||
|
||||
def Select_Target(self):
|
||||
thread_update = threading.Thread(target=self.select_target_task)
|
||||
thread_update.start()
|
||||
|
||||
def select_target_task(self):
|
||||
path = askdirectory()
|
||||
print("Selected target directory: %s"%path)
|
||||
self.save_path.set(path)
|
||||
|
||||
def Crop(self):
|
||||
thread_update = threading.Thread(target=self.crop_task)
|
||||
thread_update.start()
|
||||
|
||||
def crop_task(self):
|
||||
mode = self.align_com.get()
|
||||
crop_size = int(self.test_com.get())
|
||||
|
||||
path = self.img_path.get()
|
||||
tg_path = self.save_path.get()
|
||||
if not os.path.exists(tg_path):
|
||||
os.makedirs(tg_path)
|
||||
tg_format = self.format_com.get()
|
||||
min_scale = float(self.min_scale.get())
|
||||
blur_t = 10.0
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6,\
|
||||
det_size=(640,640),mode = mode,crop_size=crop_size,ratio=min_scale)
|
||||
if path and tg_path:
|
||||
imgs_list = []
|
||||
if os.path.isdir(path):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(path,"*"))
|
||||
for item in imgs:
|
||||
imgs_list.append(item)
|
||||
# print(imgs_list)
|
||||
index = 0
|
||||
for img in imgs_list:
|
||||
print(img)
|
||||
attr_img_ori= cv2.imread(img)
|
||||
try:
|
||||
attr_img_align_crop, _ = self.detect.get(attr_img_ori)
|
||||
sub_index = 0
|
||||
if len(attr_img_align_crop) < 1:
|
||||
print("Small face")
|
||||
for face_i in attr_img_align_crop:
|
||||
imageVar = cv2.Laplacian(face_i, cv2.CV_64F).var()
|
||||
f_path =os.path.join(tg_path, str(index).zfill(6)+"_%d.%s"%(sub_index,tg_format))
|
||||
if imageVar < blur_t:
|
||||
print("Over blurry image!")
|
||||
continue
|
||||
# face_i = cv2.putText(face_i, '%.1f'%imageVar,(50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
cv2.imwrite(f_path,face_i)
|
||||
sub_index += 1
|
||||
index += 1
|
||||
except:
|
||||
print("Detect no face!")
|
||||
continue
|
||||
else:
|
||||
print("Input an image....")
|
||||
imgs_list.append(path)
|
||||
print("Process finished!")
|
||||
else:
|
||||
print("Pathes are invalid!")
|
||||
|
||||
def on_closing(self):
|
||||
|
||||
# self.__save_config__()
|
||||
self.master.destroy()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = Application()
|
||||
app.mainloop()
|
||||
+3
-3
@@ -5,7 +5,7 @@
|
||||
# Created Date: Tuesday February 1st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 2nd February 2022 11:17:04 pm
|
||||
# Last Modified: Friday, 15th April 2022 10:07:15 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -27,7 +27,7 @@ def getParameters():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--save_path', type=str, default="./output/",
|
||||
help="The root path for saving cropped images")
|
||||
parser.add_argument('-v', '--video', type=str, default="G:\\4K\\05.mp4",
|
||||
parser.add_argument('-v', '--video', type=str, default="G:\\4K\\Faith.Makes.Great.2021\\40.mp4",
|
||||
help="The path for input video")
|
||||
parser.add_argument('-c', '--crop_size', type=int, default=512,
|
||||
help="expected image resolution")
|
||||
@@ -39,7 +39,7 @@ def getParameters():
|
||||
choices=['jpg', 'png'],help="target file format")
|
||||
parser.add_argument('-i', '--interval', type=int, default=20,
|
||||
help="number of frames interval")
|
||||
parser.add_argument('-b', '--blur', type=float, default=10.0,
|
||||
parser.add_argument('-b', '--blur', type=float, default=20.0,
|
||||
help="blur degree")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
# Pre-trained Models and Other Data
|
||||
|
||||
Download pre-trained models and other data. Put them in this folder.
|
||||
|
||||
1. [Pretrained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth)
|
||||
1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
|
||||
1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)
|
||||
@@ -0,0 +1,7 @@
|
||||
# flake8: noqa
|
||||
from .archs import *
|
||||
from .data import *
|
||||
from .models import *
|
||||
from .utils import *
|
||||
|
||||
# from .version import *
|
||||
@@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
# automatically scan and import arch modules for registry
|
||||
# scan all the files that end with '_arch.py' under the archs folder
|
||||
arch_folder = osp.dirname(osp.abspath(__file__))
|
||||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
||||
# import all the arch modules
|
||||
_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]
|
||||
@@ -0,0 +1,245 @@
|
||||
import torch.nn as nn
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
|
||||
def conv3x3(inplanes, outplanes, stride=1):
|
||||
"""A simple wrapper for 3x3 convolution with padding.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
outplanes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
"""
|
||||
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""Basic residual block used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
"""
|
||||
expansion = 1 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class IRBlock(nn.Module):
|
||||
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
||||
"""
|
||||
expansion = 1 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
||||
super(IRBlock, self).__init__()
|
||||
self.bn0 = nn.BatchNorm2d(inplanes)
|
||||
self.conv1 = conv3x3(inplanes, inplanes)
|
||||
self.bn1 = nn.BatchNorm2d(inplanes)
|
||||
self.prelu = nn.PReLU()
|
||||
self.conv2 = conv3x3(inplanes, planes, stride)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.use_se = use_se
|
||||
if self.use_se:
|
||||
self.se = SEBlock(planes)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.bn0(x)
|
||||
out = self.conv1(out)
|
||||
out = self.bn1(out)
|
||||
out = self.prelu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
if self.use_se:
|
||||
out = self.se(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.prelu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""Bottleneck block used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
"""
|
||||
expansion = 4 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
||||
|
||||
Args:
|
||||
channel (int): Channel number of inputs.
|
||||
reduction (int): Channel reduction ration. Default: 16.
|
||||
"""
|
||||
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SEBlock, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
||||
nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class ResNetArcFace(nn.Module):
|
||||
"""ArcFace with ResNet architectures.
|
||||
|
||||
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
||||
|
||||
Args:
|
||||
block (str): Block used in the ArcFace architecture.
|
||||
layers (tuple(int)): Block numbers in each layer.
|
||||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, block, layers, use_se=True):
|
||||
if block == 'IRBlock':
|
||||
block = IRBlock
|
||||
self.inplanes = 64
|
||||
self.use_se = use_se
|
||||
super(ResNetArcFace, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.prelu = nn.PReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.bn4 = nn.BatchNorm2d(512)
|
||||
self.dropout = nn.Dropout()
|
||||
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
||||
self.bn5 = nn.BatchNorm1d(512)
|
||||
|
||||
# initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, num_blocks):
|
||||
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.prelu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.bn4(x)
|
||||
x = self.dropout(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc5(x)
|
||||
x = self.bn5(x)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,312 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
|
||||
from .gfpganv1_arch import ResUpBlock
|
||||
from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
||||
StyleGAN2GeneratorBilinear)
|
||||
|
||||
|
||||
class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
|
||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
||||
|
||||
It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
|
||||
deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
num_mlp=8,
|
||||
channel_multiplier=2,
|
||||
lr_mlp=0.01,
|
||||
narrow=1,
|
||||
sft_half=False):
|
||||
super(StyleGAN2GeneratorBilinearSFT, self).__init__(
|
||||
out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
lr_mlp=lr_mlp,
|
||||
narrow=narrow)
|
||||
self.sft_half = sft_half
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
conditions,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2GeneratorBilinearSFT.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
conditions (list[Tensor]): SFT conditions to generators.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half: # only apply SFT to half of the channels
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else: # apply SFT to all the channels
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class GFPGANBilinear(nn.Module):
|
||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
||||
|
||||
It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
|
||||
deployment. It can be easily converted to the clean version: GFPGANv1Clean.
|
||||
|
||||
|
||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
||||
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
lr_mlp=0.01,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=False):
|
||||
|
||||
super(GFPGANBilinear, self).__init__()
|
||||
self.input_is_latent = input_is_latent
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
'16': int(512 * unet_narrow),
|
||||
'32': int(512 * unet_narrow),
|
||||
'64': int(256 * channel_multiplier * unet_narrow),
|
||||
'128': int(128 * channel_multiplier * unet_narrow),
|
||||
'256': int(64 * channel_multiplier * unet_narrow),
|
||||
'512': int(32 * channel_multiplier * unet_narrow),
|
||||
'1024': int(16 * channel_multiplier * unet_narrow)
|
||||
}
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
first_out_size = 2**(int(math.log(out_size, 2)))
|
||||
|
||||
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
||||
|
||||
# downsample
|
||||
in_channels = channels[f'{first_out_size}']
|
||||
self.conv_body_down = nn.ModuleList()
|
||||
for i in range(self.log_size, 2, -1):
|
||||
out_channels = channels[f'{2**(i - 1)}']
|
||||
self.conv_body_down.append(ResBlock(in_channels, out_channels))
|
||||
in_channels = out_channels
|
||||
|
||||
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
||||
|
||||
# upsample
|
||||
in_channels = channels['4']
|
||||
self.conv_body_up = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
||||
in_channels = out_channels
|
||||
|
||||
# to RGB
|
||||
self.toRGB = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
||||
|
||||
if different_w:
|
||||
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
||||
else:
|
||||
linear_out_channel = num_style_feat
|
||||
|
||||
self.final_linear = EqualLinear(
|
||||
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
||||
|
||||
# the decoder: stylegan2 generator with SFT modulations
|
||||
self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
lr_mlp=lr_mlp,
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
# load pre-trained stylegan2 model if necessary
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
# fix decoder without updating params
|
||||
if fix_decoder:
|
||||
for _, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT modulations (scale and shift)
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
if sft_half:
|
||||
sft_out_channels = out_channels
|
||||
else:
|
||||
sft_out_channels = out_channels * 2
|
||||
self.condition_scale.append(
|
||||
nn.Sequential(
|
||||
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
||||
ScaledLeakyReLU(0.2),
|
||||
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
||||
self.condition_shift.append(
|
||||
nn.Sequential(
|
||||
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
||||
ScaledLeakyReLU(0.2),
|
||||
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
||||
|
||||
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
||||
"""Forward function for GFPGANBilinear.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
"""
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
|
||||
# encoder
|
||||
feat = self.conv_body_first(x)
|
||||
for i in range(self.log_size - 2):
|
||||
feat = self.conv_body_down[i](feat)
|
||||
unet_skips.insert(0, feat)
|
||||
|
||||
feat = self.final_conv(feat)
|
||||
|
||||
# style code
|
||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||
if self.different_w:
|
||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||
|
||||
# decode
|
||||
for i in range(self.log_size - 2):
|
||||
# add unet skip
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layers
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
conditions.append(shift.clone())
|
||||
# generate rgb images
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
return_latents=return_latents,
|
||||
input_is_latent=self.input_is_latent,
|
||||
randomize_noise=randomize_noise)
|
||||
|
||||
return image, out_rgbs
|
||||
@@ -0,0 +1,439 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
|
||||
StyleGAN2Generator)
|
||||
from basicsr.ops.fused_act import FusedLeakyReLU
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
|
||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
||||
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
num_mlp=8,
|
||||
channel_multiplier=2,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
lr_mlp=0.01,
|
||||
narrow=1,
|
||||
sft_half=False):
|
||||
super(StyleGAN2GeneratorSFT, self).__init__(
|
||||
out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
resample_kernel=resample_kernel,
|
||||
lr_mlp=lr_mlp,
|
||||
narrow=narrow)
|
||||
self.sft_half = sft_half
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
conditions,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2GeneratorSFT.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
conditions (list[Tensor]): SFT conditions to generators.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half: # only apply SFT to half of the channels
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else: # apply SFT to all the channels
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
class ConvUpLayer(nn.Module):
|
||||
"""Convolutional upsampling layer. It uses bilinear upsampler + Conv.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
stride (int): Stride of the convolution. Default: 1
|
||||
padding (int): Zero-padding added to both sides of the input. Default: 0.
|
||||
bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
|
||||
bias_init_val (float): Bias initialized value. Default: 0.
|
||||
activate (bool): Whether use activateion. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
bias_init_val=0,
|
||||
activate=True):
|
||||
super(ConvUpLayer, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
# self.scale is used to scale the convolution weights, which is related to the common initializations.
|
||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
||||
|
||||
if bias and not activate:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
# activation
|
||||
if activate:
|
||||
if bias:
|
||||
self.activation = FusedLeakyReLU(out_channels)
|
||||
else:
|
||||
self.activation = ScaledLeakyReLU(0.2)
|
||||
else:
|
||||
self.activation = None
|
||||
|
||||
def forward(self, x):
|
||||
# bilinear upsample
|
||||
out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
# conv
|
||||
out = F.conv2d(
|
||||
out,
|
||||
self.weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
# activation
|
||||
if self.activation is not None:
|
||||
out = self.activation(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResUpBlock(nn.Module):
|
||||
"""Residual block with upsampling.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(ResUpBlock, self).__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
||||
self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
|
||||
self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
skip = self.skip(x)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class GFPGANv1(nn.Module):
|
||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
||||
|
||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
|
||||
applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
||||
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
resample_kernel=(1, 3, 3, 1),
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
lr_mlp=0.01,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=False):
|
||||
|
||||
super(GFPGANv1, self).__init__()
|
||||
self.input_is_latent = input_is_latent
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
'16': int(512 * unet_narrow),
|
||||
'32': int(512 * unet_narrow),
|
||||
'64': int(256 * channel_multiplier * unet_narrow),
|
||||
'128': int(128 * channel_multiplier * unet_narrow),
|
||||
'256': int(64 * channel_multiplier * unet_narrow),
|
||||
'512': int(32 * channel_multiplier * unet_narrow),
|
||||
'1024': int(16 * channel_multiplier * unet_narrow)
|
||||
}
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
first_out_size = 2**(int(math.log(out_size, 2)))
|
||||
|
||||
self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
|
||||
|
||||
# downsample
|
||||
in_channels = channels[f'{first_out_size}']
|
||||
self.conv_body_down = nn.ModuleList()
|
||||
for i in range(self.log_size, 2, -1):
|
||||
out_channels = channels[f'{2**(i - 1)}']
|
||||
self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
|
||||
in_channels = out_channels
|
||||
|
||||
self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
|
||||
|
||||
# upsample
|
||||
in_channels = channels['4']
|
||||
self.conv_body_up = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
|
||||
in_channels = out_channels
|
||||
|
||||
# to RGB
|
||||
self.toRGB = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
|
||||
|
||||
if different_w:
|
||||
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
||||
else:
|
||||
linear_out_channel = num_style_feat
|
||||
|
||||
self.final_linear = EqualLinear(
|
||||
channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
|
||||
|
||||
# the decoder: stylegan2 generator with SFT modulations
|
||||
self.stylegan_decoder = StyleGAN2GeneratorSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
resample_kernel=resample_kernel,
|
||||
lr_mlp=lr_mlp,
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
# load pre-trained stylegan2 model if necessary
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
# fix decoder without updating params
|
||||
if fix_decoder:
|
||||
for _, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT modulations (scale and shift)
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
if sft_half:
|
||||
sft_out_channels = out_channels
|
||||
else:
|
||||
sft_out_channels = out_channels * 2
|
||||
self.condition_scale.append(
|
||||
nn.Sequential(
|
||||
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
||||
ScaledLeakyReLU(0.2),
|
||||
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
|
||||
self.condition_shift.append(
|
||||
nn.Sequential(
|
||||
EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
|
||||
ScaledLeakyReLU(0.2),
|
||||
EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
|
||||
|
||||
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
||||
"""Forward function for GFPGANv1.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
"""
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
|
||||
# encoder
|
||||
feat = self.conv_body_first(x)
|
||||
for i in range(self.log_size - 2):
|
||||
feat = self.conv_body_down[i](feat)
|
||||
unet_skips.insert(0, feat)
|
||||
|
||||
feat = self.final_conv(feat)
|
||||
|
||||
# style code
|
||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||
if self.different_w:
|
||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||
|
||||
# decode
|
||||
for i in range(self.log_size - 2):
|
||||
# add unet skip
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layers
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
conditions.append(shift.clone())
|
||||
# generate rgb images
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
return_latents=return_latents,
|
||||
input_is_latent=self.input_is_latent,
|
||||
randomize_noise=randomize_noise)
|
||||
|
||||
return image, out_rgbs
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class FacialComponentDiscriminator(nn.Module):
|
||||
"""Facial component (eyes, mouth, noise) discriminator used in GFPGAN.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(FacialComponentDiscriminator, self).__init__()
|
||||
# It now uses a VGG-style architectrue with fixed model size
|
||||
self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
|
||||
self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
|
||||
|
||||
def forward(self, x, return_feats=False):
|
||||
"""Forward function for FacialComponentDiscriminator.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_feats (bool): Whether to return intermediate features. Default: False.
|
||||
"""
|
||||
feat = self.conv1(x)
|
||||
feat = self.conv3(self.conv2(feat))
|
||||
rlt_feats = []
|
||||
if return_feats:
|
||||
rlt_feats.append(feat.clone())
|
||||
feat = self.conv5(self.conv4(feat))
|
||||
if return_feats:
|
||||
rlt_feats.append(feat.clone())
|
||||
out = self.final_conv(feat)
|
||||
|
||||
if return_feats:
|
||||
return out, rlt_feats
|
||||
else:
|
||||
return out, None
|
||||
@@ -0,0 +1,324 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
|
||||
|
||||
|
||||
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
||||
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
||||
|
||||
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
||||
super(StyleGAN2GeneratorCSFT, self).__init__(
|
||||
out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
narrow=narrow)
|
||||
self.sft_half = sft_half
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
conditions,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2GeneratorCSFT.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
conditions (list[Tensor]): SFT conditions to generators.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
|
||||
# the conditions may have fewer levels
|
||||
if i < len(conditions):
|
||||
# SFT part to combine the conditions
|
||||
if self.sft_half: # only apply SFT to half of the channels
|
||||
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
||||
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
||||
out = torch.cat([out_same, out_sft], dim=1)
|
||||
else: # apply SFT to all the channels
|
||||
out = out * conditions[i - 1] + conditions[i]
|
||||
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block with bilinear upsampling/downsampling.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, mode='down'):
|
||||
super(ResBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
||||
if mode == 'down':
|
||||
self.scale_factor = 0.5
|
||||
elif mode == 'up':
|
||||
self.scale_factor = 2
|
||||
|
||||
def forward(self, x):
|
||||
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
||||
# upsample/downsample
|
||||
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
||||
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
||||
# skip
|
||||
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
||||
skip = self.skip(x)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class GFPGANv1Clean(nn.Module):
|
||||
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
||||
|
||||
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
||||
|
||||
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
||||
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
||||
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
||||
narrow (float): The narrow ratio for channels. Default: 1.
|
||||
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=1,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=False,
|
||||
different_w=False,
|
||||
narrow=1,
|
||||
sft_half=False):
|
||||
|
||||
super(GFPGANv1Clean, self).__init__()
|
||||
self.input_is_latent = input_is_latent
|
||||
self.different_w = different_w
|
||||
self.num_style_feat = num_style_feat
|
||||
|
||||
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
||||
channels = {
|
||||
'4': int(512 * unet_narrow),
|
||||
'8': int(512 * unet_narrow),
|
||||
'16': int(512 * unet_narrow),
|
||||
'32': int(512 * unet_narrow),
|
||||
'64': int(256 * channel_multiplier * unet_narrow),
|
||||
'128': int(128 * channel_multiplier * unet_narrow),
|
||||
'256': int(64 * channel_multiplier * unet_narrow),
|
||||
'512': int(32 * channel_multiplier * unet_narrow),
|
||||
'1024': int(16 * channel_multiplier * unet_narrow)
|
||||
}
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
first_out_size = 2**(int(math.log(out_size, 2)))
|
||||
|
||||
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
|
||||
|
||||
# downsample
|
||||
in_channels = channels[f'{first_out_size}']
|
||||
self.conv_body_down = nn.ModuleList()
|
||||
for i in range(self.log_size, 2, -1):
|
||||
out_channels = channels[f'{2**(i - 1)}']
|
||||
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
||||
in_channels = out_channels
|
||||
|
||||
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
||||
|
||||
# upsample
|
||||
in_channels = channels['4']
|
||||
self.conv_body_up = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
|
||||
in_channels = out_channels
|
||||
|
||||
# to RGB
|
||||
self.toRGB = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
|
||||
|
||||
if different_w:
|
||||
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
||||
else:
|
||||
linear_out_channel = num_style_feat
|
||||
|
||||
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
||||
|
||||
# the decoder: stylegan2 generator with SFT modulations
|
||||
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
||||
out_size=out_size,
|
||||
num_style_feat=num_style_feat,
|
||||
num_mlp=num_mlp,
|
||||
channel_multiplier=channel_multiplier,
|
||||
narrow=narrow,
|
||||
sft_half=sft_half)
|
||||
|
||||
# load pre-trained stylegan2 model if necessary
|
||||
if decoder_load_path:
|
||||
self.stylegan_decoder.load_state_dict(
|
||||
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
||||
# fix decoder without updating params
|
||||
if fix_decoder:
|
||||
for _, param in self.stylegan_decoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# for SFT modulations (scale and shift)
|
||||
self.condition_scale = nn.ModuleList()
|
||||
self.condition_shift = nn.ModuleList()
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
if sft_half:
|
||||
sft_out_channels = out_channels
|
||||
else:
|
||||
sft_out_channels = out_channels * 2
|
||||
self.condition_scale.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||
self.condition_shift.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
||||
|
||||
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
||||
"""Forward function for GFPGANv1Clean.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input images.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
"""
|
||||
conditions = []
|
||||
unet_skips = []
|
||||
out_rgbs = []
|
||||
|
||||
# encoder
|
||||
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
||||
for i in range(self.log_size - 2):
|
||||
feat = self.conv_body_down[i](feat)
|
||||
unet_skips.insert(0, feat)
|
||||
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
||||
|
||||
# style code
|
||||
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
||||
if self.different_w:
|
||||
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
||||
|
||||
# decode
|
||||
for i in range(self.log_size - 2):
|
||||
# add unet skip
|
||||
feat = feat + unet_skips[i]
|
||||
# ResUpLayer
|
||||
feat = self.conv_body_up[i](feat)
|
||||
# generate scale and shift for SFT layers
|
||||
scale = self.condition_scale[i](feat)
|
||||
conditions.append(scale.clone())
|
||||
shift = self.condition_shift[i](feat)
|
||||
conditions.append(shift.clone())
|
||||
# generate rgb images
|
||||
if return_rgb:
|
||||
out_rgbs.append(self.toRGB[i](feat))
|
||||
|
||||
# decoder
|
||||
image, _ = self.stylegan_decoder([style_code],
|
||||
conditions,
|
||||
return_latents=return_latents,
|
||||
input_is_latent=self.input_is_latent,
|
||||
randomize_noise=randomize_noise)
|
||||
|
||||
return image, out_rgbs
|
||||
@@ -0,0 +1,613 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class NormStyleCode(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""Normalize the style codes.
|
||||
|
||||
Args:
|
||||
x (Tensor): Style codes with shape (b, c).
|
||||
|
||||
Returns:
|
||||
Tensor: Normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
"""Equalized Linear as StyleGAN2.
|
||||
|
||||
Args:
|
||||
in_channels (int): Size of each sample.
|
||||
out_channels (int): Size of each output sample.
|
||||
bias (bool): If set to ``False``, the layer will not learn an additive
|
||||
bias. Default: ``True``.
|
||||
bias_init_val (float): Bias initialized value. Default: 0.
|
||||
lr_mul (float): Learning rate multiplier. Default: 1.
|
||||
activation (None | str): The activation after ``linear`` operation.
|
||||
Supported: 'fused_lrelu', None. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
|
||||
super(EqualLinear, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.lr_mul = lr_mul
|
||||
self.activation = activation
|
||||
if self.activation not in ['fused_lrelu', None]:
|
||||
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
|
||||
"Supported ones are: ['fused_lrelu', None].")
|
||||
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, x):
|
||||
if self.bias is None:
|
||||
bias = None
|
||||
else:
|
||||
bias = self.bias * self.lr_mul
|
||||
if self.activation == 'fused_lrelu':
|
||||
out = F.linear(x, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, bias)
|
||||
else:
|
||||
out = F.linear(x, self.weight * self.scale, bias=bias)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
||||
f'out_channels={self.out_channels}, bias={self.bias is not None})')
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
"""Modulated Conv2d used in StyleGAN2.
|
||||
|
||||
There is no bias in ModulatedConv2d.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether to demodulate in the conv layer.
|
||||
Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
eps (float): A value added to the denominator for numerical stability.
|
||||
Default: 1e-8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
eps=1e-8,
|
||||
interpolation_mode='bilinear'):
|
||||
super(ModulatedConv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.demodulate = demodulate
|
||||
self.sample_mode = sample_mode
|
||||
self.eps = eps
|
||||
self.interpolation_mode = interpolation_mode
|
||||
if self.interpolation_mode == 'nearest':
|
||||
self.align_corners = None
|
||||
else:
|
||||
self.align_corners = False
|
||||
|
||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
||||
# modulation inside each modulated conv
|
||||
self.modulation = EqualLinear(
|
||||
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
def forward(self, x, style):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
|
||||
Returns:
|
||||
Tensor: Modulated tensor after convolution.
|
||||
"""
|
||||
b, c, h, w = x.shape # c = c_in
|
||||
# weight modulation
|
||||
style = self.modulation(style).view(b, 1, c, 1, 1)
|
||||
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
||||
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
|
||||
|
||||
if self.demodulate:
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
||||
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
||||
|
||||
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
||||
|
||||
if self.sample_mode == 'upsample':
|
||||
x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
||||
elif self.sample_mode == 'downsample':
|
||||
x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
|
||||
|
||||
b, c, h, w = x.shape
|
||||
x = x.view(1, b * c, h, w)
|
||||
# weight: (b*c_out, c_in, k, k), groups=b
|
||||
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
||||
f'out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size}, '
|
||||
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
||||
|
||||
|
||||
class StyleConv(nn.Module):
|
||||
"""Style conv.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
interpolation_mode='bilinear'):
|
||||
super(StyleConv, self).__init__()
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=demodulate,
|
||||
sample_mode=sample_mode,
|
||||
interpolation_mode=interpolation_mode)
|
||||
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
||||
self.activate = FusedLeakyReLU(out_channels)
|
||||
|
||||
def forward(self, x, style, noise=None):
|
||||
# modulate
|
||||
out = self.modulated_conv(x, style)
|
||||
# noise injection
|
||||
if noise is None:
|
||||
b, _, h, w = out.shape
|
||||
noise = out.new_empty(b, 1, h, w).normal_()
|
||||
out = out + self.weight * noise
|
||||
# activation (with bias)
|
||||
out = self.activate(out)
|
||||
return out
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
"""To RGB from features.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of input.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
upsample (bool): Whether to upsample. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
|
||||
super(ToRGB, self).__init__()
|
||||
self.upsample = upsample
|
||||
self.interpolation_mode = interpolation_mode
|
||||
if self.interpolation_mode == 'nearest':
|
||||
self.align_corners = None
|
||||
else:
|
||||
self.align_corners = False
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels,
|
||||
3,
|
||||
kernel_size=1,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=False,
|
||||
sample_mode=None,
|
||||
interpolation_mode=interpolation_mode)
|
||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||
|
||||
def forward(self, x, style, skip=None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
skip (Tensor): Base/skip tensor. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: RGB images.
|
||||
"""
|
||||
out = self.modulated_conv(x, style)
|
||||
out = out + self.bias
|
||||
if skip is not None:
|
||||
if self.upsample:
|
||||
skip = F.interpolate(
|
||||
skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class ConstantInput(nn.Module):
|
||||
"""Constant input.
|
||||
|
||||
Args:
|
||||
num_channel (int): Channel number of constant input.
|
||||
size (int): Spatial size of constant input.
|
||||
"""
|
||||
|
||||
def __init__(self, num_channel, size):
|
||||
super(ConstantInput, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
||||
|
||||
def forward(self, batch):
|
||||
out = self.weight.repeat(batch, 1, 1, 1)
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class StyleGAN2GeneratorBilinear(nn.Module):
|
||||
"""StyleGAN2 Generator.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of
|
||||
StyleGAN2. Default: 2.
|
||||
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_size,
|
||||
num_style_feat=512,
|
||||
num_mlp=8,
|
||||
channel_multiplier=2,
|
||||
lr_mlp=0.01,
|
||||
narrow=1,
|
||||
interpolation_mode='bilinear'):
|
||||
super(StyleGAN2GeneratorBilinear, self).__init__()
|
||||
# Style MLP layers
|
||||
self.num_style_feat = num_style_feat
|
||||
style_mlp_layers = [NormStyleCode()]
|
||||
for i in range(num_mlp):
|
||||
style_mlp_layers.append(
|
||||
EqualLinear(
|
||||
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
|
||||
activation='fused_lrelu'))
|
||||
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
||||
|
||||
channels = {
|
||||
'4': int(512 * narrow),
|
||||
'8': int(512 * narrow),
|
||||
'16': int(512 * narrow),
|
||||
'32': int(512 * narrow),
|
||||
'64': int(256 * channel_multiplier * narrow),
|
||||
'128': int(128 * channel_multiplier * narrow),
|
||||
'256': int(64 * channel_multiplier * narrow),
|
||||
'512': int(32 * channel_multiplier * narrow),
|
||||
'1024': int(16 * channel_multiplier * narrow)
|
||||
}
|
||||
self.channels = channels
|
||||
|
||||
self.constant_input = ConstantInput(channels['4'], size=4)
|
||||
self.style_conv1 = StyleConv(
|
||||
channels['4'],
|
||||
channels['4'],
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
interpolation_mode=interpolation_mode)
|
||||
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
self.num_latent = self.log_size * 2 - 2
|
||||
|
||||
self.style_convs = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channels = channels['4']
|
||||
# noise
|
||||
for layer_idx in range(self.num_layers):
|
||||
resolution = 2**((layer_idx + 5) // 2)
|
||||
shape = [1, 1, resolution, resolution]
|
||||
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
||||
# style convs and to_rgbs
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode='upsample',
|
||||
interpolation_mode=interpolation_mode))
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
interpolation_mode=interpolation_mode))
|
||||
self.to_rgbs.append(
|
||||
ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
|
||||
in_channels = out_channels
|
||||
|
||||
def make_noise(self):
|
||||
"""Make noise for noise injection."""
|
||||
device = self.constant_input.weight.device
|
||||
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def get_latent(self, x):
|
||||
return self.style_mlp(x)
|
||||
|
||||
def mean_latent(self, num_latent):
|
||||
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
||||
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
||||
return latent
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2Generator.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style.
|
||||
Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is
|
||||
False. Default: True.
|
||||
truncation (float): TODO. Default: 1.
|
||||
truncation_latent (Tensor | None): TODO. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise.
|
||||
Default: None.
|
||||
return_latents (bool): Whether to return style latents.
|
||||
Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latent with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip)
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
|
||||
|
||||
class ScaledLeakyReLU(nn.Module):
|
||||
"""Scaled LeakyReLU.
|
||||
|
||||
Args:
|
||||
negative_slope (float): Negative slope. Default: 0.2.
|
||||
"""
|
||||
|
||||
def __init__(self, negative_slope=0.2):
|
||||
super(ScaledLeakyReLU, self).__init__()
|
||||
self.negative_slope = negative_slope
|
||||
|
||||
def forward(self, x):
|
||||
out = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
return out * math.sqrt(2)
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
"""Equalized Linear as StyleGAN2.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
stride (int): Stride of the convolution. Default: 1
|
||||
padding (int): Zero-padding added to both sides of the input.
|
||||
Default: 0.
|
||||
bias (bool): If ``True``, adds a learnable bias to the output.
|
||||
Default: ``True``.
|
||||
bias_init_val (float): Bias initialized value. Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
|
||||
super(EqualConv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.conv2d(
|
||||
x,
|
||||
self.weight * self.scale,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
||||
f'out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size},'
|
||||
f' stride={self.stride}, padding={self.padding}, '
|
||||
f'bias={self.bias is not None})')
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
"""Conv Layer used in StyleGAN2 Discriminator.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Kernel size.
|
||||
downsample (bool): Whether downsample by a factor of 2.
|
||||
Default: False.
|
||||
bias (bool): Whether with bias. Default: True.
|
||||
activate (bool): Whether use activateion. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
bias=True,
|
||||
activate=True,
|
||||
interpolation_mode='bilinear'):
|
||||
layers = []
|
||||
self.interpolation_mode = interpolation_mode
|
||||
# downsample
|
||||
if downsample:
|
||||
if self.interpolation_mode == 'nearest':
|
||||
self.align_corners = None
|
||||
else:
|
||||
self.align_corners = False
|
||||
|
||||
layers.append(
|
||||
torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
# conv
|
||||
layers.append(
|
||||
EqualConv2d(
|
||||
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
|
||||
and not activate))
|
||||
# activation
|
||||
if activate:
|
||||
if bias:
|
||||
layers.append(FusedLeakyReLU(out_channels))
|
||||
else:
|
||||
layers.append(ScaledLeakyReLU(0.2))
|
||||
|
||||
super(ConvLayer, self).__init__(*layers)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Residual block used in StyleGAN2 Discriminator.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
|
||||
super(ResBlock, self).__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
||||
self.conv2 = ConvLayer(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
downsample=True,
|
||||
interpolation_mode=interpolation_mode,
|
||||
bias=True,
|
||||
activate=True)
|
||||
self.skip = ConvLayer(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
downsample=True,
|
||||
interpolation_mode=interpolation_mode,
|
||||
bias=False,
|
||||
activate=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
skip = self.skip(x)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
return out
|
||||
@@ -0,0 +1,368 @@
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
from basicsr.archs.arch_util import default_init_weights
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class NormStyleCode(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""Normalize the style codes.
|
||||
|
||||
Args:
|
||||
x (Tensor): Style codes with shape (b, c).
|
||||
|
||||
Returns:
|
||||
Tensor: Normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
|
||||
class ModulatedConv2d(nn.Module):
|
||||
"""Modulated Conv2d used in StyleGAN2.
|
||||
|
||||
There is no bias in ModulatedConv2d.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None,
|
||||
eps=1e-8):
|
||||
super(ModulatedConv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.demodulate = demodulate
|
||||
self.sample_mode = sample_mode
|
||||
self.eps = eps
|
||||
|
||||
# modulation inside each modulated conv
|
||||
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
||||
# initialization
|
||||
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
|
||||
math.sqrt(in_channels * kernel_size**2))
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
def forward(self, x, style):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
|
||||
Returns:
|
||||
Tensor: Modulated tensor after convolution.
|
||||
"""
|
||||
b, c, h, w = x.shape # c = c_in
|
||||
# weight modulation
|
||||
style = self.modulation(style).view(b, 1, c, 1, 1)
|
||||
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
||||
weight = self.weight * style # (b, c_out, c_in, k, k)
|
||||
|
||||
if self.demodulate:
|
||||
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
||||
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
||||
|
||||
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
||||
|
||||
# upsample or downsample if necessary
|
||||
if self.sample_mode == 'upsample':
|
||||
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
elif self.sample_mode == 'downsample':
|
||||
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
|
||||
b, c, h, w = x.shape
|
||||
x = x.view(1, b * c, h, w)
|
||||
# weight: (b*c_out, c_in, k, k), groups=b
|
||||
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
||||
out = out.view(b, self.out_channels, *out.shape[2:4])
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
||||
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
||||
|
||||
|
||||
class StyleConv(nn.Module):
|
||||
"""Style conv used in StyleGAN2.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of the input.
|
||||
out_channels (int): Channel number of the output.
|
||||
kernel_size (int): Size of the convolving kernel.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
||||
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
||||
super(StyleConv, self).__init__()
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
|
||||
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
||||
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
||||
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x, style, noise=None):
|
||||
# modulate
|
||||
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
||||
# noise injection
|
||||
if noise is None:
|
||||
b, _, h, w = out.shape
|
||||
noise = out.new_empty(b, 1, h, w).normal_()
|
||||
out = out + self.weight * noise
|
||||
# add bias
|
||||
out = out + self.bias
|
||||
# activation
|
||||
out = self.activate(out)
|
||||
return out
|
||||
|
||||
|
||||
class ToRGB(nn.Module):
|
||||
"""To RGB (image space) from features.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel number of input.
|
||||
num_style_feat (int): Channel number of style features.
|
||||
upsample (bool): Whether to upsample. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, num_style_feat, upsample=True):
|
||||
super(ToRGB, self).__init__()
|
||||
self.upsample = upsample
|
||||
self.modulated_conv = ModulatedConv2d(
|
||||
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
||||
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
||||
|
||||
def forward(self, x, style, skip=None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature tensor with shape (b, c, h, w).
|
||||
style (Tensor): Tensor with shape (b, num_style_feat).
|
||||
skip (Tensor): Base/skip tensor. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: RGB images.
|
||||
"""
|
||||
out = self.modulated_conv(x, style)
|
||||
out = out + self.bias
|
||||
if skip is not None:
|
||||
if self.upsample:
|
||||
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
out = out + skip
|
||||
return out
|
||||
|
||||
|
||||
class ConstantInput(nn.Module):
|
||||
"""Constant input.
|
||||
|
||||
Args:
|
||||
num_channel (int): Channel number of constant input.
|
||||
size (int): Spatial size of constant input.
|
||||
"""
|
||||
|
||||
def __init__(self, num_channel, size):
|
||||
super(ConstantInput, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
||||
|
||||
def forward(self, batch):
|
||||
out = self.weight.repeat(batch, 1, 1, 1)
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class StyleGAN2GeneratorClean(nn.Module):
|
||||
"""Clean version of StyleGAN2 Generator.
|
||||
|
||||
Args:
|
||||
out_size (int): The spatial size of outputs.
|
||||
num_style_feat (int): Channel number of style features. Default: 512.
|
||||
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
narrow (float): Narrow ratio for channels. Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
|
||||
super(StyleGAN2GeneratorClean, self).__init__()
|
||||
# Style MLP layers
|
||||
self.num_style_feat = num_style_feat
|
||||
style_mlp_layers = [NormStyleCode()]
|
||||
for i in range(num_mlp):
|
||||
style_mlp_layers.extend(
|
||||
[nn.Linear(num_style_feat, num_style_feat, bias=True),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True)])
|
||||
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
||||
# initialization
|
||||
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
# channel list
|
||||
channels = {
|
||||
'4': int(512 * narrow),
|
||||
'8': int(512 * narrow),
|
||||
'16': int(512 * narrow),
|
||||
'32': int(512 * narrow),
|
||||
'64': int(256 * channel_multiplier * narrow),
|
||||
'128': int(128 * channel_multiplier * narrow),
|
||||
'256': int(64 * channel_multiplier * narrow),
|
||||
'512': int(32 * channel_multiplier * narrow),
|
||||
'1024': int(16 * channel_multiplier * narrow)
|
||||
}
|
||||
self.channels = channels
|
||||
|
||||
self.constant_input = ConstantInput(channels['4'], size=4)
|
||||
self.style_conv1 = StyleConv(
|
||||
channels['4'],
|
||||
channels['4'],
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None)
|
||||
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
|
||||
|
||||
self.log_size = int(math.log(out_size, 2))
|
||||
self.num_layers = (self.log_size - 2) * 2 + 1
|
||||
self.num_latent = self.log_size * 2 - 2
|
||||
|
||||
self.style_convs = nn.ModuleList()
|
||||
self.to_rgbs = nn.ModuleList()
|
||||
self.noises = nn.Module()
|
||||
|
||||
in_channels = channels['4']
|
||||
# noise
|
||||
for layer_idx in range(self.num_layers):
|
||||
resolution = 2**((layer_idx + 5) // 2)
|
||||
shape = [1, 1, resolution, resolution]
|
||||
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
||||
# style convs and to_rgbs
|
||||
for i in range(3, self.log_size + 1):
|
||||
out_channels = channels[f'{2**i}']
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode='upsample'))
|
||||
self.style_convs.append(
|
||||
StyleConv(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_style_feat=num_style_feat,
|
||||
demodulate=True,
|
||||
sample_mode=None))
|
||||
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
||||
in_channels = out_channels
|
||||
|
||||
def make_noise(self):
|
||||
"""Make noise for noise injection."""
|
||||
device = self.constant_input.weight.device
|
||||
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
||||
|
||||
for i in range(3, self.log_size + 1):
|
||||
for _ in range(2):
|
||||
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
||||
|
||||
return noises
|
||||
|
||||
def get_latent(self, x):
|
||||
return self.style_mlp(x)
|
||||
|
||||
def mean_latent(self, num_latent):
|
||||
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
||||
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
||||
return latent
|
||||
|
||||
def forward(self,
|
||||
styles,
|
||||
input_is_latent=False,
|
||||
noise=None,
|
||||
randomize_noise=True,
|
||||
truncation=1,
|
||||
truncation_latent=None,
|
||||
inject_index=None,
|
||||
return_latents=False):
|
||||
"""Forward function for StyleGAN2GeneratorClean.
|
||||
|
||||
Args:
|
||||
styles (list[Tensor]): Sample codes of styles.
|
||||
input_is_latent (bool): Whether input is latent style. Default: False.
|
||||
noise (Tensor | None): Input noise or None. Default: None.
|
||||
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
||||
truncation (float): The truncation ratio. Default: 1.
|
||||
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
||||
inject_index (int | None): The injection index for mixing noise. Default: None.
|
||||
return_latents (bool): Whether to return style latents. Default: False.
|
||||
"""
|
||||
# style codes -> latents with Style MLP layer
|
||||
if not input_is_latent:
|
||||
styles = [self.style_mlp(s) for s in styles]
|
||||
# noises
|
||||
if noise is None:
|
||||
if randomize_noise:
|
||||
noise = [None] * self.num_layers # for each style conv layer
|
||||
else: # use the stored noise
|
||||
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
||||
# style truncation
|
||||
if truncation < 1:
|
||||
style_truncation = []
|
||||
for style in styles:
|
||||
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
||||
styles = style_truncation
|
||||
# get style latents with injection
|
||||
if len(styles) == 1:
|
||||
inject_index = self.num_latent
|
||||
|
||||
if styles[0].ndim < 3:
|
||||
# repeat latent code for all the layers
|
||||
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
else: # used for encoder with different latent code for each layer
|
||||
latent = styles[0]
|
||||
elif len(styles) == 2: # mixing noises
|
||||
if inject_index is None:
|
||||
inject_index = random.randint(1, self.num_latent - 1)
|
||||
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
||||
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
||||
latent = torch.cat([latent1, latent2], 1)
|
||||
|
||||
# main generation
|
||||
out = self.constant_input(latent.shape[0])
|
||||
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
||||
skip = self.to_rgb1(out, latent[:, 1])
|
||||
|
||||
i = 1
|
||||
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
||||
noise[2::2], self.to_rgbs):
|
||||
out = conv1(out, latent[:, i], noise=noise1)
|
||||
out = conv2(out, latent[:, i + 1], noise=noise2)
|
||||
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
||||
i += 2
|
||||
|
||||
image = skip
|
||||
|
||||
if return_latents:
|
||||
return image, latent
|
||||
else:
|
||||
return image, None
|
||||
@@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
# automatically scan and import dataset modules for registry
|
||||
# scan all the files that end with '_dataset.py' under the data folder
|
||||
data_folder = osp.dirname(osp.abspath(__file__))
|
||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
||||
# import all the dataset modules
|
||||
_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]
|
||||
@@ -0,0 +1,230 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
from basicsr.data import degradations as degradations
|
||||
from basicsr.data.data_util import paths_from_folder
|
||||
from basicsr.data.transforms import augment
|
||||
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
||||
normalize)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class FFHQDegradationDataset(data.Dataset):
|
||||
"""FFHQ dataset for GFPGAN.
|
||||
|
||||
It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
|
||||
|
||||
Args:
|
||||
opt (dict): Config for train datasets. It contains the following keys:
|
||||
dataroot_gt (str): Data root path for gt.
|
||||
io_backend (dict): IO backend type and other kwarg.
|
||||
mean (list | tuple): Image mean.
|
||||
std (list | tuple): Image std.
|
||||
use_hflip (bool): Whether to horizontally flip.
|
||||
Please see more options in the codes.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(FFHQDegradationDataset, self).__init__()
|
||||
self.opt = opt
|
||||
# file client (io backend)
|
||||
self.file_client = None
|
||||
self.io_backend_opt = opt['io_backend']
|
||||
|
||||
self.gt_folder = opt['dataroot_gt']
|
||||
self.mean = opt['mean']
|
||||
self.std = opt['std']
|
||||
self.out_size = opt['out_size']
|
||||
|
||||
self.crop_components = opt.get('crop_components', False) # facial components
|
||||
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
|
||||
|
||||
if self.crop_components:
|
||||
# load component list from a pre-process pth files
|
||||
self.components_list = torch.load(opt.get('component_path'))
|
||||
|
||||
# file client (lmdb io backend)
|
||||
if self.io_backend_opt['type'] == 'lmdb':
|
||||
self.io_backend_opt['db_paths'] = self.gt_folder
|
||||
if not self.gt_folder.endswith('.lmdb'):
|
||||
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
||||
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
||||
self.paths = [line.split('.')[0] for line in fin]
|
||||
else:
|
||||
# disk backend: scan file list from a folder
|
||||
self.paths = paths_from_folder(self.gt_folder)
|
||||
|
||||
# degradation configurations
|
||||
self.blur_kernel_size = opt['blur_kernel_size']
|
||||
self.kernel_list = opt['kernel_list']
|
||||
self.kernel_prob = opt['kernel_prob']
|
||||
self.blur_sigma = opt['blur_sigma']
|
||||
self.downsample_range = opt['downsample_range']
|
||||
self.noise_range = opt['noise_range']
|
||||
self.jpeg_range = opt['jpeg_range']
|
||||
|
||||
# color jitter
|
||||
self.color_jitter_prob = opt.get('color_jitter_prob')
|
||||
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
|
||||
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
||||
# to gray
|
||||
self.gray_prob = opt.get('gray_prob')
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
||||
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
||||
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
||||
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
||||
|
||||
if self.color_jitter_prob is not None:
|
||||
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
||||
if self.gray_prob is not None:
|
||||
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
||||
self.color_jitter_shift /= 255.
|
||||
|
||||
@staticmethod
|
||||
def color_jitter(img, shift):
|
||||
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
||||
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
||||
img = img + jitter_val
|
||||
img = np.clip(img, 0, 1)
|
||||
return img
|
||||
|
||||
@staticmethod
|
||||
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
||||
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
||||
fn_idx = torch.randperm(4)
|
||||
for fn_id in fn_idx:
|
||||
if fn_id == 0 and brightness is not None:
|
||||
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
||||
img = adjust_brightness(img, brightness_factor)
|
||||
|
||||
if fn_id == 1 and contrast is not None:
|
||||
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
||||
img = adjust_contrast(img, contrast_factor)
|
||||
|
||||
if fn_id == 2 and saturation is not None:
|
||||
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
||||
img = adjust_saturation(img, saturation_factor)
|
||||
|
||||
if fn_id == 3 and hue is not None:
|
||||
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
||||
img = adjust_hue(img, hue_factor)
|
||||
return img
|
||||
|
||||
def get_component_coordinates(self, index, status):
|
||||
"""Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
|
||||
components_bbox = self.components_list[f'{index:08d}']
|
||||
if status[0]: # hflip
|
||||
# exchange right and left eye
|
||||
tmp = components_bbox['left_eye']
|
||||
components_bbox['left_eye'] = components_bbox['right_eye']
|
||||
components_bbox['right_eye'] = tmp
|
||||
# modify the width coordinate
|
||||
components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
|
||||
components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
|
||||
components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
|
||||
|
||||
# get coordinates
|
||||
locations = []
|
||||
for part in ['left_eye', 'right_eye', 'mouth']:
|
||||
mean = components_bbox[part][0:2]
|
||||
half_len = components_bbox[part][2]
|
||||
if 'eye' in part:
|
||||
half_len *= self.eye_enlarge_ratio
|
||||
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
||||
loc = torch.from_numpy(loc).float()
|
||||
locations.append(loc)
|
||||
return locations
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.file_client is None:
|
||||
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
||||
|
||||
# load gt image
|
||||
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]
|
||||
img_bytes = self.file_client.get(gt_path)
|
||||
img_gt = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# random horizontal flip
|
||||
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
||||
h, w, _ = img_gt.shape
|
||||
|
||||
# get facial component coordinates
|
||||
if self.crop_components:
|
||||
locations = self.get_component_coordinates(index, status)
|
||||
loc_left_eye, loc_right_eye, loc_mouth = locations
|
||||
|
||||
# ------------------------ generate lq image ------------------------ #
|
||||
# blur
|
||||
kernel = degradations.random_mixed_kernels(
|
||||
self.kernel_list,
|
||||
self.kernel_prob,
|
||||
self.blur_kernel_size,
|
||||
self.blur_sigma,
|
||||
self.blur_sigma, [-math.pi, math.pi],
|
||||
noise_range=None)
|
||||
img_lq = cv2.filter2D(img_gt, -1, kernel)
|
||||
# downsample
|
||||
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
||||
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
|
||||
# noise
|
||||
if self.noise_range is not None:
|
||||
img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
|
||||
# jpeg compression
|
||||
if self.jpeg_range is not None:
|
||||
img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
|
||||
|
||||
# resize to original size
|
||||
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# random color jitter (only for lq)
|
||||
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
||||
img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
|
||||
# random to gray (only for lq)
|
||||
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
||||
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
||||
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
||||
if self.opt.get('gt_gray'): # whether convert GT to gray images
|
||||
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
||||
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
||||
|
||||
# random color jitter (pytorch version) (only for lq)
|
||||
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
||||
brightness = self.opt.get('brightness', (0.5, 1.5))
|
||||
contrast = self.opt.get('contrast', (0.5, 1.5))
|
||||
saturation = self.opt.get('saturation', (0, 1.5))
|
||||
hue = self.opt.get('hue', (-0.1, 0.1))
|
||||
img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
|
||||
|
||||
# round and clip
|
||||
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
|
||||
|
||||
# normalize
|
||||
normalize(img_gt, self.mean, self.std, inplace=True)
|
||||
normalize(img_lq, self.mean, self.std, inplace=True)
|
||||
|
||||
if self.crop_components:
|
||||
return_dict = {
|
||||
'lq': img_lq,
|
||||
'gt': img_gt,
|
||||
'gt_path': gt_path,
|
||||
'loc_left_eye': loc_left_eye,
|
||||
'loc_right_eye': loc_right_eye,
|
||||
'loc_mouth': loc_mouth
|
||||
}
|
||||
return return_dict
|
||||
else:
|
||||
return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
@@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
from basicsr.utils import scandir
|
||||
from os import path as osp
|
||||
|
||||
# automatically scan and import model modules for registry
|
||||
# scan all the files that end with '_model.py' under the model folder
|
||||
model_folder = osp.dirname(osp.abspath(__file__))
|
||||
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
||||
# import all the model modules
|
||||
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
|
||||
@@ -0,0 +1,579 @@
|
||||
import math
|
||||
import os.path as osp
|
||||
import torch
|
||||
from basicsr.archs import build_network
|
||||
from basicsr.losses import build_loss
|
||||
from basicsr.losses.losses import r1_penalty
|
||||
from basicsr.metrics import calculate_metric
|
||||
from basicsr.models.base_model import BaseModel
|
||||
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
||||
from basicsr.utils.registry import MODEL_REGISTRY
|
||||
from collections import OrderedDict
|
||||
from torch.nn import functional as F
|
||||
from torchvision.ops import roi_align
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@MODEL_REGISTRY.register()
|
||||
class GFPGANModel(BaseModel):
|
||||
"""The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
|
||||
|
||||
def __init__(self, opt):
|
||||
super(GFPGANModel, self).__init__(opt)
|
||||
self.idx = 0 # it is used for saving data for check
|
||||
|
||||
# define network
|
||||
self.net_g = build_network(opt['network_g'])
|
||||
self.net_g = self.model_to_device(self.net_g)
|
||||
self.print_network(self.net_g)
|
||||
|
||||
# load pretrained model
|
||||
load_path = self.opt['path'].get('pretrain_network_g', None)
|
||||
if load_path is not None:
|
||||
param_key = self.opt['path'].get('param_key_g', 'params')
|
||||
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
|
||||
|
||||
self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
|
||||
|
||||
if self.is_train:
|
||||
self.init_training_settings()
|
||||
|
||||
def init_training_settings(self):
|
||||
train_opt = self.opt['train']
|
||||
|
||||
# ----------- define net_d ----------- #
|
||||
self.net_d = build_network(self.opt['network_d'])
|
||||
self.net_d = self.model_to_device(self.net_d)
|
||||
self.print_network(self.net_d)
|
||||
# load pretrained model
|
||||
load_path = self.opt['path'].get('pretrain_network_d', None)
|
||||
if load_path is not None:
|
||||
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
||||
|
||||
# ----------- define net_g with Exponential Moving Average (EMA) ----------- #
|
||||
# net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
|
||||
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
||||
# load pretrained model
|
||||
load_path = self.opt['path'].get('pretrain_network_g', None)
|
||||
if load_path is not None:
|
||||
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
||||
else:
|
||||
self.model_ema(0) # copy net_g weight
|
||||
|
||||
self.net_g.train()
|
||||
self.net_d.train()
|
||||
self.net_g_ema.eval()
|
||||
|
||||
# ----------- facial component networks ----------- #
|
||||
if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
|
||||
self.use_facial_disc = True
|
||||
else:
|
||||
self.use_facial_disc = False
|
||||
|
||||
if self.use_facial_disc:
|
||||
# left eye
|
||||
self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
|
||||
self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
|
||||
self.print_network(self.net_d_left_eye)
|
||||
load_path = self.opt['path'].get('pretrain_network_d_left_eye')
|
||||
if load_path is not None:
|
||||
self.load_network(self.net_d_left_eye, load_path, True, 'params')
|
||||
# right eye
|
||||
self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
|
||||
self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
|
||||
self.print_network(self.net_d_right_eye)
|
||||
load_path = self.opt['path'].get('pretrain_network_d_right_eye')
|
||||
if load_path is not None:
|
||||
self.load_network(self.net_d_right_eye, load_path, True, 'params')
|
||||
# mouth
|
||||
self.net_d_mouth = build_network(self.opt['network_d_mouth'])
|
||||
self.net_d_mouth = self.model_to_device(self.net_d_mouth)
|
||||
self.print_network(self.net_d_mouth)
|
||||
load_path = self.opt['path'].get('pretrain_network_d_mouth')
|
||||
if load_path is not None:
|
||||
self.load_network(self.net_d_mouth, load_path, True, 'params')
|
||||
|
||||
self.net_d_left_eye.train()
|
||||
self.net_d_right_eye.train()
|
||||
self.net_d_mouth.train()
|
||||
|
||||
# ----------- define facial component gan loss ----------- #
|
||||
self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
|
||||
|
||||
# ----------- define losses ----------- #
|
||||
# pixel loss
|
||||
if train_opt.get('pixel_opt'):
|
||||
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
||||
else:
|
||||
self.cri_pix = None
|
||||
|
||||
# perceptual loss
|
||||
if train_opt.get('perceptual_opt'):
|
||||
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
||||
else:
|
||||
self.cri_perceptual = None
|
||||
|
||||
# L1 loss is used in pyramid loss, component style loss and identity loss
|
||||
self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
|
||||
|
||||
# gan loss (wgan)
|
||||
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
||||
|
||||
# ----------- define identity loss ----------- #
|
||||
if 'network_identity' in self.opt:
|
||||
self.use_identity = True
|
||||
else:
|
||||
self.use_identity = False
|
||||
|
||||
if self.use_identity:
|
||||
# define identity network
|
||||
self.network_identity = build_network(self.opt['network_identity'])
|
||||
self.network_identity = self.model_to_device(self.network_identity)
|
||||
self.print_network(self.network_identity)
|
||||
load_path = self.opt['path'].get('pretrain_network_identity')
|
||||
if load_path is not None:
|
||||
self.load_network(self.network_identity, load_path, True, None)
|
||||
self.network_identity.eval()
|
||||
for param in self.network_identity.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# regularization weights
|
||||
self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
|
||||
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
||||
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
|
||||
self.net_d_reg_every = train_opt['net_d_reg_every']
|
||||
|
||||
# set up optimizers and schedulers
|
||||
self.setup_optimizers()
|
||||
self.setup_schedulers()
|
||||
|
||||
def setup_optimizers(self):
|
||||
train_opt = self.opt['train']
|
||||
|
||||
# ----------- optimizer g ----------- #
|
||||
net_g_reg_ratio = 1
|
||||
normal_params = []
|
||||
for _, param in self.net_g.named_parameters():
|
||||
normal_params.append(param)
|
||||
optim_params_g = [{ # add normal params first
|
||||
'params': normal_params,
|
||||
'lr': train_opt['optim_g']['lr']
|
||||
}]
|
||||
optim_type = train_opt['optim_g'].pop('type')
|
||||
lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
|
||||
betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
|
||||
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
|
||||
self.optimizers.append(self.optimizer_g)
|
||||
|
||||
# ----------- optimizer d ----------- #
|
||||
net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
|
||||
normal_params = []
|
||||
for _, param in self.net_d.named_parameters():
|
||||
normal_params.append(param)
|
||||
optim_params_d = [{ # add normal params first
|
||||
'params': normal_params,
|
||||
'lr': train_opt['optim_d']['lr']
|
||||
}]
|
||||
optim_type = train_opt['optim_d'].pop('type')
|
||||
lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
|
||||
betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
|
||||
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
|
||||
self.optimizers.append(self.optimizer_d)
|
||||
|
||||
# ----------- optimizers for facial component networks ----------- #
|
||||
if self.use_facial_disc:
|
||||
# setup optimizers for facial component discriminators
|
||||
optim_type = train_opt['optim_component'].pop('type')
|
||||
lr = train_opt['optim_component']['lr']
|
||||
# left eye
|
||||
self.optimizer_d_left_eye = self.get_optimizer(
|
||||
optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
|
||||
self.optimizers.append(self.optimizer_d_left_eye)
|
||||
# right eye
|
||||
self.optimizer_d_right_eye = self.get_optimizer(
|
||||
optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
|
||||
self.optimizers.append(self.optimizer_d_right_eye)
|
||||
# mouth
|
||||
self.optimizer_d_mouth = self.get_optimizer(
|
||||
optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
|
||||
self.optimizers.append(self.optimizer_d_mouth)
|
||||
|
||||
def feed_data(self, data):
|
||||
self.lq = data['lq'].to(self.device)
|
||||
if 'gt' in data:
|
||||
self.gt = data['gt'].to(self.device)
|
||||
|
||||
if 'loc_left_eye' in data:
|
||||
# get facial component locations, shape (batch, 4)
|
||||
self.loc_left_eyes = data['loc_left_eye']
|
||||
self.loc_right_eyes = data['loc_right_eye']
|
||||
self.loc_mouths = data['loc_mouth']
|
||||
|
||||
# uncomment to check data
|
||||
# import torchvision
|
||||
# if self.opt['rank'] == 0:
|
||||
# import os
|
||||
# os.makedirs('tmp/gt', exist_ok=True)
|
||||
# os.makedirs('tmp/lq', exist_ok=True)
|
||||
# print(self.idx)
|
||||
# torchvision.utils.save_image(
|
||||
# self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
||||
# torchvision.utils.save_image(
|
||||
# self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
|
||||
# self.idx = self.idx + 1
|
||||
|
||||
def construct_img_pyramid(self):
|
||||
"""Construct image pyramid for intermediate restoration loss"""
|
||||
pyramid_gt = [self.gt]
|
||||
down_img = self.gt
|
||||
for _ in range(0, self.log_size - 3):
|
||||
down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
pyramid_gt.insert(0, down_img)
|
||||
return pyramid_gt
|
||||
|
||||
def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
|
||||
face_ratio = int(self.opt['network_g']['out_size'] / 512)
|
||||
eye_out_size *= face_ratio
|
||||
mouth_out_size *= face_ratio
|
||||
|
||||
rois_eyes = []
|
||||
rois_mouths = []
|
||||
for b in range(self.loc_left_eyes.size(0)): # loop for batch size
|
||||
# left eye and right eye
|
||||
img_inds = self.loc_left_eyes.new_full((2, 1), b)
|
||||
bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
|
||||
rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
|
||||
rois_eyes.append(rois)
|
||||
# mouse
|
||||
img_inds = self.loc_left_eyes.new_full((1, 1), b)
|
||||
rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
|
||||
rois_mouths.append(rois)
|
||||
|
||||
rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
|
||||
rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
|
||||
|
||||
# real images
|
||||
all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
||||
self.left_eyes_gt = all_eyes[0::2, :, :, :]
|
||||
self.right_eyes_gt = all_eyes[1::2, :, :, :]
|
||||
self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
||||
# output
|
||||
all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
|
||||
self.left_eyes = all_eyes[0::2, :, :, :]
|
||||
self.right_eyes = all_eyes[1::2, :, :, :]
|
||||
self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
|
||||
|
||||
def _gram_mat(self, x):
|
||||
"""Calculate Gram matrix.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Gram matrix.
|
||||
"""
|
||||
n, c, h, w = x.size()
|
||||
features = x.view(n, c, w * h)
|
||||
features_t = features.transpose(1, 2)
|
||||
gram = features.bmm(features_t) / (c * h * w)
|
||||
return gram
|
||||
|
||||
def gray_resize_for_identity(self, out, size=128):
|
||||
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
|
||||
out_gray = out_gray.unsqueeze(1)
|
||||
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
|
||||
return out_gray
|
||||
|
||||
def optimize_parameters(self, current_iter):
|
||||
# optimize net_g
|
||||
for p in self.net_d.parameters():
|
||||
p.requires_grad = False
|
||||
self.optimizer_g.zero_grad()
|
||||
|
||||
# do not update facial component net_d
|
||||
if self.use_facial_disc:
|
||||
for p in self.net_d_left_eye.parameters():
|
||||
p.requires_grad = False
|
||||
for p in self.net_d_right_eye.parameters():
|
||||
p.requires_grad = False
|
||||
for p in self.net_d_mouth.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# image pyramid loss weight
|
||||
pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
|
||||
if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
|
||||
pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error
|
||||
if pyramid_loss_weight > 0:
|
||||
self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
|
||||
pyramid_gt = self.construct_img_pyramid()
|
||||
else:
|
||||
self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
|
||||
|
||||
# get roi-align regions
|
||||
if self.use_facial_disc:
|
||||
self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
|
||||
|
||||
l_g_total = 0
|
||||
loss_dict = OrderedDict()
|
||||
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
||||
# pixel loss
|
||||
if self.cri_pix:
|
||||
l_g_pix = self.cri_pix(self.output, self.gt)
|
||||
l_g_total += l_g_pix
|
||||
loss_dict['l_g_pix'] = l_g_pix
|
||||
|
||||
# image pyramid loss
|
||||
if pyramid_loss_weight > 0:
|
||||
for i in range(0, self.log_size - 2):
|
||||
l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
|
||||
l_g_total += l_pyramid
|
||||
loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
|
||||
|
||||
# perceptual loss
|
||||
if self.cri_perceptual:
|
||||
l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
|
||||
if l_g_percep is not None:
|
||||
l_g_total += l_g_percep
|
||||
loss_dict['l_g_percep'] = l_g_percep
|
||||
if l_g_style is not None:
|
||||
l_g_total += l_g_style
|
||||
loss_dict['l_g_style'] = l_g_style
|
||||
|
||||
# gan loss
|
||||
fake_g_pred = self.net_d(self.output)
|
||||
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
||||
l_g_total += l_g_gan
|
||||
loss_dict['l_g_gan'] = l_g_gan
|
||||
|
||||
# facial component loss
|
||||
if self.use_facial_disc:
|
||||
# left eye
|
||||
fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
|
||||
l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
|
||||
l_g_total += l_g_gan
|
||||
loss_dict['l_g_gan_left_eye'] = l_g_gan
|
||||
# right eye
|
||||
fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
|
||||
l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
|
||||
l_g_total += l_g_gan
|
||||
loss_dict['l_g_gan_right_eye'] = l_g_gan
|
||||
# mouth
|
||||
fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
|
||||
l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
|
||||
l_g_total += l_g_gan
|
||||
loss_dict['l_g_gan_mouth'] = l_g_gan
|
||||
|
||||
if self.opt['train'].get('comp_style_weight', 0) > 0:
|
||||
# get gt feat
|
||||
_, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
|
||||
_, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
|
||||
_, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
|
||||
|
||||
def _comp_style(feat, feat_gt, criterion):
|
||||
return criterion(self._gram_mat(feat[0]), self._gram_mat(
|
||||
feat_gt[0].detach())) * 0.5 + criterion(
|
||||
self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
|
||||
|
||||
# facial component style loss
|
||||
comp_style_loss = 0
|
||||
comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
|
||||
comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
|
||||
comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
|
||||
comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
|
||||
l_g_total += comp_style_loss
|
||||
loss_dict['l_g_comp_style_loss'] = comp_style_loss
|
||||
|
||||
# identity loss
|
||||
if self.use_identity:
|
||||
identity_weight = self.opt['train']['identity_weight']
|
||||
# get gray images and resize
|
||||
out_gray = self.gray_resize_for_identity(self.output)
|
||||
gt_gray = self.gray_resize_for_identity(self.gt)
|
||||
|
||||
identity_gt = self.network_identity(gt_gray).detach()
|
||||
identity_out = self.network_identity(out_gray)
|
||||
l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
|
||||
l_g_total += l_identity
|
||||
loss_dict['l_identity'] = l_identity
|
||||
|
||||
l_g_total.backward()
|
||||
self.optimizer_g.step()
|
||||
|
||||
# EMA
|
||||
self.model_ema(decay=0.5**(32 / (10 * 1000)))
|
||||
|
||||
# ----------- optimize net_d ----------- #
|
||||
for p in self.net_d.parameters():
|
||||
p.requires_grad = True
|
||||
self.optimizer_d.zero_grad()
|
||||
if self.use_facial_disc:
|
||||
for p in self.net_d_left_eye.parameters():
|
||||
p.requires_grad = True
|
||||
for p in self.net_d_right_eye.parameters():
|
||||
p.requires_grad = True
|
||||
for p in self.net_d_mouth.parameters():
|
||||
p.requires_grad = True
|
||||
self.optimizer_d_left_eye.zero_grad()
|
||||
self.optimizer_d_right_eye.zero_grad()
|
||||
self.optimizer_d_mouth.zero_grad()
|
||||
|
||||
fake_d_pred = self.net_d(self.output.detach())
|
||||
real_d_pred = self.net_d(self.gt)
|
||||
l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
|
||||
loss_dict['l_d'] = l_d
|
||||
# In WGAN, real_score should be positive and fake_score should be negative
|
||||
loss_dict['real_score'] = real_d_pred.detach().mean()
|
||||
loss_dict['fake_score'] = fake_d_pred.detach().mean()
|
||||
l_d.backward()
|
||||
|
||||
# regularization loss
|
||||
if current_iter % self.net_d_reg_every == 0:
|
||||
self.gt.requires_grad = True
|
||||
real_pred = self.net_d(self.gt)
|
||||
l_d_r1 = r1_penalty(real_pred, self.gt)
|
||||
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
|
||||
loss_dict['l_d_r1'] = l_d_r1.detach().mean()
|
||||
l_d_r1.backward()
|
||||
|
||||
self.optimizer_d.step()
|
||||
|
||||
# optimize facial component discriminators
|
||||
if self.use_facial_disc:
|
||||
# left eye
|
||||
fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
|
||||
real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
|
||||
l_d_left_eye = self.cri_component(
|
||||
real_d_pred, True, is_disc=True) + self.cri_gan(
|
||||
fake_d_pred, False, is_disc=True)
|
||||
loss_dict['l_d_left_eye'] = l_d_left_eye
|
||||
l_d_left_eye.backward()
|
||||
# right eye
|
||||
fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
|
||||
real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
|
||||
l_d_right_eye = self.cri_component(
|
||||
real_d_pred, True, is_disc=True) + self.cri_gan(
|
||||
fake_d_pred, False, is_disc=True)
|
||||
loss_dict['l_d_right_eye'] = l_d_right_eye
|
||||
l_d_right_eye.backward()
|
||||
# mouth
|
||||
fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
|
||||
real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
|
||||
l_d_mouth = self.cri_component(
|
||||
real_d_pred, True, is_disc=True) + self.cri_gan(
|
||||
fake_d_pred, False, is_disc=True)
|
||||
loss_dict['l_d_mouth'] = l_d_mouth
|
||||
l_d_mouth.backward()
|
||||
|
||||
self.optimizer_d_left_eye.step()
|
||||
self.optimizer_d_right_eye.step()
|
||||
self.optimizer_d_mouth.step()
|
||||
|
||||
self.log_dict = self.reduce_loss_dict(loss_dict)
|
||||
|
||||
def test(self):
|
||||
with torch.no_grad():
|
||||
if hasattr(self, 'net_g_ema'):
|
||||
self.net_g_ema.eval()
|
||||
self.output, _ = self.net_g_ema(self.lq)
|
||||
else:
|
||||
logger = get_root_logger()
|
||||
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
||||
self.net_g.eval()
|
||||
self.output, _ = self.net_g(self.lq)
|
||||
self.net_g.train()
|
||||
|
||||
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||
if self.opt['rank'] == 0:
|
||||
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
||||
|
||||
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
||||
dataset_name = dataloader.dataset.opt['name']
|
||||
with_metrics = self.opt['val'].get('metrics') is not None
|
||||
use_pbar = self.opt['val'].get('pbar', False)
|
||||
|
||||
if with_metrics:
|
||||
if not hasattr(self, 'metric_results'): # only execute in the first run
|
||||
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
||||
# initialize the best metric results for each dataset_name (supporting multiple validation datasets)
|
||||
self._initialize_best_metric_results(dataset_name)
|
||||
# zero self.metric_results
|
||||
self.metric_results = {metric: 0 for metric in self.metric_results}
|
||||
|
||||
metric_data = dict()
|
||||
if use_pbar:
|
||||
pbar = tqdm(total=len(dataloader), unit='image')
|
||||
|
||||
for idx, val_data in enumerate(dataloader):
|
||||
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
||||
self.feed_data(val_data)
|
||||
self.test()
|
||||
|
||||
sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
|
||||
metric_data['img'] = sr_img
|
||||
if hasattr(self, 'gt'):
|
||||
gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
|
||||
metric_data['img2'] = gt_img
|
||||
del self.gt
|
||||
|
||||
# tentative for out of GPU memory
|
||||
del self.lq
|
||||
del self.output
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if save_img:
|
||||
if self.opt['is_train']:
|
||||
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
||||
f'{img_name}_{current_iter}.png')
|
||||
else:
|
||||
if self.opt['val']['suffix']:
|
||||
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
||||
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
||||
else:
|
||||
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
||||
f'{img_name}_{self.opt["name"]}.png')
|
||||
imwrite(sr_img, save_img_path)
|
||||
|
||||
if with_metrics:
|
||||
# calculate metrics
|
||||
for name, opt_ in self.opt['val']['metrics'].items():
|
||||
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
||||
if use_pbar:
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Test {img_name}')
|
||||
if use_pbar:
|
||||
pbar.close()
|
||||
|
||||
if with_metrics:
|
||||
for metric in self.metric_results.keys():
|
||||
self.metric_results[metric] /= (idx + 1)
|
||||
# update the best metric result
|
||||
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
|
||||
|
||||
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
||||
|
||||
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
||||
log_str = f'Validation {dataset_name}\n'
|
||||
for metric, value in self.metric_results.items():
|
||||
log_str += f'\t # {metric}: {value:.4f}'
|
||||
if hasattr(self, 'best_metric_results'):
|
||||
log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
|
||||
f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
|
||||
log_str += '\n'
|
||||
|
||||
logger = get_root_logger()
|
||||
logger.info(log_str)
|
||||
if tb_logger:
|
||||
for metric, value in self.metric_results.items():
|
||||
tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
|
||||
|
||||
def save(self, epoch, current_iter):
|
||||
# save net_g and net_d
|
||||
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
||||
self.save_network(self.net_d, 'net_d', current_iter)
|
||||
# save component discriminators
|
||||
if self.use_facial_disc:
|
||||
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
|
||||
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
|
||||
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
|
||||
# save training state
|
||||
self.save_training_state(epoch, current_iter)
|
||||
@@ -0,0 +1,11 @@
|
||||
# flake8: noqa
|
||||
import os.path as osp
|
||||
from basicsr.train import train_pipeline
|
||||
|
||||
import gfpgan.archs
|
||||
import gfpgan.data
|
||||
import gfpgan.models
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
||||
train_pipeline(root_path)
|
||||
@@ -0,0 +1,143 @@
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
|
||||
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class GFPGANer():
|
||||
"""Helper for restoration with GFPGAN.
|
||||
|
||||
It will detect and crop faces, and then resize the faces to 512x512.
|
||||
GFPGAN is used to restored the resized faces.
|
||||
The background is upsampled with the bg_upsampler.
|
||||
Finally, the faces will be pasted back to the upsample background image.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
||||
upscale (float): The upscale of the final output. Default: 2.
|
||||
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
|
||||
self.upscale = upscale
|
||||
self.bg_upsampler = bg_upsampler
|
||||
|
||||
# initialize model
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# initialize the GFP-GAN
|
||||
if arch == 'clean':
|
||||
self.gfpgan = GFPGANv1Clean(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'bilinear':
|
||||
self.gfpgan = GFPGANBilinear(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'original':
|
||||
self.gfpgan = GFPGANv1(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
# initialize face helper
|
||||
self.face_helper = FaceRestoreHelper(
|
||||
upscale,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
device=self.device)
|
||||
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
keyname = 'params'
|
||||
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
||||
self.gfpgan.eval()
|
||||
self.gfpgan = self.gfpgan.to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if has_aligned: # the inputs are already aligned
|
||||
img = cv2.resize(img, (512, 512))
|
||||
self.face_helper.cropped_faces = [img]
|
||||
else:
|
||||
self.face_helper.read_image(img)
|
||||
# get face landmarks for each face
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
||||
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
||||
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
||||
# align and warp each face
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
# face restoration
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
||||
|
||||
try:
|
||||
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
|
||||
# convert to image
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
except RuntimeError as error:
|
||||
print(f'\tFailed inference for GFPGAN: {error}.')
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
if not has_aligned and paste_back:
|
||||
# upsample the background
|
||||
if self.bg_upsampler is not None:
|
||||
# Now only support RealESRGAN for upsampling background
|
||||
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
||||
else:
|
||||
bg_img = None
|
||||
|
||||
self.face_helper.get_inverse_affine(None)
|
||||
# paste each restored face to the input image
|
||||
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
||||
else:
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
||||
@@ -0,0 +1,5 @@
|
||||
# GENERATED VERSION FILE
|
||||
# TIME: Wed Mar 30 13:34:44 2022
|
||||
__version__ = '1.3.2'
|
||||
__gitsha__ = 'unknown'
|
||||
version_info = (1, 3, 2)
|
||||
@@ -0,0 +1,3 @@
|
||||
# Weights
|
||||
|
||||
Put the downloaded weights to this folder.
|
||||
@@ -0,0 +1,164 @@
|
||||
import argparse
|
||||
import math
|
||||
import torch
|
||||
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
|
||||
|
||||
def modify_checkpoint(checkpoint_bilinear, checkpoint_clean):
|
||||
for ori_k, ori_v in checkpoint_bilinear.items():
|
||||
if 'stylegan_decoder' in ori_k:
|
||||
if 'style_mlp' in ori_k: # style_mlp_layers
|
||||
lr_mul = 0.01
|
||||
prefix, name, idx, var = ori_k.split('.')
|
||||
idx = (int(idx) * 2) - 1
|
||||
crt_k = f'{prefix}.{name}.{idx}.{var}'
|
||||
if var == 'weight':
|
||||
_, c_in = ori_v.size()
|
||||
scale = (1 / math.sqrt(c_in)) * lr_mul
|
||||
crt_v = ori_v * scale * 2**0.5
|
||||
else:
|
||||
crt_v = ori_v * lr_mul * 2**0.5
|
||||
checkpoint_clean[crt_k] = crt_v
|
||||
elif 'modulation' in ori_k: # modulation in StyleConv
|
||||
lr_mul = 1
|
||||
crt_k = ori_k
|
||||
var = ori_k.split('.')[-1]
|
||||
if var == 'weight':
|
||||
_, c_in = ori_v.size()
|
||||
scale = (1 / math.sqrt(c_in)) * lr_mul
|
||||
crt_v = ori_v * scale
|
||||
else:
|
||||
crt_v = ori_v * lr_mul
|
||||
checkpoint_clean[crt_k] = crt_v
|
||||
elif 'style_conv' in ori_k:
|
||||
# StyleConv in style_conv1 and style_convs
|
||||
if 'activate' in ori_k: # FusedLeakyReLU
|
||||
# eg. style_conv1.activate.bias
|
||||
# eg. style_convs.13.activate.bias
|
||||
split_rlt = ori_k.split('.')
|
||||
if len(split_rlt) == 4:
|
||||
prefix, name, _, var = split_rlt
|
||||
crt_k = f'{prefix}.{name}.{var}'
|
||||
elif len(split_rlt) == 5:
|
||||
prefix, name, idx, _, var = split_rlt
|
||||
crt_k = f'{prefix}.{name}.{idx}.{var}'
|
||||
crt_v = ori_v * 2**0.5 # 2**0.5 used in FusedLeakyReLU
|
||||
c = crt_v.size(0)
|
||||
checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1)
|
||||
elif 'modulated_conv' in ori_k:
|
||||
# eg. style_conv1.modulated_conv.weight
|
||||
# eg. style_convs.13.modulated_conv.weight
|
||||
_, c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
crt_k = ori_k
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
elif 'weight' in ori_k:
|
||||
crt_k = ori_k
|
||||
checkpoint_clean[crt_k] = ori_v * 2**0.5
|
||||
elif 'to_rgb' in ori_k: # StyleConv in to_rgb1 and to_rgbs
|
||||
if 'modulated_conv' in ori_k:
|
||||
# eg. to_rgb1.modulated_conv.weight
|
||||
# eg. to_rgbs.5.modulated_conv.weight
|
||||
_, c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
crt_k = ori_k
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
else:
|
||||
crt_k = ori_k
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
else:
|
||||
crt_k = ori_k
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
# end of 'stylegan_decoder'
|
||||
elif 'conv_body_first' in ori_k or 'final_conv' in ori_k:
|
||||
# key name
|
||||
name, _, var = ori_k.split('.')
|
||||
crt_k = f'{name}.{var}'
|
||||
# weight and bias
|
||||
if var == 'weight':
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
|
||||
else:
|
||||
checkpoint_clean[crt_k] = ori_v * 2**0.5
|
||||
elif 'conv_body' in ori_k:
|
||||
if 'conv_body_up' in ori_k:
|
||||
ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight')
|
||||
ori_k = ori_k.replace('skip.weight', 'skip.1.weight')
|
||||
name1, idx1, name2, _, var = ori_k.split('.')
|
||||
crt_k = f'{name1}.{idx1}.{name2}.{var}'
|
||||
if name2 == 'skip':
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale / 2**0.5
|
||||
else:
|
||||
if var == 'weight':
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
else:
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
if 'conv1' in ori_k:
|
||||
checkpoint_clean[crt_k] *= 2**0.5
|
||||
elif 'toRGB' in ori_k:
|
||||
crt_k = ori_k
|
||||
if 'weight' in ori_k:
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
else:
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
elif 'final_linear' in ori_k:
|
||||
crt_k = ori_k
|
||||
if 'weight' in ori_k:
|
||||
_, c_in = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in)
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
else:
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
elif 'condition' in ori_k:
|
||||
crt_k = ori_k
|
||||
if '0.weight' in ori_k:
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
|
||||
elif '0.bias' in ori_k:
|
||||
checkpoint_clean[crt_k] = ori_v * 2**0.5
|
||||
elif '2.weight' in ori_k:
|
||||
c_out, c_in, k1, k2 = ori_v.size()
|
||||
scale = 1 / math.sqrt(c_in * k1 * k2)
|
||||
checkpoint_clean[crt_k] = ori_v * scale
|
||||
elif '2.bias' in ori_k:
|
||||
checkpoint_clean[crt_k] = ori_v
|
||||
|
||||
return checkpoint_clean
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--ori_path', type=str, help='Path to the original model')
|
||||
parser.add_argument('--narrow', type=float, default=1)
|
||||
parser.add_argument('--channel_multiplier', type=float, default=2)
|
||||
parser.add_argument('--save_path', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
ori_ckpt = torch.load(args.ori_path)['params_ema']
|
||||
|
||||
net = GFPGANv1Clean(
|
||||
512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=args.channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
# for stylegan decoder
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=args.narrow,
|
||||
sft_half=True)
|
||||
crt_ckpt = net.state_dict()
|
||||
|
||||
crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt)
|
||||
print(f'Save to {args.save_path}.')
|
||||
torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False)
|
||||
@@ -0,0 +1,85 @@
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import FileClient, imfrombytes
|
||||
from collections import OrderedDict
|
||||
|
||||
# ---------------------------- This script is used to parse facial landmarks ------------------------------------- #
|
||||
# Configurations
|
||||
save_img = False
|
||||
scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others
|
||||
enlarge_ratio = 1.4 # only for eyes
|
||||
json_path = 'ffhq-dataset-v2.json'
|
||||
face_path = 'datasets/ffhq/ffhq_512.lmdb'
|
||||
save_path = './FFHQ_eye_mouth_landmarks_512.pth'
|
||||
|
||||
print('Load JSON metadata...')
|
||||
# use the official json file in FFHQ dataset
|
||||
with open(json_path, 'rb') as f:
|
||||
json_data = json.load(f, object_pairs_hook=OrderedDict)
|
||||
|
||||
print('Open LMDB file...')
|
||||
# read ffhq images
|
||||
file_client = FileClient('lmdb', db_paths=face_path)
|
||||
with open(os.path.join(face_path, 'meta_info.txt')) as fin:
|
||||
paths = [line.split('.')[0] for line in fin]
|
||||
|
||||
save_dict = {}
|
||||
|
||||
for item_idx, item in enumerate(json_data.values()):
|
||||
print(f'\r{item_idx} / {len(json_data)}, {item["image"]["file_path"]} ', end='', flush=True)
|
||||
|
||||
# parse landmarks
|
||||
lm = np.array(item['image']['face_landmarks'])
|
||||
lm = lm * scale
|
||||
|
||||
item_dict = {}
|
||||
# get image
|
||||
if save_img:
|
||||
img_bytes = file_client.get(paths[item_idx])
|
||||
img = imfrombytes(img_bytes, float32=True)
|
||||
|
||||
# get landmarks for each component
|
||||
map_left_eye = list(range(36, 42))
|
||||
map_right_eye = list(range(42, 48))
|
||||
map_mouth = list(range(48, 68))
|
||||
|
||||
# eye_left
|
||||
mean_left_eye = np.mean(lm[map_left_eye], 0) # (x, y)
|
||||
half_len_left_eye = np.max((np.max(np.max(lm[map_left_eye], 0) - np.min(lm[map_left_eye], 0)) / 2, 16))
|
||||
item_dict['left_eye'] = [mean_left_eye[0], mean_left_eye[1], half_len_left_eye]
|
||||
# mean_left_eye[0] = 512 - mean_left_eye[0] # for testing flip
|
||||
half_len_left_eye *= enlarge_ratio
|
||||
loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int)
|
||||
if save_img:
|
||||
eye_left_img = img[loc_left_eye[1]:loc_left_eye[3], loc_left_eye[0]:loc_left_eye[2], :]
|
||||
cv2.imwrite(f'tmp/{item_idx:08d}_eye_left.png', eye_left_img * 255)
|
||||
|
||||
# eye_right
|
||||
mean_right_eye = np.mean(lm[map_right_eye], 0)
|
||||
half_len_right_eye = np.max((np.max(np.max(lm[map_right_eye], 0) - np.min(lm[map_right_eye], 0)) / 2, 16))
|
||||
item_dict['right_eye'] = [mean_right_eye[0], mean_right_eye[1], half_len_right_eye]
|
||||
# mean_right_eye[0] = 512 - mean_right_eye[0] # # for testing flip
|
||||
half_len_right_eye *= enlarge_ratio
|
||||
loc_right_eye = np.hstack(
|
||||
(mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int)
|
||||
if save_img:
|
||||
eye_right_img = img[loc_right_eye[1]:loc_right_eye[3], loc_right_eye[0]:loc_right_eye[2], :]
|
||||
cv2.imwrite(f'tmp/{item_idx:08d}_eye_right.png', eye_right_img * 255)
|
||||
|
||||
# mouth
|
||||
mean_mouth = np.mean(lm[map_mouth], 0)
|
||||
half_len_mouth = np.max((np.max(np.max(lm[map_mouth], 0) - np.min(lm[map_mouth], 0)) / 2, 16))
|
||||
item_dict['mouth'] = [mean_mouth[0], mean_mouth[1], half_len_mouth]
|
||||
# mean_mouth[0] = 512 - mean_mouth[0] # for testing flip
|
||||
loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int)
|
||||
if save_img:
|
||||
mouth_img = img[loc_mouth[1]:loc_mouth[3], loc_mouth[0]:loc_mouth[2], :]
|
||||
cv2.imwrite(f'tmp/{item_idx:08d}_mouth.png', mouth_img * 255)
|
||||
|
||||
save_dict[f'{item_idx:08d}'] = item_dict
|
||||
|
||||
print('Save...')
|
||||
torch.save(save_dict, save_path)
|
||||
@@ -0,0 +1,110 @@
|
||||
check_points/
|
||||
pretrain_models*
|
||||
test_dir_enhance_results/
|
||||
test_dir_align_results/
|
||||
test_unalign_results/
|
||||
tmp*
|
||||
local*
|
||||
*.pth
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# dotenv
|
||||
.env
|
||||
|
||||
# virtualenv
|
||||
.venv
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
@@ -0,0 +1,445 @@
|
||||
PSFR-GAN (c) by Chaofeng Chen
|
||||
|
||||
PSFR-GAN is licensed under a
|
||||
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
|
||||
|
||||
You should have received a copy of the license along with this
|
||||
work. If not, see <http://creativecommons.org/licenses/by-nc-sa/4.0/>.
|
||||
|
||||
Attribution-NonCommercial-ShareAlike 4.0 International
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
||||
does not provide legal services or legal advice. Distribution of
|
||||
Creative Commons public licenses does not create a lawyer-client or
|
||||
other relationship. Creative Commons makes its licenses and related
|
||||
information available on an "as-is" basis. Creative Commons gives no
|
||||
warranties regarding its licenses, any material licensed under their
|
||||
terms and conditions, or any related information. Creative Commons
|
||||
disclaims all liability for damages resulting from their use to the
|
||||
fullest extent possible.
|
||||
|
||||
Using Creative Commons Public Licenses
|
||||
|
||||
Creative Commons public licenses provide a standard set of terms and
|
||||
conditions that creators and other rights holders may use to share
|
||||
original works of authorship and other material subject to copyright
|
||||
and certain other rights specified in the public license below. The
|
||||
following considerations are for informational purposes only, are not
|
||||
exhaustive, and do not form part of our licenses.
|
||||
|
||||
Considerations for licensors: Our public licenses are
|
||||
intended for use by those authorized to give the public
|
||||
permission to use material in ways otherwise restricted by
|
||||
copyright and certain other rights. Our licenses are
|
||||
irrevocable. Licensors should read and understand the terms
|
||||
and conditions of the license they choose before applying it.
|
||||
Licensors should also secure all rights necessary before
|
||||
applying our licenses so that the public can reuse the
|
||||
material as expected. Licensors should clearly mark any
|
||||
material not subject to the license. This includes other CC-
|
||||
licensed material, or material used under an exception or
|
||||
limitation to copyright. More considerations for licensors:
|
||||
wiki.creativecommons.org/Considerations_for_licensors
|
||||
|
||||
Considerations for the public: By using one of our public
|
||||
licenses, a licensor grants the public permission to use the
|
||||
licensed material under specified terms and conditions. If
|
||||
the licensor's permission is not necessary for any reason--for
|
||||
example, because of any applicable exception or limitation to
|
||||
copyright--then that use is not regulated by the license. Our
|
||||
licenses grant only permissions under copyright and certain
|
||||
other rights that a licensor has authority to grant. Use of
|
||||
the licensed material may still be restricted for other
|
||||
reasons, including because others have copyright or other
|
||||
rights in the material. A licensor may make special requests,
|
||||
such as asking that all changes be marked or described.
|
||||
Although not required by our licenses, you are encouraged to
|
||||
respect those requests where reasonable. More_considerations
|
||||
for the public:
|
||||
wiki.creativecommons.org/Considerations_for_licensees
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
||||
Public License
|
||||
|
||||
By exercising the Licensed Rights (defined below), You accept and agree
|
||||
to be bound by the terms and conditions of this Creative Commons
|
||||
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
||||
("Public License"). To the extent this Public License may be
|
||||
interpreted as a contract, You are granted the Licensed Rights in
|
||||
consideration of Your acceptance of these terms and conditions, and the
|
||||
Licensor grants You such rights in consideration of benefits the
|
||||
Licensor receives from making the Licensed Material available under
|
||||
these terms and conditions.
|
||||
|
||||
|
||||
Section 1 -- Definitions.
|
||||
|
||||
a. Adapted Material means material subject to Copyright and Similar
|
||||
Rights that is derived from or based upon the Licensed Material
|
||||
and in which the Licensed Material is translated, altered,
|
||||
arranged, transformed, or otherwise modified in a manner requiring
|
||||
permission under the Copyright and Similar Rights held by the
|
||||
Licensor. For purposes of this Public License, where the Licensed
|
||||
Material is a musical work, performance, or sound recording,
|
||||
Adapted Material is always produced where the Licensed Material is
|
||||
synched in timed relation with a moving image.
|
||||
|
||||
b. Adapter's License means the license You apply to Your Copyright
|
||||
and Similar Rights in Your contributions to Adapted Material in
|
||||
accordance with the terms and conditions of this Public License.
|
||||
|
||||
c. BY-NC-SA Compatible License means a license listed at
|
||||
creativecommons.org/compatiblelicenses, approved by Creative
|
||||
Commons as essentially the equivalent of this Public License.
|
||||
|
||||
d. Copyright and Similar Rights means copyright and/or similar rights
|
||||
closely related to copyright including, without limitation,
|
||||
performance, broadcast, sound recording, and Sui Generis Database
|
||||
Rights, without regard to how the rights are labeled or
|
||||
categorized. For purposes of this Public License, the rights
|
||||
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
||||
Rights.
|
||||
|
||||
e. Effective Technological Measures means those measures that, in the
|
||||
absence of proper authority, may not be circumvented under laws
|
||||
fulfilling obligations under Article 11 of the WIPO Copyright
|
||||
Treaty adopted on December 20, 1996, and/or similar international
|
||||
agreements.
|
||||
|
||||
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
||||
any other exception or limitation to Copyright and Similar Rights
|
||||
that applies to Your use of the Licensed Material.
|
||||
|
||||
g. License Elements means the license attributes listed in the name
|
||||
of a Creative Commons Public License. The License Elements of this
|
||||
Public License are Attribution, NonCommercial, and ShareAlike.
|
||||
|
||||
h. Licensed Material means the artistic or literary work, database,
|
||||
or other material to which the Licensor applied this Public
|
||||
License.
|
||||
|
||||
i. Licensed Rights means the rights granted to You subject to the
|
||||
terms and conditions of this Public License, which are limited to
|
||||
all Copyright and Similar Rights that apply to Your use of the
|
||||
Licensed Material and that the Licensor has authority to license.
|
||||
|
||||
j. Licensor means the individual(s) or entity(ies) granting rights
|
||||
under this Public License.
|
||||
|
||||
k. NonCommercial means not primarily intended for or directed towards
|
||||
commercial advantage or monetary compensation. For purposes of
|
||||
this Public License, the exchange of the Licensed Material for
|
||||
other material subject to Copyright and Similar Rights by digital
|
||||
file-sharing or similar means is NonCommercial provided there is
|
||||
no payment of monetary compensation in connection with the
|
||||
exchange.
|
||||
|
||||
l. Share means to provide material to the public by any means or
|
||||
process that requires permission under the Licensed Rights, such
|
||||
as reproduction, public display, public performance, distribution,
|
||||
dissemination, communication, or importation, and to make material
|
||||
available to the public including in ways that members of the
|
||||
public may access the material from a place and at a time
|
||||
individually chosen by them.
|
||||
|
||||
m. Sui Generis Database Rights means rights other than copyright
|
||||
resulting from Directive 96/9/EC of the European Parliament and of
|
||||
the Council of 11 March 1996 on the legal protection of databases,
|
||||
as amended and/or succeeded, as well as other essentially
|
||||
equivalent rights anywhere in the world.
|
||||
|
||||
n. You means the individual or entity exercising the Licensed Rights
|
||||
under this Public License. Your has a corresponding meaning.
|
||||
|
||||
|
||||
Section 2 -- Scope.
|
||||
|
||||
a. License grant.
|
||||
|
||||
1. Subject to the terms and conditions of this Public License,
|
||||
the Licensor hereby grants You a worldwide, royalty-free,
|
||||
non-sublicensable, non-exclusive, irrevocable license to
|
||||
exercise the Licensed Rights in the Licensed Material to:
|
||||
|
||||
a. reproduce and Share the Licensed Material, in whole or
|
||||
in part, for NonCommercial purposes only; and
|
||||
|
||||
b. produce, reproduce, and Share Adapted Material for
|
||||
NonCommercial purposes only.
|
||||
|
||||
2. Exceptions and Limitations. For the avoidance of doubt, where
|
||||
Exceptions and Limitations apply to Your use, this Public
|
||||
License does not apply, and You do not need to comply with
|
||||
its terms and conditions.
|
||||
|
||||
3. Term. The term of this Public License is specified in Section
|
||||
6(a).
|
||||
|
||||
4. Media and formats; technical modifications allowed. The
|
||||
Licensor authorizes You to exercise the Licensed Rights in
|
||||
all media and formats whether now known or hereafter created,
|
||||
and to make technical modifications necessary to do so. The
|
||||
Licensor waives and/or agrees not to assert any right or
|
||||
authority to forbid You from making technical modifications
|
||||
necessary to exercise the Licensed Rights, including
|
||||
technical modifications necessary to circumvent Effective
|
||||
Technological Measures. For purposes of this Public License,
|
||||
simply making modifications authorized by this Section 2(a)
|
||||
(4) never produces Adapted Material.
|
||||
|
||||
5. Downstream recipients.
|
||||
|
||||
a. Offer from the Licensor -- Licensed Material. Every
|
||||
recipient of the Licensed Material automatically
|
||||
receives an offer from the Licensor to exercise the
|
||||
Licensed Rights under the terms and conditions of this
|
||||
Public License.
|
||||
|
||||
b. Additional offer from the Licensor -- Adapted Material.
|
||||
Every recipient of Adapted Material from You
|
||||
automatically receives an offer from the Licensor to
|
||||
exercise the Licensed Rights in the Adapted Material
|
||||
under the conditions of the Adapter's License You apply.
|
||||
|
||||
c. No downstream restrictions. You may not offer or impose
|
||||
any additional or different terms or conditions on, or
|
||||
apply any Effective Technological Measures to, the
|
||||
Licensed Material if doing so restricts exercise of the
|
||||
Licensed Rights by any recipient of the Licensed
|
||||
Material.
|
||||
|
||||
6. No endorsement. Nothing in this Public License constitutes or
|
||||
may be construed as permission to assert or imply that You
|
||||
are, or that Your use of the Licensed Material is, connected
|
||||
with, or sponsored, endorsed, or granted official status by,
|
||||
the Licensor or others designated to receive attribution as
|
||||
provided in Section 3(a)(1)(A)(i).
|
||||
|
||||
b. Other rights.
|
||||
|
||||
1. Moral rights, such as the right of integrity, are not
|
||||
licensed under this Public License, nor are publicity,
|
||||
privacy, and/or other similar personality rights; however, to
|
||||
the extent possible, the Licensor waives and/or agrees not to
|
||||
assert any such rights held by the Licensor to the limited
|
||||
extent necessary to allow You to exercise the Licensed
|
||||
Rights, but not otherwise.
|
||||
|
||||
2. Patent and trademark rights are not licensed under this
|
||||
Public License.
|
||||
|
||||
3. To the extent possible, the Licensor waives any right to
|
||||
collect royalties from You for the exercise of the Licensed
|
||||
Rights, whether directly or through a collecting society
|
||||
under any voluntary or waivable statutory or compulsory
|
||||
licensing scheme. In all other cases the Licensor expressly
|
||||
reserves any right to collect such royalties, including when
|
||||
the Licensed Material is used other than for NonCommercial
|
||||
purposes.
|
||||
|
||||
|
||||
Section 3 -- License Conditions.
|
||||
|
||||
Your exercise of the Licensed Rights is expressly made subject to the
|
||||
following conditions.
|
||||
|
||||
a. Attribution.
|
||||
|
||||
1. If You Share the Licensed Material (including in modified
|
||||
form), You must:
|
||||
|
||||
a. retain the following if it is supplied by the Licensor
|
||||
with the Licensed Material:
|
||||
|
||||
i. identification of the creator(s) of the Licensed
|
||||
Material and any others designated to receive
|
||||
attribution, in any reasonable manner requested by
|
||||
the Licensor (including by pseudonym if
|
||||
designated);
|
||||
|
||||
ii. a copyright notice;
|
||||
|
||||
iii. a notice that refers to this Public License;
|
||||
|
||||
iv. a notice that refers to the disclaimer of
|
||||
warranties;
|
||||
|
||||
v. a URI or hyperlink to the Licensed Material to the
|
||||
extent reasonably practicable;
|
||||
|
||||
b. indicate if You modified the Licensed Material and
|
||||
retain an indication of any previous modifications; and
|
||||
|
||||
c. indicate the Licensed Material is licensed under this
|
||||
Public License, and include the text of, or the URI or
|
||||
hyperlink to, this Public License.
|
||||
|
||||
2. You may satisfy the conditions in Section 3(a)(1) in any
|
||||
reasonable manner based on the medium, means, and context in
|
||||
which You Share the Licensed Material. For example, it may be
|
||||
reasonable to satisfy the conditions by providing a URI or
|
||||
hyperlink to a resource that includes the required
|
||||
information.
|
||||
3. If requested by the Licensor, You must remove any of the
|
||||
information required by Section 3(a)(1)(A) to the extent
|
||||
reasonably practicable.
|
||||
|
||||
b. ShareAlike.
|
||||
|
||||
In addition to the conditions in Section 3(a), if You Share
|
||||
Adapted Material You produce, the following conditions also apply.
|
||||
|
||||
1. The Adapter's License You apply must be a Creative Commons
|
||||
license with the same License Elements, this version or
|
||||
later, or a BY-NC-SA Compatible License.
|
||||
|
||||
2. You must include the text of, or the URI or hyperlink to, the
|
||||
Adapter's License You apply. You may satisfy this condition
|
||||
in any reasonable manner based on the medium, means, and
|
||||
context in which You Share Adapted Material.
|
||||
|
||||
3. You may not offer or impose any additional or different terms
|
||||
or conditions on, or apply any Effective Technological
|
||||
Measures to, Adapted Material that restrict exercise of the
|
||||
rights granted under the Adapter's License You apply.
|
||||
|
||||
|
||||
Section 4 -- Sui Generis Database Rights.
|
||||
|
||||
Where the Licensed Rights include Sui Generis Database Rights that
|
||||
apply to Your use of the Licensed Material:
|
||||
|
||||
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
||||
to extract, reuse, reproduce, and Share all or a substantial
|
||||
portion of the contents of the database for NonCommercial purposes
|
||||
only;
|
||||
|
||||
b. if You include all or a substantial portion of the database
|
||||
contents in a database in which You have Sui Generis Database
|
||||
Rights, then the database in which You have Sui Generis Database
|
||||
Rights (but not its individual contents) is Adapted Material,
|
||||
including for purposes of Section 3(b); and
|
||||
|
||||
c. You must comply with the conditions in Section 3(a) if You Share
|
||||
all or a substantial portion of the contents of the database.
|
||||
|
||||
For the avoidance of doubt, this Section 4 supplements and does not
|
||||
replace Your obligations under this Public License where the Licensed
|
||||
Rights include other Copyright and Similar Rights.
|
||||
|
||||
|
||||
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
||||
|
||||
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
||||
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
||||
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
||||
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
||||
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
||||
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
||||
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
||||
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
||||
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
||||
|
||||
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
||||
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
||||
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
||||
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
||||
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
||||
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
||||
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
||||
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
||||
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
||||
|
||||
c. The disclaimer of warranties and limitation of liability provided
|
||||
above shall be interpreted in a manner that, to the extent
|
||||
possible, most closely approximates an absolute disclaimer and
|
||||
waiver of all liability.
|
||||
|
||||
|
||||
Section 6 -- Term and Termination.
|
||||
|
||||
a. This Public License applies for the term of the Copyright and
|
||||
Similar Rights licensed here. However, if You fail to comply with
|
||||
this Public License, then Your rights under this Public License
|
||||
terminate automatically.
|
||||
|
||||
b. Where Your right to use the Licensed Material has terminated under
|
||||
Section 6(a), it reinstates:
|
||||
|
||||
1. automatically as of the date the violation is cured, provided
|
||||
it is cured within 30 days of Your discovery of the
|
||||
violation; or
|
||||
|
||||
2. upon express reinstatement by the Licensor.
|
||||
|
||||
For the avoidance of doubt, this Section 6(b) does not affect any
|
||||
right the Licensor may have to seek remedies for Your violations
|
||||
of this Public License.
|
||||
|
||||
c. For the avoidance of doubt, the Licensor may also offer the
|
||||
Licensed Material under separate terms or conditions or stop
|
||||
distributing the Licensed Material at any time; however, doing so
|
||||
will not terminate this Public License.
|
||||
|
||||
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
||||
License.
|
||||
|
||||
|
||||
Section 7 -- Other Terms and Conditions.
|
||||
|
||||
a. The Licensor shall not be bound by any additional or different
|
||||
terms or conditions communicated by You unless expressly agreed.
|
||||
|
||||
b. Any arrangements, understandings, or agreements regarding the
|
||||
Licensed Material not stated herein are separate from and
|
||||
independent of the terms and conditions of this Public License.
|
||||
|
||||
|
||||
Section 8 -- Interpretation.
|
||||
|
||||
a. For the avoidance of doubt, this Public License does not, and
|
||||
shall not be interpreted to, reduce, limit, restrict, or impose
|
||||
conditions on any use of the Licensed Material that could lawfully
|
||||
be made without permission under this Public License.
|
||||
|
||||
b. To the extent possible, if any provision of this Public License is
|
||||
deemed unenforceable, it shall be automatically reformed to the
|
||||
minimum extent necessary to make it enforceable. If the provision
|
||||
cannot be reformed, it shall be severed from this Public License
|
||||
without affecting the enforceability of the remaining terms and
|
||||
conditions.
|
||||
|
||||
c. No term or condition of this Public License will be waived and no
|
||||
failure to comply consented to unless expressly agreed to by the
|
||||
Licensor.
|
||||
|
||||
d. Nothing in this Public License constitutes or may be interpreted
|
||||
as a limitation upon, or waiver of, any privileges and immunities
|
||||
that apply to the Licensor or You, including from the legal
|
||||
processes of any jurisdiction or authority.
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons is not a party to its public
|
||||
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
||||
its public licenses to material it publishes and in those instances
|
||||
will be considered the “Licensor.” The text of the Creative Commons
|
||||
public licenses is dedicated to the public domain under the CC0 Public
|
||||
Domain Dedication. Except for the limited purpose of indicating that
|
||||
material is shared under a Creative Commons public license or as
|
||||
otherwise permitted by the Creative Commons policies published at
|
||||
creativecommons.org/policies, Creative Commons does not authorize the
|
||||
use of the trademark "Creative Commons" or any other trademark or logo
|
||||
of Creative Commons without its prior written consent including,
|
||||
without limitation, in connection with any unauthorized modifications
|
||||
to any of its public licenses or any other arrangements,
|
||||
understandings, or agreements concerning use of licensed material. For
|
||||
the avoidance of doubt, this paragraph does not form part of the
|
||||
public licenses.
|
||||
|
||||
Creative Commons may be contacted at creativecommons.org.
|
||||
@@ -0,0 +1,123 @@
|
||||
# PSFR-GAN in PyTorch
|
||||
|
||||
[Progressive Semantic-Aware Style Transformation for Blind Face Restoration](https://arxiv.org/abs/2009.08709)
|
||||
[Chaofeng Chen](https://chaofengc.github.io), [Xiaoming Li](https://csxmli2016.github.io/), [Lingbo Yang](https://lotayou.github.io), [Xianhui Lin](https://dblp.org/pid/147/7708.html), [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang/), [Kwan-Yee K. Wong](https://i.cs.hku.hk/~kykwong/)
|
||||
|
||||

|
||||

|
||||
|
||||
### Changelog
|
||||
- **2021.04.26**: Add pytorch vgg19 model to GoogleDrive and remove `--distributed` option which causes training error.
|
||||
- **2021.03.22**: Update new model at 15 epoch (52.5k iterations).
|
||||
- **2021.03.19**: Add train codes for PSFRGAN and FPN.
|
||||
|
||||
## Prerequisites and Installation
|
||||
- Ubuntu 18.04
|
||||
- CUDA 10.1
|
||||
- Clone this repository
|
||||
```
|
||||
git clone https://github.com/chaofengc/PSFR-GAN.git
|
||||
cd PSFR-GAN
|
||||
```
|
||||
- Python 3.7, install required packages by `pip3 install -r requirements.txt`
|
||||
|
||||
## Quick Test
|
||||
|
||||
### Download Pretrain Models and Dataset
|
||||
Download the pretrained models from the following link and put them to `./pretrain_models`
|
||||
- [GoogleDrive](https://drive.google.com/drive/folders/1Ubejhxd2xd4fxGc_M_LWl3Ux6CgQd9rP?usp=sharing)
|
||||
- [BaiduNetDisk](https://pan.baidu.com/s/1cru3uUASEfGX6p6L0_7gWQ), extract code: `gj2r`
|
||||
|
||||
### Test single image
|
||||
Run the following script to enhance face(s) in single input
|
||||
```
|
||||
python test_enhance_single_unalign.py --test_img_path ./test_dir/test_hzgg.jpg --results_dir test_hzgg_results --gpus 1
|
||||
```
|
||||
|
||||
This script do the following things:
|
||||
- Crop and align all the faces from input image, stored at `results_dir/LQ_faces`
|
||||
- Parse these faces and then enhance them, results stored at `results_dir/ParseMaps` and `results_dir/HQ`
|
||||
- Paste then enhanced faces back to the original image `results_dir/hq_final.jpg`
|
||||
- You can use `--gpus` to specify how many GPUs to use, `<=0` means running on CPU. The program will use GPU with the most available memory. Set `CUDA_VISIBLE_DEVICE` to specify the GPU if you do not want automatic GPU selection.
|
||||
|
||||
### Test image folder
|
||||
To test multiple images, we first crop out all the faces and align them use the following script.
|
||||
```
|
||||
python align_and_crop_dir.py --src_dir test_dir --results_dir test_dir_align_results
|
||||
```
|
||||
|
||||
For images (*e.g.* `multiface_test.jpg`) contain multiple faces, the aligned faces will be stored as `multiface_test_{face_index}.jpg`
|
||||
And then parse the aligned faces and enhance them with
|
||||
```
|
||||
python test_enhance_dir_align.py --src_dir test_dir_align_results --results_dir test_dir_enhance_results
|
||||
```
|
||||
Results will be saved to three folders respectively: `results_dir/lq`, `results_dir/parse`, `results_dir/hq`.
|
||||
|
||||
### Additional test script
|
||||
|
||||
For your convenience, we also provide script to test multiple unaligned images and paste the enhance results back. **Note the paste back operation could be quite slow for large size images containing many faces (dlib takes time to detect faces in large image).**
|
||||
```
|
||||
python test_enhance_dir_unalign.py --src_dir test_dir --results_dir test_unalign_results
|
||||
```
|
||||
This script basically do the same thing as `test_enhance_single_unalign.py` for each image in `src_dir`
|
||||
|
||||
## Train the Model
|
||||
|
||||
### Data Preparation
|
||||
|
||||
- Download [FFHQ](https://github.com/NVlabs/ffhq-dataset) and put the images to `../datasets/FFHQ/imgs1024`
|
||||
- Download parsing masks (`512x512`) [HERE](https://drive.google.com/file/d/1eQwO8hKcaluyCnxuZAp0eJVOdgMi30uA/view?usp=sharing) generated by the pretrained FPN and put them to `../datasets/FFHQ/masks512`.
|
||||
|
||||
*Note: you may change `../datasets/FFHQ` to your own path. But images and masks must be stored under `your_own_path/imgs1024` and `your_own_path/masks512` respectively.*
|
||||
|
||||
### Train Script for PSFRGAN
|
||||
|
||||
Here is an example train script for PSFRGAN:
|
||||
|
||||
```
|
||||
python train.py --gpus 2 --model enhance --name PSFRGAN_v001 \
|
||||
--g_lr 0.0001 --d_lr 0.0004 --beta1 0.5 \
|
||||
--gan_mode 'hinge' --lambda_pix 10 --lambda_fm 10 --lambda_ss 1000 \
|
||||
--Dinput_nc 22 --D_num 3 --n_layers_D 4 \
|
||||
--batch_size 2 --dataset ffhq --dataroot ../datasets/FFHQ \
|
||||
--visual_freq 100 --print_freq 10 #--continue_train
|
||||
```
|
||||
- Please change the `--name` option for different experiments. Tensorboard records with the same name will be moved to `check_points/log_archive`, and the weight directory will only store weight history of latest experiment with the same name.
|
||||
- `--gpus` specify number of GPUs used to train. The script will use GPUs with more available memory first. To specify the GPU index, use `export CUDA_VISIBLE_DEVICES=your_gpu_ids` before the script.
|
||||
- Uncomment `--continue_train` to resume training. *Current codes do not resume the optimizer state.*
|
||||
- It needs at least **8GB** memory to train with **batch_size=1**.
|
||||
|
||||
### Scripts for FPN
|
||||
|
||||
You may also train your own FPN and generate masks for the HQ images by yourself with the following steps:
|
||||
|
||||
- Download [CelebAHQ-Mask](https://github.com/switchablenorms/CelebAMask-HQ) dataset. Generate `CelebAMask-HQ-mask` and `CelebAMask-HQ-mask-color` with the provided scripts in `CelebAMask-HQ/face_parsing/Data_preprocessing/`.
|
||||
- Train FPN with the following commmand
|
||||
```
|
||||
python train.py --gpus 1 --model parse --name FPN_v001 \
|
||||
--lr 0.0002 --batch_size 8 \
|
||||
--dataset celebahqmask --dataroot ../datasets/CelebAMask-HQ \
|
||||
--visual_freq 100 --print_freq 10 #--continue_train
|
||||
```
|
||||
- Generate parsing masks with your own FPN using the following command:
|
||||
```
|
||||
python generate_masks.py --save_masks_dir ../datasets/FFHQ/masks512 --batch_size 8 --parse_net_weight path/to/your/own/FPN
|
||||
```
|
||||
|
||||
## Citation
|
||||
```
|
||||
@inproceedings{ChenPSFRGAN,
|
||||
author = {Chen, Chaofeng and Li, Xiaoming and Lingbo, Yang and Lin, Xianhui and Zhang, Lei and Wong, KKY},
|
||||
title = {Progressive Semantic-Aware Style Transformation for Blind Face Restoration},
|
||||
Journal = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
year = {2021}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
<a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc-sa/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/">Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License</a>.
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This work is inspired by [SPADE](https://github.com/NVlabs/SPADE), and closed related to [DFDNet](https://github.com/csxmli2016/DFDNet) and [HiFaceGAN](https://github.com/Lotayou/Face-Renovation). Our codes largely benefit from [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
|
||||
@@ -0,0 +1,86 @@
|
||||
import dlib
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from skimage import transform as trans
|
||||
from skimage import io
|
||||
import argparse
|
||||
|
||||
|
||||
def get_points(img, detector, shape_predictor, size_threshold=999):
|
||||
dets = detector(img, 1)
|
||||
if len(dets) == 0:
|
||||
return None
|
||||
|
||||
all_points = []
|
||||
for det in dets:
|
||||
if isinstance(detector, dlib.cnn_face_detection_model_v1):
|
||||
rec = det.rect # for cnn detector
|
||||
else:
|
||||
rec = det
|
||||
if rec.width() > size_threshold or rec.height() > size_threshold:
|
||||
break
|
||||
shape = shape_predictor(img, rec)
|
||||
single_points = []
|
||||
for i in range(5):
|
||||
single_points.append([shape.part(i).x, shape.part(i).y])
|
||||
all_points.append(np.array(single_points))
|
||||
if len(all_points) <= 0:
|
||||
return None
|
||||
else:
|
||||
return all_points
|
||||
|
||||
def align_and_save(img, save_path, src_points, template_path, template_scale=1):
|
||||
out_size = (512, 512)
|
||||
reference = np.load(template_path) / template_scale
|
||||
|
||||
ext = os.path.splitext(save_path)
|
||||
for idx, spoint in enumerate(src_points):
|
||||
tform = trans.SimilarityTransform()
|
||||
tform.estimate(spoint, reference)
|
||||
M = tform.params[0:2,:]
|
||||
|
||||
crop_img = cv2.warpAffine(img, M, out_size)
|
||||
if len(src_points) > 1:
|
||||
save_path = ext[0] + '_{}'.format(idx) + ext[1]
|
||||
dlib.save_image(crop_img.astype(np.uint8), save_path)
|
||||
print('Saving image', save_path)
|
||||
|
||||
def align_and_save_dir(src_dir, save_dir, template_path='./pretrain_models/FFHQ_template.npy', template_scale=2, use_cnn_detector=True):
|
||||
out_size = (512, 512)
|
||||
if use_cnn_detector:
|
||||
detector = dlib.cnn_face_detection_model_v1('./pretrain_models/mmod_human_face_detector.dat')
|
||||
else:
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
sp = dlib.shape_predictor('./pretrain_models/shape_predictor_5_face_landmarks.dat')
|
||||
|
||||
for name in os.listdir(src_dir):
|
||||
img_path = os.path.join(src_dir, name)
|
||||
img = dlib.load_rgb_image(img_path)
|
||||
|
||||
points = get_points(img, detector, sp)
|
||||
if points is not None:
|
||||
save_path = os.path.join(save_dir, name)
|
||||
align_and_save(img, save_path, points, template_path, template_scale)
|
||||
else:
|
||||
print('No face detected in', img_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--src_dir', type=str, help='source directory containing images to crop and align.')
|
||||
parser.add_argument('--results_dir', type=str, help='results directory to save the aligned faces.')
|
||||
parser.add_argument('--not_use_cnn_detector', action='store_true', help='do not use cnn face detector in dlib.')
|
||||
args = parser.parse_args()
|
||||
|
||||
src_dir = args.src_dir
|
||||
assert os.path.isdir(src_dir), 'Source path should be a directory containing images'
|
||||
save_dir = args.results_dir
|
||||
if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
|
||||
align_and_save_dir(src_dir, save_dir, use_cnn_detector=not args.not_use_cnn_detector)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
"""This package includes all the modules related to data loading and preprocessing
|
||||
|
||||
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
||||
You need to implement four functions:
|
||||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
||||
-- <__len__>: return the size of dataset.
|
||||
-- <__getitem__>: get a data point from data loader.
|
||||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
||||
|
||||
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
||||
See our template dataset class 'template_dataset.py' for more details.
|
||||
"""
|
||||
import importlib
|
||||
import torch.utils.data
|
||||
from data.base_dataset import BaseDataset
|
||||
|
||||
|
||||
def find_dataset_using_name(dataset_name):
|
||||
"""Import the module "data/[dataset_name]_dataset.py".
|
||||
|
||||
In the file, the class called DatasetNameDataset() will
|
||||
be instantiated. It has to be a subclass of BaseDataset,
|
||||
and it is case-insensitive.
|
||||
"""
|
||||
dataset_filename = "data." + dataset_name + "_dataset"
|
||||
datasetlib = importlib.import_module(dataset_filename)
|
||||
|
||||
dataset = None
|
||||
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
||||
for name, cls in datasetlib.__dict__.items():
|
||||
if name.lower() == target_dataset_name.lower() \
|
||||
and issubclass(cls, BaseDataset):
|
||||
dataset = cls
|
||||
|
||||
if dataset is None:
|
||||
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_option_setter(dataset_name):
|
||||
"""Return the static method <modify_commandline_options> of the dataset class."""
|
||||
dataset_class = find_dataset_using_name(dataset_name)
|
||||
return dataset_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_dataset(opt):
|
||||
"""Create a dataset given the option.
|
||||
|
||||
This function wraps the class CustomDatasetDataLoader.
|
||||
This is the main interface between this package and 'train.py'/'test.py'
|
||||
|
||||
Example:
|
||||
>>> from data import create_dataset
|
||||
>>> dataset = create_dataset(opt)
|
||||
"""
|
||||
data_loader = CustomDatasetDataLoader(opt)
|
||||
dataset = data_loader.load_data()
|
||||
return dataset
|
||||
|
||||
|
||||
class CustomDatasetDataLoader():
|
||||
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this class
|
||||
|
||||
Step 1: create a dataset instance given the name [dataset_mode]
|
||||
Step 2: create a multi-threaded data loader.
|
||||
"""
|
||||
self.opt = opt
|
||||
dataset_class = find_dataset_using_name(opt.dataset_name)
|
||||
self.dataset = dataset_class(opt)
|
||||
print("dataset [%s] was created" % type(self.dataset).__name__)
|
||||
drop_last = True if opt.isTrain else False
|
||||
self.dataloader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
batch_size=opt.batch_size,
|
||||
shuffle=not opt.serial_batches,
|
||||
num_workers=int(opt.num_threads), drop_last=drop_last)
|
||||
|
||||
def load_data(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of data in the dataset"""
|
||||
return min(len(self.dataset), self.opt.max_dataset_size)
|
||||
|
||||
def __iter__(self):
|
||||
"""Return a batch of data"""
|
||||
for i, data in enumerate(self.dataloader):
|
||||
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
||||
break
|
||||
yield data
|
||||
@@ -0,0 +1,162 @@
|
||||
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
||||
|
||||
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
||||
"""
|
||||
import random
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import imgaug as ia
|
||||
import imgaug.augmenters as iaa
|
||||
|
||||
class BaseDataset(data.Dataset, ABC):
|
||||
"""This class is an abstract base class (ABC) for datasets.
|
||||
|
||||
To create a subclass, you need to implement the following four functions:
|
||||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
||||
-- <__len__>: return the size of dataset.
|
||||
-- <__getitem__>: get a data point.
|
||||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the class; save the options in the class
|
||||
|
||||
Parameters:
|
||||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
self.opt = opt
|
||||
self.root = opt.dataroot
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
"""Add new dataset-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset."""
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index - - a random integer for data indexing
|
||||
|
||||
Returns:
|
||||
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_params(opt, size):
|
||||
w, h = size
|
||||
new_h = h
|
||||
new_w = w
|
||||
if opt.preprocess == 'resize_and_crop':
|
||||
new_h = new_w = opt.load_size
|
||||
elif opt.preprocess == 'scale_width_and_crop':
|
||||
new_w = opt.load_size
|
||||
new_h = opt.load_size * h // w
|
||||
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
||||
|
||||
flip = random.random() > 0.5
|
||||
|
||||
return {'crop_pos': (x, y), 'flip': flip}
|
||||
|
||||
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
|
||||
transform_list = []
|
||||
if grayscale:
|
||||
# transform_list.append(transforms.Grayscale(1))
|
||||
from util import util
|
||||
transform_list.append(util.RGBtoY)
|
||||
if 'resize' in opt.preprocess:
|
||||
osize = [opt.load_size, opt.load_size]
|
||||
transform_list.append(transforms.Resize(osize, method))
|
||||
elif 'scale_width' in opt.preprocess:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
|
||||
|
||||
if 'crop' in opt.preprocess:
|
||||
if params is None:
|
||||
transform_list.append(transforms.RandomCrop(opt.crop_size))
|
||||
else:
|
||||
if 'crop_size' in params:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], params['crop_size'])))
|
||||
else:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
|
||||
|
||||
if opt.preprocess == 'none':
|
||||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
|
||||
|
||||
if not opt.no_flip:
|
||||
if params is None:
|
||||
transform_list.append(transforms.RandomHorizontalFlip())
|
||||
elif params['flip']:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
||||
|
||||
if convert:
|
||||
transform_list += [transforms.ToTensor()]
|
||||
if grayscale:
|
||||
transform_list += [transforms.Normalize((0.5,), (0.5,))]
|
||||
else:
|
||||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
def __make_power_2(img, base, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
h = int(round(oh / base) * base)
|
||||
w = int(round(ow / base) * base)
|
||||
if (h == oh) and (w == ow):
|
||||
return img
|
||||
|
||||
__print_size_warning(ow, oh, w, h)
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __scale_width(img, target_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
if (ow == target_width):
|
||||
return img
|
||||
w = target_width
|
||||
h = int(target_width * oh / ow)
|
||||
return img.resize((w, h), method)
|
||||
|
||||
|
||||
def __crop(img, pos, size):
|
||||
ow, oh = img.size
|
||||
x1, y1 = pos
|
||||
tw = th = size
|
||||
if (ow > tw or oh > th):
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
return img
|
||||
|
||||
|
||||
def __flip(img, flip):
|
||||
if flip:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
|
||||
|
||||
def __print_size_warning(ow, oh, w, h):
|
||||
"""Print warning information about image size(only print once)"""
|
||||
if not hasattr(__print_size_warning, 'has_printed'):
|
||||
print("The image size needs to be a multiple of 4. "
|
||||
"The loaded image size was (%d, %d), so it was adjusted to "
|
||||
"(%d, %d). This adjustment will be done to all images "
|
||||
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
||||
__print_size_warning.has_printed = True
|
||||
@@ -0,0 +1,60 @@
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import imgaug as ia
|
||||
import imgaug.augmenters as iaa
|
||||
|
||||
from data.image_folder import make_dataset
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
from data.base_dataset import BaseDataset
|
||||
from utils.utils import onehot_parse_map
|
||||
|
||||
from data.ffhq_dataset import complex_imgaug, random_gray
|
||||
|
||||
class CelebAHQMaskDataset(BaseDataset):
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseDataset.__init__(self, opt)
|
||||
self.img_size = opt.Pimg_size
|
||||
self.lr_size = opt.Gin_size
|
||||
self.hr_size = opt.Gout_size
|
||||
self.shuffle = True if opt.isTrain else False
|
||||
|
||||
self.img_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'CelebA-HQ-img')))
|
||||
self.mask_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'CelebAMask-HQ-mask')))
|
||||
|
||||
self.to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
def __len__(self,):
|
||||
return len(self.img_dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = {}
|
||||
img_path = self.img_dataset[idx]
|
||||
mask_path = self.mask_dataset[idx]
|
||||
hr_img = Image.open(img_path).convert('RGB')
|
||||
mask_img = Image.open(mask_path)
|
||||
|
||||
hr_img = hr_img.resize((self.hr_size, self.hr_size))
|
||||
hr_img = random_gray(hr_img, p=0.3)
|
||||
scale_size = np.random.randint(32, 256)
|
||||
lr_img = complex_imgaug(hr_img, self.img_size, scale_size)
|
||||
|
||||
mask_img = mask_img.resize((self.hr_size, self.hr_size))
|
||||
mask_label = torch.tensor(np.array(mask_img)).long()
|
||||
|
||||
hr_tensor = self.to_tensor(hr_img)
|
||||
lr_tensor = self.to_tensor(lr_img)
|
||||
|
||||
return {'HR': hr_tensor, 'LR': lr_tensor, 'HR_paths': img_path, 'Mask': mask_label}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import imgaug as ia
|
||||
import imgaug.augmenters as iaa
|
||||
|
||||
from data.image_folder import make_dataset
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
from data.base_dataset import BaseDataset
|
||||
from utils.utils import onehot_parse_map
|
||||
|
||||
class FFHQDataset(BaseDataset):
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseDataset.__init__(self, opt)
|
||||
self.img_size = opt.Pimg_size
|
||||
self.lr_size = opt.Gin_size
|
||||
self.hr_size = opt.Gout_size
|
||||
self.shuffle = True if opt.isTrain else False
|
||||
|
||||
self.img_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'imgs1024')))
|
||||
self.mask_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'masks512')))
|
||||
|
||||
self.to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
self.random_crop = transforms.RandomCrop(self.hr_size)
|
||||
|
||||
def __len__(self,):
|
||||
return len(self.img_dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = {}
|
||||
img_path = self.img_dataset[idx]
|
||||
mask_path = self.mask_dataset[idx]
|
||||
hr_img = Image.open(img_path).convert('RGB')
|
||||
mask_img = Image.open(mask_path).convert('RGB')
|
||||
|
||||
hr_img = hr_img.resize((self.hr_size, self.hr_size))
|
||||
hr_img = random_gray(hr_img, p=0.3)
|
||||
scale_size = np.random.randint(32, 256)
|
||||
lr_img = complex_imgaug(hr_img, self.img_size, scale_size)
|
||||
|
||||
mask_img = mask_img.resize((self.hr_size, self.hr_size))
|
||||
mask_label = onehot_parse_map(mask_img)
|
||||
mask_label = torch.tensor(mask_label).float()
|
||||
|
||||
hr_tensor = self.to_tensor(hr_img)
|
||||
lr_tensor = self.to_tensor(lr_img)
|
||||
|
||||
return {'HR': hr_tensor, 'LR': lr_tensor, 'HR_paths': img_path, 'Mask': mask_label}
|
||||
|
||||
|
||||
def complex_imgaug(x, org_size, scale_size):
|
||||
"""input single RGB PIL Image instance"""
|
||||
x = np.array(x)
|
||||
x = x[np.newaxis, :, :, :]
|
||||
aug_seq = iaa.Sequential([
|
||||
iaa.Sometimes(0.5, iaa.OneOf([
|
||||
iaa.GaussianBlur((3, 15)),
|
||||
iaa.AverageBlur(k=(3, 15)),
|
||||
iaa.MedianBlur(k=(3, 15)),
|
||||
iaa.MotionBlur((5, 25))
|
||||
])),
|
||||
iaa.Resize(scale_size, interpolation=ia.ALL),
|
||||
iaa.Sometimes(0.2, iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.1*255), per_channel=0.5)),
|
||||
iaa.Sometimes(0.7, iaa.JpegCompression(compression=(10, 65))),
|
||||
iaa.Resize(org_size),
|
||||
])
|
||||
|
||||
aug_img = aug_seq(images=x)
|
||||
return aug_img[0]
|
||||
|
||||
|
||||
def random_gray(x, p=0.5):
|
||||
"""input single RGB PIL Image instance"""
|
||||
x = np.array(x)
|
||||
x = x[np.newaxis, :, :, :]
|
||||
aug = iaa.Sometimes(p, iaa.Grayscale(alpha=1.0))
|
||||
aug_img = aug(images=x)
|
||||
return aug_img[0]
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
"""A modified image folder class
|
||||
|
||||
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
||||
so that this class can load images from both current directory and its subdirectories.
|
||||
"""
|
||||
|
||||
import torch.utils.data as data
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
||||
'.tif', '.TIF', '.tiff', '.TIFF',
|
||||
]
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset(dir, max_dataset_size=float("inf")):
|
||||
images = []
|
||||
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
||||
|
||||
for root, _, fnames in sorted(os.walk(dir)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
return images[:min(max_dataset_size, len(images))]
|
||||
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
|
||||
class ImageFolder(data.Dataset):
|
||||
|
||||
def __init__(self, root, transform=None, return_paths=False,
|
||||
loader=default_loader):
|
||||
imgs = make_dataset(root)
|
||||
if len(imgs) == 0:
|
||||
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
||||
"Supported image extensions are: " +
|
||||
",".join(IMG_EXTENSIONS)))
|
||||
|
||||
self.root = root
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.return_paths = return_paths
|
||||
self.loader = loader
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.imgs[index]
|
||||
img = self.loader(path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.return_paths:
|
||||
return img, path
|
||||
else:
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
@@ -0,0 +1,43 @@
|
||||
from data.base_dataset import BaseDataset, get_transform
|
||||
from data.image_folder import make_dataset
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
class SingleDataset(BaseDataset):
|
||||
"""This dataset class can load a set of images specified by the path --dataroot /path/to/data.
|
||||
|
||||
It can be used for generating CycleGAN results only for one side with the model option '-model test'.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this dataset class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
BaseDataset.__init__(self, opt)
|
||||
self.A_paths = sorted(make_dataset(opt.src_dir, opt.max_dataset_size))
|
||||
input_nc = self.opt.output_nc
|
||||
self.transform = get_transform(opt, grayscale=(input_nc == 1))
|
||||
self.opt = opt
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index - - a random integer for data indexing
|
||||
|
||||
Returns a dictionary that contains A and A_paths
|
||||
A(tensor) - - an image in one domain
|
||||
A_paths(str) - - the path of the image
|
||||
"""
|
||||
A_path = self.A_paths[index]
|
||||
A_img = Image.open(A_path).convert('RGB')
|
||||
A_img = A_img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
A = self.transform(A_img)
|
||||
return {'LR': A, 'LR_paths': A_path}
|
||||
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset."""
|
||||
return len(self.A_paths)
|
||||
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
from options.test_options import TestOptions
|
||||
from data import create_dataset
|
||||
from models import create_model
|
||||
from utils import utils
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
import cv2
|
||||
import glob
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = TestOptions()
|
||||
opt = opt.parse() # get test options
|
||||
opt.num_threads = 0 # test code only supports num_threads = 1
|
||||
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
|
||||
opt.no_flip = True
|
||||
|
||||
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
||||
model = create_model(opt) # create a model given opt.model and other options
|
||||
model.load_pretrain_models()
|
||||
|
||||
netP = model.netP
|
||||
model.eval()
|
||||
|
||||
to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
image_dir = "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan/"
|
||||
output_dir = "G:/VGGFace2-HQ/VGGface2_HQ_original_aligned_mask"
|
||||
|
||||
temp_path = os.path.join(image_dir,'*/')
|
||||
pathes = glob.glob(temp_path)
|
||||
dataset = []
|
||||
for dir_item in pathes:
|
||||
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
|
||||
print("processing %s"%dir_item,end='\r')
|
||||
temp_list = []
|
||||
for item in join_path:
|
||||
temp_list.append(item)
|
||||
dataset.append(temp_list)
|
||||
|
||||
# ------------------------ restore ------------------------
|
||||
for i_dir in dataset:
|
||||
path = os.path.dirname(i_dir[0])
|
||||
dir_name = os.path.join(output_dir, os.path.basename(path))
|
||||
if not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
|
||||
for img_path in i_dir:
|
||||
hr_img = Image.open(img_path).convert('RGB')
|
||||
inp = to_tensor(hr_img).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
parse_map, _ = netP(inp)
|
||||
parse_map_sm = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float()
|
||||
ref_parse_img = utils.color_parse_map(parse_map_sm)
|
||||
img_name = os.path.basename(img_path)
|
||||
basename, ext = os.path.splitext(img_name)
|
||||
save_face_name = f'{basename}.png'
|
||||
# print(save_face_name)
|
||||
save_path = os.path.join(dir_name, save_face_name)
|
||||
# os.makedirs(opt.save_masks_dir, exist_ok=True)
|
||||
img = cv2.cvtColor(ref_parse_img[0],cv2.COLOR_RGB2GRAY)
|
||||
cv2.imwrite(save_path,img)
|
||||
|
||||
# for i, data in tqdm(enumerate(dataset), total=len(dataset)//opt.batch_size):
|
||||
# inp = data['LR']
|
||||
# with torch.no_grad():
|
||||
# parse_map, _ = netP(inp)
|
||||
# parse_map_sm = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float()
|
||||
# img_path = data['LR_paths'] # get image paths
|
||||
# ref_parse_img = utils.color_parse_map(parse_map_sm)
|
||||
# for i in range(len(img_path)):
|
||||
# img_name = os.path.basename(img_path[i])
|
||||
# basename, ext = os.path.splitext(img_name)
|
||||
# save_face_name = f'{basename}.png'
|
||||
# # print(save_face_name)
|
||||
# save_path = os.path.join(opt.save_masks_dir, save_face_name)
|
||||
# os.makedirs(opt.save_masks_dir, exist_ok=True)
|
||||
# img = cv2.cvtColor(ref_parse_img[i],cv2.COLOR_RGB2GRAY)
|
||||
# cv2.imwrite(save_path,img)
|
||||
# save_img = Image.fromarray(ref_parse_img[i])
|
||||
# save_img.save(save_path)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
||||
|
||||
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
||||
You need to implement the following five functions:
|
||||
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
||||
-- <set_input>: unpack data from dataset and apply preprocessing.
|
||||
-- <forward>: produce intermediate results.
|
||||
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
||||
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
||||
|
||||
In the function <__init__>, you need to define four lists:
|
||||
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
||||
-- self.model_names (str list): define networks used in our training.
|
||||
-- self.visual_names (str list): specify the images that you want to display and save.
|
||||
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
||||
|
||||
Now you can use the model class by specifying flag '--model dummy'.
|
||||
See our template model class 'template_model.py' for more details.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from models.base_model import BaseModel
|
||||
|
||||
|
||||
def find_model_using_name(model_name):
|
||||
"""Import the module "models/[model_name]_model.py".
|
||||
|
||||
In the file, the class called DatasetNameModel() will
|
||||
be instantiated. It has to be a subclass of BaseModel,
|
||||
and it is case-insensitive.
|
||||
"""
|
||||
model_filename = "models." + model_name + "_model"
|
||||
modellib = importlib.import_module(model_filename)
|
||||
model = None
|
||||
target_model_name = model_name.replace('_', '') + 'model'
|
||||
for name, cls in modellib.__dict__.items():
|
||||
if name.lower() == target_model_name.lower() \
|
||||
and issubclass(cls, BaseModel):
|
||||
model = cls
|
||||
|
||||
if model is None:
|
||||
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
||||
exit(0)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_option_setter(model_name):
|
||||
"""Return the static method <modify_commandline_options> of the model class."""
|
||||
model_class = find_model_using_name(model_name)
|
||||
return model_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_model(opt):
|
||||
"""Create a model given the option.
|
||||
|
||||
This function warps the class CustomDatasetDataLoader.
|
||||
This is the main interface between this package and 'train.py'/'test.py'
|
||||
|
||||
Example:
|
||||
>>> from models import create_model
|
||||
>>> model = create_model(opt)
|
||||
"""
|
||||
model = find_model_using_name(opt.model)
|
||||
instance = model(opt)
|
||||
print("model [%s] was created" % type(instance).__name__)
|
||||
return instance
|
||||
@@ -0,0 +1,248 @@
|
||||
import os
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from . import networks
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""This class is an abstract base class (ABC) for models.
|
||||
To create a subclass, you need to implement the following five functions:
|
||||
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
||||
-- <set_input>: unpack data from dataset and apply preprocessing.
|
||||
-- <forward>: produce intermediate results.
|
||||
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
||||
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the BaseModel class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
|
||||
When creating your custom class, you need to implement your own initialization.
|
||||
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
|
||||
Then, you need to define four lists:
|
||||
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
||||
-- self.model_names (str list): specify the images that you want to display and save.
|
||||
-- self.visual_names (str list): define networks used in our training.
|
||||
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
||||
"""
|
||||
self.opt = opt
|
||||
self.gpu_ids = opt.gpu_ids
|
||||
self.isTrain = opt.isTrain
|
||||
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
||||
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
||||
|
||||
self.loss_names = []
|
||||
self.model_names = []
|
||||
self.visual_names = []
|
||||
self.optimizers = []
|
||||
self.image_paths = []
|
||||
self.metric = 0 # used for learning rate policy 'plateau'
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
"""Add new model-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input (dict): includes the data itself and its metadata information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimize_parameters(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
pass
|
||||
|
||||
def setup(self, opt):
|
||||
"""Load and print networks; create schedulers
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
if self.isTrain:
|
||||
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
||||
if not self.isTrain or opt.continue_train:
|
||||
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
|
||||
self.load_networks(load_suffix)
|
||||
self.print_networks(opt.verbose)
|
||||
|
||||
def eval(self):
|
||||
"""Make models eval mode during test time"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
net.eval()
|
||||
|
||||
def test(self):
|
||||
"""Forward function used in test time.
|
||||
|
||||
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
||||
It also calls <compute_visuals> to produce additional visualization results
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.forward()
|
||||
self.compute_visuals()
|
||||
|
||||
def compute_visuals(self):
|
||||
"""Calculate additional output images for visdom and HTML visualization"""
|
||||
pass
|
||||
|
||||
def get_image_paths(self):
|
||||
""" Return image paths that are used to load current data"""
|
||||
return self.image_paths
|
||||
|
||||
def update_learning_rate(self):
|
||||
"""Update learning rates for all the networks; called at the end of every epoch"""
|
||||
for scheduler in self.schedulers:
|
||||
if self.opt.lr_policy == 'plateau':
|
||||
scheduler.step(self.metric)
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
lr = self.optimizers[0].param_groups[0]['lr']
|
||||
print('learning rate = %.7f' % lr)
|
||||
|
||||
def get_lr(self,):
|
||||
lrs = {}
|
||||
for idx, p in enumerate(self.optimizers):
|
||||
lrs['LR{}'.format(idx)] = p.param_groups[0]['lr']
|
||||
return lrs
|
||||
|
||||
def get_current_visuals(self):
|
||||
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
||||
visual_ret = OrderedDict()
|
||||
for name in self.visual_names:
|
||||
if isinstance(name, str):
|
||||
visual_ret[name] = getattr(self, name)
|
||||
return visual_ret
|
||||
|
||||
def get_current_losses(self):
|
||||
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
||||
errors_ret = OrderedDict()
|
||||
for name in self.loss_names:
|
||||
if isinstance(name, str):
|
||||
errors_ret['Loss_' + name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
||||
return errors_ret
|
||||
|
||||
def save_networks(self, epoch, info=None):
|
||||
"""Save all the networks to the disk.
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||||
"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
save_filename = '%s_net_%s.pth' % (epoch, name)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
|
||||
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
||||
torch.save(net.module.cpu().state_dict(), save_path)
|
||||
print('Model saved in:', save_filename)
|
||||
net.cuda(self.gpu_ids[0])
|
||||
else:
|
||||
torch.save(net.cpu().state_dict(), save_path)
|
||||
if info is not None:
|
||||
torch.save(info, os.path.join(self.save_dir, '%s.info' % epoch))
|
||||
|
||||
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
||||
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
||||
key = keys[i]
|
||||
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'running_mean' or key == 'running_var'):
|
||||
if getattr(module, key) is None:
|
||||
state_dict.pop('.'.join(keys))
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'num_batches_tracked'):
|
||||
state_dict.pop('.'.join(keys))
|
||||
else:
|
||||
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
||||
|
||||
def load_networks(self, epoch):
|
||||
"""Load all the networks from the disk.
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||||
"""
|
||||
for name in self.load_model_names:
|
||||
if isinstance(name, str):
|
||||
load_filename = '%s_net_%s.pth' % (epoch, name)
|
||||
load_path = os.path.join(self.save_dir, load_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
if isinstance(net, torch.nn.DataParallel) or isinstance(net, torch.nn.parallel.DistributedDataParallel):
|
||||
net = net.module
|
||||
print('loading the model from %s' % load_path)
|
||||
# if you are using PyTorch newer than 0.4 (e.g., built from
|
||||
# GitHub source), you can remove str() on self.device
|
||||
map_location = str(self.device)
|
||||
|
||||
state_dict = torch.load(load_path, map_location=map_location)
|
||||
|
||||
# patch InstanceNorm checkpoints prior to 0.4
|
||||
# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
||||
# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
||||
# net.load_state_dict(state_dict)
|
||||
if not self.opt.no_strict_load:
|
||||
net.load_state_dict(state_dict)
|
||||
# Load partial weights
|
||||
else:
|
||||
model_dict = net.state_dict()
|
||||
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
net.load_state_dict(model_dict, strict=False)
|
||||
|
||||
info_path = os.path.join(self.save_dir, '%s.info' % epoch)
|
||||
if os.path.exists(info_path):
|
||||
info_dict = torch.load(info_path)
|
||||
for k, v in info_dict.items():
|
||||
setattr(self.opt, k, v)
|
||||
|
||||
def print_networks(self, verbose):
|
||||
"""Print the total number of parameters in the network and (if verbose) network architecture
|
||||
|
||||
Parameters:
|
||||
verbose (bool) -- if verbose: print the network architecture
|
||||
"""
|
||||
print('---------- Networks initialized -------------')
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
num_params = 0
|
||||
for param in net.parameters():
|
||||
num_params += param.numel()
|
||||
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
||||
print('-----------------------------------------------')
|
||||
|
||||
def set_requires_grad(self, nets, requires_grad=False):
|
||||
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
||||
Parameters:
|
||||
nets (network list) -- a list of networks
|
||||
requires_grad (bool) -- whether the networks require gradients or not
|
||||
"""
|
||||
if not isinstance(nets, list):
|
||||
nets = [nets]
|
||||
for net in nets:
|
||||
if net is not None:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
@@ -0,0 +1,131 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NormLayer(nn.Module):
|
||||
"""Normalization Layers.
|
||||
------------
|
||||
# Arguments
|
||||
- channels: input channels, for batch norm and instance norm.
|
||||
- input_size: input shape without batch size, for layer norm.
|
||||
"""
|
||||
def __init__(self, channels, normalize_shape=None, norm_type='bn'):
|
||||
super(NormLayer, self).__init__()
|
||||
norm_type = norm_type.lower()
|
||||
self.norm_type = norm_type
|
||||
self.channels = channels
|
||||
if norm_type == 'bn':
|
||||
self.norm = nn.BatchNorm2d(channels, affine=True)
|
||||
elif norm_type == 'in':
|
||||
self.norm = nn.InstanceNorm2d(channels, affine=False)
|
||||
elif norm_type == 'gn':
|
||||
self.norm = nn.GroupNorm(32, channels, affine=True)
|
||||
elif norm_type == 'pixel':
|
||||
self.norm = lambda x: F.normalize(x, p=2, dim=1)
|
||||
elif norm_type == 'layer':
|
||||
self.norm = nn.LayerNorm(normalize_shape)
|
||||
elif norm_type == 'none':
|
||||
self.norm = lambda x: x*1.0
|
||||
else:
|
||||
assert 1==0, 'Norm type {} not support.'.format(norm_type)
|
||||
|
||||
def forward(self, x, ref=None):
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class ReluLayer(nn.Module):
|
||||
"""Relu Layer.
|
||||
------------
|
||||
# Arguments
|
||||
- relu type: type of relu layer, candidates are
|
||||
- ReLU
|
||||
- LeakyReLU: default relu slope 0.2
|
||||
- PRelu
|
||||
- SELU
|
||||
- none: direct pass
|
||||
"""
|
||||
def __init__(self, channels, relu_type='relu'):
|
||||
super(ReluLayer, self).__init__()
|
||||
relu_type = relu_type.lower()
|
||||
if relu_type == 'relu':
|
||||
self.func = nn.ReLU(True)
|
||||
elif relu_type == 'leakyrelu':
|
||||
self.func = nn.LeakyReLU(0.2, inplace=True)
|
||||
elif relu_type == 'prelu':
|
||||
self.func = nn.PReLU(channels)
|
||||
elif relu_type == 'selu':
|
||||
self.func = nn.SELU(True)
|
||||
elif relu_type == 'none':
|
||||
self.func = lambda x: x*1.0
|
||||
else:
|
||||
assert 1==0, 'Relu type {} not support.'.format(relu_type)
|
||||
|
||||
def forward(self, x):
|
||||
return self.func(x)
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, scale='none', norm_type='none', relu_type='none', use_pad=True, bias=True):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.use_pad = use_pad
|
||||
self.norm_type = norm_type
|
||||
self.in_channels = in_channels
|
||||
if norm_type in ['bn']:
|
||||
bias = False
|
||||
|
||||
stride = 2 if scale == 'down' else 1
|
||||
self.scale = scale
|
||||
|
||||
self.scale_func = lambda x: x
|
||||
if scale == 'up':
|
||||
self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
|
||||
|
||||
self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.)/2)))
|
||||
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(2, 2)
|
||||
self.relu = ReluLayer(out_channels, relu_type)
|
||||
self.norm = NormLayer(out_channels, norm_type=norm_type)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.scale_func(x)
|
||||
if self.use_pad:
|
||||
out = self.reflection_pad(out)
|
||||
out = self.conv2d(out)
|
||||
if self.scale == 'down_avg':
|
||||
out = self.avgpool(out)
|
||||
out = self.norm(out)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""
|
||||
Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
|
||||
"""
|
||||
def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
if scale == 'none' and c_in == c_out:
|
||||
self.shortcut_func = lambda x: x
|
||||
else:
|
||||
self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
|
||||
|
||||
scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
|
||||
scale_conf = scale_config_dict[scale]
|
||||
|
||||
self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
|
||||
self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
|
||||
|
||||
def forward(self, x):
|
||||
identity = self.shortcut_func(x)
|
||||
|
||||
res = self.conv1(x)
|
||||
res = self.conv2(res)
|
||||
return identity + res
|
||||
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from models import loss
|
||||
from models import networks
|
||||
from .base_model import BaseModel
|
||||
from utils import utils
|
||||
|
||||
class EnhanceModel(BaseModel):
|
||||
|
||||
def modify_commandline_options(parser, is_train):
|
||||
if is_train:
|
||||
parser.add_argument('--parse_net_weight', type=str, default='./pretrain_models/parse_multi_iter_90000.pth', help='parse model path')
|
||||
parser.add_argument('--lambda_pix', type=float, default=10.0, help='weight for parsing map')
|
||||
parser.add_argument('--lambda_pcp', type=float, default=0.0, help='weight for vgg perceptual loss')
|
||||
parser.add_argument('--lambda_fm', type=float, default=10.0, help='weight for sr')
|
||||
parser.add_argument('--lambda_g', type=float, default=1.0, help='weight for sr')
|
||||
parser.add_argument('--lambda_ss', type=float, default=1000., help='weight for global style')
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
self.netP = networks.define_P(opt, weight_path=opt.parse_net_weight)
|
||||
self.netG = networks.define_G(opt, use_norm='spectral_norm')
|
||||
|
||||
if self.isTrain:
|
||||
self.netD = networks.define_D(opt, opt.Dinput_nc, use_norm='spectral_norm')
|
||||
self.vgg_model = loss.PCPFeat(weight_path='./pretrain_models/vgg19-dcbb9e9d.pth').to(opt.device)
|
||||
if len(opt.gpu_ids) > 0:
|
||||
self.vgg_model = torch.nn.DataParallel(self.vgg_model, opt.gpu_ids, output_device=opt.device)
|
||||
|
||||
self.model_names = ['G']
|
||||
self.loss_names = ['Pix', 'PCP', 'G', 'FM', 'D', 'SS'] # Generator loss, fm loss, parsing loss, discriminator loss
|
||||
self.visual_names = ['img_LR', 'img_HR', 'img_SR', 'ref_Parse', 'hr_mask']
|
||||
self.fm_weights = [1**x for x in range(opt.D_num)]
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G', 'D']
|
||||
self.load_model_names = ['G', 'D']
|
||||
|
||||
self.criterionParse = torch.nn.CrossEntropyLoss().to(opt.device)
|
||||
self.criterionFM = loss.FMLoss().to(opt.device)
|
||||
self.criterionGAN = loss.GANLoss(opt.gan_mode).to(opt.device)
|
||||
self.criterionPCP = loss.PCPLoss(opt)
|
||||
self.criterionPix= nn.L1Loss()
|
||||
self.criterionRS = loss.RegionStyleLoss()
|
||||
|
||||
self.optimizer_G = optim.Adam([p for p in self.netG.parameters() if p.requires_grad], lr=opt.g_lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizer_D = optim.Adam([p for p in self.netD.parameters() if p.requires_grad], lr=opt.d_lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers = [self.optimizer_G, self.optimizer_D]
|
||||
|
||||
def eval(self):
|
||||
self.netG.eval()
|
||||
self.netP.eval()
|
||||
|
||||
def load_pretrain_models(self,):
|
||||
self.netP.eval()
|
||||
print('Loading pretrained LQ face parsing network from', self.opt.parse_net_weight)
|
||||
if len(self.opt.gpu_ids) > 0:
|
||||
self.netP.module.load_state_dict(torch.load(self.opt.parse_net_weight))
|
||||
else:
|
||||
self.netP.load_state_dict(torch.load(self.opt.parse_net_weight))
|
||||
self.netG.eval()
|
||||
print('Loading pretrained PSFRGAN from', self.opt.psfr_net_weight)
|
||||
if len(self.opt.gpu_ids) > 0:
|
||||
self.netG.module.load_state_dict(torch.load(self.opt.psfr_net_weight), strict=False)
|
||||
else:
|
||||
self.netG.load_state_dict(torch.load(self.opt.psfr_net_weight), strict=False)
|
||||
|
||||
def set_input(self, input, cur_iters=None):
|
||||
self.cur_iters = cur_iters
|
||||
self.img_LR = input['LR'].to(self.opt.device)
|
||||
self.img_HR = input['HR'].to(self.opt.device)
|
||||
self.hr_mask = input['Mask'].to(self.opt.device)
|
||||
if self.opt.debug:
|
||||
print('SRNet input shape:', self.img_LR.shape, self.img_HR.shape)
|
||||
|
||||
def forward(self):
|
||||
with torch.no_grad():
|
||||
ref_mask, _ = self.netP(self.img_LR)
|
||||
self.ref_mask_onehot = (ref_mask == ref_mask.max(dim=1, keepdim=True)[0]).float().detach()
|
||||
|
||||
if self.opt.debug:
|
||||
print('SRNet reference mask shape:', self.ref_mask_onehot.shape)
|
||||
self.img_SR = self.netG(self.img_LR, self.ref_mask_onehot)
|
||||
|
||||
self.real_D_results = self.netD(torch.cat((self.img_HR, self.hr_mask), dim=1), return_feat=True)
|
||||
self.fake_D_results = self.netD(torch.cat((self.img_SR.detach(), self.hr_mask), dim=1), return_feat=False)
|
||||
self.fake_G_results = self.netD(torch.cat((self.img_SR, self.hr_mask), dim=1), return_feat=True)
|
||||
|
||||
self.img_SR_feats = self.vgg_model(self.img_SR)
|
||||
self.img_HR_feats = self.vgg_model(self.img_HR)
|
||||
|
||||
def backward_G(self):
|
||||
# Pix Loss
|
||||
self.loss_Pix = self.criterionPix(self.img_SR, self.img_HR) * self.opt.lambda_pix
|
||||
|
||||
# semantic style loss
|
||||
self.loss_SS = self.criterionRS(self.img_SR_feats, self.img_HR_feats, self.hr_mask) * self.opt.lambda_ss
|
||||
|
||||
# perceptual loss
|
||||
self.loss_PCP = self.criterionPCP(self.img_SR_feats, self.img_HR_feats) * self.opt.lambda_pcp
|
||||
|
||||
# Feature matching loss
|
||||
tmp_loss = 0
|
||||
for i, w in zip(range(self.opt.D_num), self.fm_weights):
|
||||
tmp_loss = tmp_loss + self.criterionFM(self.fake_G_results[i][1], self.real_D_results[i][1]) * w
|
||||
self.loss_FM = tmp_loss * self.opt.lambda_fm / self.opt.D_num
|
||||
|
||||
# Generator loss
|
||||
tmp_loss = 0
|
||||
for i in range(self.opt.D_num):
|
||||
tmp_loss = tmp_loss + self.criterionGAN(self.fake_G_results[i][0], True, for_discriminator=False)
|
||||
self.loss_G = tmp_loss * self.opt.lambda_g / self.opt.D_num
|
||||
|
||||
total_loss = self.loss_Pix + self.loss_PCP + self.loss_FM + self.loss_G + self.loss_SS
|
||||
total_loss.backward()
|
||||
|
||||
def backward_D(self, ):
|
||||
self.loss_D = 0
|
||||
for i in range(self.opt.D_num):
|
||||
self.loss_D += 0.5 * (self.criterionGAN(self.fake_D_results[i], False) + self.criterionGAN(self.real_D_results[i][0], True))
|
||||
self.loss_D /= self.opt.D_num
|
||||
self.loss_D.backward()
|
||||
|
||||
def optimize_parameters(self, ):
|
||||
# ---- Update G ------------
|
||||
self.optimizer_G.zero_grad()
|
||||
self.backward_G()
|
||||
self.optimizer_G.step()
|
||||
|
||||
# ---- Update D ------------
|
||||
self.optimizer_D.zero_grad()
|
||||
self.backward_D()
|
||||
self.optimizer_D.step()
|
||||
|
||||
def get_current_visuals(self, size=512):
|
||||
out = []
|
||||
visual_imgs = []
|
||||
out.append(utils.tensor_to_numpy(self.img_LR))
|
||||
out.append(utils.tensor_to_numpy(self.img_SR))
|
||||
out.append(utils.tensor_to_numpy(self.img_HR))
|
||||
|
||||
out_imgs = [utils.batch_numpy_to_image(x, size) for x in out]
|
||||
|
||||
visual_imgs += out_imgs
|
||||
visual_imgs.append(utils.color_parse_map(self.ref_mask_onehot, size))
|
||||
visual_imgs.append(utils.color_parse_map(self.hr_mask, size))
|
||||
|
||||
return visual_imgs
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
import torch
|
||||
from torchvision import models
|
||||
from utils import utils
|
||||
from torch import nn
|
||||
|
||||
|
||||
def tv_loss(x):
|
||||
"""
|
||||
Total Variation Loss.
|
||||
"""
|
||||
return torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])
|
||||
) + torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
|
||||
|
||||
|
||||
class PCPFeat(torch.nn.Module):
|
||||
"""
|
||||
Features used to calculate Perceptual Loss based on ResNet50 features.
|
||||
Input: (B, C, H, W), RGB, [0, 1]
|
||||
"""
|
||||
def __init__(self, weight_path, model='vgg'):
|
||||
super(PCPFeat, self).__init__()
|
||||
if model == 'vgg':
|
||||
self.model = models.vgg19(pretrained=False)
|
||||
self.build_vgg_layers()
|
||||
elif model == 'resnet':
|
||||
self.model = models.resnet50(pretrained=False)
|
||||
self.build_resnet_layers()
|
||||
|
||||
self.model.load_state_dict(torch.load(weight_path))
|
||||
self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def build_resnet_layers(self):
|
||||
self.layer1 = torch.nn.Sequential(
|
||||
self.model.conv1,
|
||||
self.model.bn1,
|
||||
self.model.relu,
|
||||
self.model.maxpool,
|
||||
self.model.layer1
|
||||
)
|
||||
self.layer2 = self.model.layer2
|
||||
self.layer3 = self.model.layer3
|
||||
self.layer4 = self.model.layer4
|
||||
self.features = torch.nn.ModuleList(
|
||||
[self.layer1, self.layer2, self.layer3, self.layer4]
|
||||
)
|
||||
|
||||
def build_vgg_layers(self):
|
||||
vgg_pretrained_features = self.model.features
|
||||
self.features = []
|
||||
feature_layers = [0, 3, 8, 17, 26, 35]
|
||||
for i in range(len(feature_layers)-1):
|
||||
module_layers = torch.nn.Sequential()
|
||||
for j in range(feature_layers[i], feature_layers[i+1]):
|
||||
module_layers.add_module(str(j), vgg_pretrained_features[j])
|
||||
self.features.append(module_layers)
|
||||
self.features = torch.nn.ModuleList(self.features)
|
||||
|
||||
def preprocess(self, x):
|
||||
x = (x + 1) / 2
|
||||
mean = torch.Tensor([0.485, 0.456, 0.406]).to(x)
|
||||
std = torch.Tensor([0.229, 0.224, 0.225]).to(x)
|
||||
mean = mean.view(1, 3, 1, 1)
|
||||
std = std.view(1, 3, 1, 1)
|
||||
x = (x - mean) / std
|
||||
if x.shape[3] < 224:
|
||||
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.preprocess(x)
|
||||
|
||||
features = []
|
||||
for m in self.features:
|
||||
x = m(x)
|
||||
features.append(x)
|
||||
return features
|
||||
|
||||
|
||||
class PCPLoss(torch.nn.Module):
|
||||
"""Perceptual Loss.
|
||||
"""
|
||||
def __init__(self,
|
||||
opt,
|
||||
layer=5,
|
||||
model='vgg',
|
||||
):
|
||||
super(PCPLoss, self).__init__()
|
||||
|
||||
self.mse = torch.nn.MSELoss()
|
||||
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
||||
|
||||
def forward(self, x_feats, y_feats):
|
||||
loss = 0
|
||||
for xf, yf, w in zip(x_feats, y_feats, self.weights):
|
||||
loss = loss + self.mse(xf, yf.detach()) * w
|
||||
return loss
|
||||
|
||||
|
||||
class FMLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = torch.nn.MSELoss()
|
||||
|
||||
def forward(self, x_feats, y_feats):
|
||||
loss = 0
|
||||
for xf, yf in zip(x_feats, y_feats):
|
||||
loss = loss + self.mse(xf, yf.detach())
|
||||
return loss
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
||||
""" Initialize the GANLoss class.
|
||||
Parameters:
|
||||
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
||||
target_real_label (bool) - - label for a real image
|
||||
target_fake_label (bool) - - label of a fake image
|
||||
Note: Do not use sigmoid as the last layer of Discriminator.
|
||||
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
||||
"""
|
||||
super(GANLoss, self).__init__()
|
||||
self.register_buffer('real_label', torch.tensor(target_real_label))
|
||||
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
||||
self.gan_mode = gan_mode
|
||||
if gan_mode == 'lsgan':
|
||||
self.loss = nn.MSELoss()
|
||||
elif gan_mode == 'vanilla':
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif gan_mode == 'hinge':
|
||||
pass
|
||||
elif gan_mode in ['wgangp']:
|
||||
self.loss = None
|
||||
else:
|
||||
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
||||
|
||||
def get_target_tensor(self, prediction, target_is_real):
|
||||
if target_is_real:
|
||||
target_tensor = self.real_label
|
||||
else:
|
||||
target_tensor = self.fake_label
|
||||
return target_tensor.expand_as(prediction)
|
||||
|
||||
def __call__(self, prediction, target_is_real, for_discriminator=True):
|
||||
"""Calculate loss given Discriminator's output and grount truth labels.
|
||||
Parameters:
|
||||
prediction (tensor) - - tpyically the prediction output from a discriminator
|
||||
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
||||
Returns:
|
||||
the calculated loss.
|
||||
"""
|
||||
if self.gan_mode in ['lsgan', 'vanilla']:
|
||||
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
||||
loss = self.loss(prediction, target_tensor)
|
||||
elif self.gan_mode == 'hinge':
|
||||
if for_discriminator:
|
||||
if target_is_real:
|
||||
loss = nn.ReLU()(1 - prediction).mean()
|
||||
else:
|
||||
loss = nn.ReLU()(1 + prediction).mean()
|
||||
else:
|
||||
assert target_is_real, "The generator's hinge loss must be aiming for real"
|
||||
loss = - prediction.mean()
|
||||
return loss
|
||||
|
||||
elif self.gan_mode == 'wgangp':
|
||||
if target_is_real:
|
||||
loss = -prediction.mean()
|
||||
else:
|
||||
loss = prediction.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class RegionStyleLoss(nn.Module):
|
||||
def __init__(self, reg_num=19, eps=1e-8):
|
||||
super().__init__()
|
||||
self.reg_num = reg_num
|
||||
self.eps = eps
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def __masked_gram_matrix(self, x, m):
|
||||
b, c, h, w = x.shape
|
||||
m = m.view(b, -1, h*w)
|
||||
x = x.view(b, -1, h*w)
|
||||
total_elements = m.sum(2) + self.eps
|
||||
|
||||
x = x * m
|
||||
G = torch.bmm(x, x.transpose(1, 2))
|
||||
return G / (c * total_elements.view(b, 1, 1))
|
||||
|
||||
def __layer_gram_matrix(self, x, mask):
|
||||
b, c, h, w = x.shape
|
||||
all_gm = []
|
||||
for i in range(self.reg_num):
|
||||
sub_mask = mask[:, i].unsqueeze(1)
|
||||
gram_matrix = self.__masked_gram_matrix(x, sub_mask)
|
||||
all_gm.append(gram_matrix)
|
||||
return torch.stack(all_gm, dim=1)
|
||||
|
||||
def forward(self, x_feats, y_feats, mask):
|
||||
loss = 0
|
||||
for xf, yf in zip(x_feats[2:], y_feats[2:]):
|
||||
tmp_mask = torch.nn.functional.interpolate(mask, xf.shape[2:])
|
||||
xf_gm = self.__layer_gram_matrix(xf, tmp_mask)
|
||||
yf_gm = self.__layer_gram_matrix(yf, tmp_mask)
|
||||
tmp_loss = self.mse(xf_gm, yf_gm.detach())
|
||||
loss = loss + tmp_loss
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
x = [
|
||||
torch.randn(2, 64, 512, 512),
|
||||
torch.randn(2, 128, 256, 256),
|
||||
torch.randn(2, 256, 128, 128),
|
||||
torch.randn(2, 512, 64, 64),
|
||||
torch.randn(2, 512, 32, 32),
|
||||
]
|
||||
|
||||
y = torch.randint(10, (2, 19, 512, 512)).float()
|
||||
loss = RegionStyleLoss()
|
||||
print(loss(x, x, y))
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
from models.blocks import *
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.optim import lr_scheduler
|
||||
from utils import utils
|
||||
import numpy as np
|
||||
|
||||
from models import psfrnet
|
||||
import torch.nn.utils as tutils
|
||||
from models.loss import PCPFeat
|
||||
|
||||
|
||||
def apply_norm(net, weight_norm_type):
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if weight_norm_type.lower() == 'spectral_norm':
|
||||
tutils.spectral_norm(m)
|
||||
elif weight_norm_type.lower() == 'weight_norm':
|
||||
tutils.weight_norm(m)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""Initialize network weights.
|
||||
Parameters:
|
||||
net (network) -- network to be initialized
|
||||
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
||||
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
||||
work better for some applications. Feel free to try yourself.
|
||||
"""
|
||||
def init_func(m): # define the initialization function
|
||||
classname = m.__class__.__name__
|
||||
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
||||
if init_type == 'normal':
|
||||
init.normal_(m.weight.data, 0.0, init_gain)
|
||||
elif init_type == 'xavier':
|
||||
init.xavier_normal_(m.weight.data, gain=init_gain)
|
||||
elif init_type == 'kaiming':
|
||||
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
init.orthogonal_(m.weight.data, gain=init_gain)
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
||||
init.normal_(m.weight.data, 1.0, init_gain)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
print('initialize network with %s' % init_type)
|
||||
net.apply(init_func) # apply the initialization function <init_func>
|
||||
|
||||
|
||||
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
||||
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
||||
Parameters:
|
||||
net (network) -- the network to be initialized
|
||||
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
||||
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
||||
Return an initialized network.
|
||||
"""
|
||||
if len(gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
net.to(gpu_ids[0])
|
||||
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
||||
init_weights(net, init_type, init_gain=init_gain)
|
||||
return net
|
||||
|
||||
|
||||
def get_scheduler(optimizer, opt):
|
||||
"""Return a learning rate scheduler
|
||||
Parameters:
|
||||
optimizer -- the optimizer of the network
|
||||
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
||||
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
||||
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
||||
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
||||
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
||||
See https://pytorch.org/docs/stable/optim.html for more details.
|
||||
"""
|
||||
if opt.lr_policy == 'linear':
|
||||
def lambda_rule(epoch):
|
||||
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
||||
return lr_l
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
||||
elif opt.lr_policy == 'step':
|
||||
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
||||
elif opt.lr_policy == 'plateau':
|
||||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
||||
elif opt.lr_policy == 'cosine':
|
||||
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
||||
else:
|
||||
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
||||
return scheduler
|
||||
|
||||
|
||||
def define_P(opt, in_size=512, out_size=512, min_feat_size=32, relu_type='LeakyReLU', isTrain=True, weight_path=None):
|
||||
net = ParseNet(in_size, out_size, min_feat_size, 64, 19, norm_type=opt.Pnorm, relu_type=relu_type, ch_range=[32, 256])
|
||||
if not isTrain:
|
||||
net.eval()
|
||||
if weight_path is not None:
|
||||
net.load_state_dict(torch.load(weight_path))
|
||||
if len(opt.gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
net.to(opt.device)
|
||||
net = torch.nn.DataParallel(net, opt.gpu_ids, output_device=opt.device)
|
||||
return net
|
||||
|
||||
|
||||
def define_G(opt, isTrain=True, use_norm='none', relu_type='LeakyReLU'):
|
||||
net = psfrnet.PSFRGenerator(3, 3, in_size=opt.Gin_size, out_size=opt.Gout_size, relu_type=relu_type, parse_ch=19, norm_type=opt.Gnorm)
|
||||
apply_norm(net, use_norm)
|
||||
if not isTrain:
|
||||
net.eval()
|
||||
if len(opt.gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
net.to(opt.device)
|
||||
net = torch.nn.DataParallel(net, opt.gpu_ids, output_device=opt.device)
|
||||
# init_weights(net, init_type='normal', init_gain=0.02)
|
||||
return net
|
||||
|
||||
|
||||
def define_D(opt, in_channel=3, isTrain=True, use_norm='none'):
|
||||
net = MultiScaleDiscriminator(in_channel, opt.ndf, opt.n_layers_D, opt.Dnorm, num_D=opt.D_num)
|
||||
apply_norm(net, use_norm)
|
||||
if not isTrain:
|
||||
net.eval()
|
||||
if len(opt.gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
net.to(opt.device)
|
||||
net = torch.nn.DataParallel(net, opt.gpu_ids, output_device=opt.device)
|
||||
init_weights(net, init_type='normal', init_gain=0.02)
|
||||
return net
|
||||
|
||||
|
||||
class ParseNet(nn.Module):
|
||||
def __init__(self,
|
||||
in_size=128,
|
||||
out_size=128,
|
||||
min_feat_size=32,
|
||||
base_ch=64,
|
||||
parsing_ch=19,
|
||||
res_depth=10,
|
||||
relu_type='prelu',
|
||||
norm_type='bn',
|
||||
ch_range=[32, 512],
|
||||
):
|
||||
super().__init__()
|
||||
self.res_depth = res_depth
|
||||
act_args = {'norm_type': norm_type, 'relu_type': relu_type}
|
||||
min_ch, max_ch = ch_range
|
||||
|
||||
ch_clip = lambda x: max(min_ch, min(x, max_ch))
|
||||
min_feat_size = min(in_size, min_feat_size)
|
||||
|
||||
down_steps = int(np.log2(in_size//min_feat_size))
|
||||
up_steps = int(np.log2(out_size//min_feat_size))
|
||||
|
||||
# =============== define encoder-body-decoder ====================
|
||||
self.encoder = []
|
||||
self.encoder.append(ConvLayer(3, base_ch, 3, 1))
|
||||
head_ch = base_ch
|
||||
for i in range(down_steps):
|
||||
cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
|
||||
self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
|
||||
head_ch = head_ch * 2
|
||||
|
||||
self.body = []
|
||||
for i in range(res_depth):
|
||||
self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
|
||||
|
||||
self.decoder = []
|
||||
for i in range(up_steps):
|
||||
cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
|
||||
self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
|
||||
head_ch = head_ch // 2
|
||||
|
||||
self.encoder = nn.Sequential(*self.encoder)
|
||||
self.body = nn.Sequential(*self.body)
|
||||
self.decoder = nn.Sequential(*self.decoder)
|
||||
self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
|
||||
self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.encoder(x)
|
||||
x = feat + self.body(feat)
|
||||
x = self.decoder(x)
|
||||
out_img = self.out_img_conv(x)
|
||||
out_mask = self.out_mask_conv(x)
|
||||
return out_mask, out_img
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
def __init__(self, input_ch, base_ch=64, n_layers=3, norm_type='none', relu_type='LeakyReLU', num_D=4):
|
||||
super().__init__()
|
||||
|
||||
self.D_pool = nn.ModuleList()
|
||||
for i in range(num_D):
|
||||
netD = NLayerDiscriminator(input_ch, base_ch, depth=n_layers, norm_type=norm_type, relu_type=relu_type)
|
||||
self.D_pool.append(netD)
|
||||
|
||||
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
|
||||
def forward(self, input, return_feat=False):
|
||||
results = []
|
||||
for netd in self.D_pool:
|
||||
output = netd(input, return_feat)
|
||||
results.append(output)
|
||||
# Downsample input
|
||||
input = self.downsample(input)
|
||||
return results
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
def __init__(self,
|
||||
input_ch = 3,
|
||||
base_ch = 64,
|
||||
max_ch = 1024,
|
||||
depth = 4,
|
||||
norm_type = 'none',
|
||||
relu_type = 'LeakyReLU',
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
nargs = {'norm_type': norm_type, 'relu_type': relu_type}
|
||||
self.norm_type = norm_type
|
||||
self.input_ch = input_ch
|
||||
|
||||
self.model = []
|
||||
self.model.append(ConvLayer(input_ch, base_ch, norm_type='none', relu_type=relu_type))
|
||||
for i in range(depth):
|
||||
cin = min(base_ch * 2**(i), max_ch)
|
||||
cout = min(base_ch * 2**(i+1), max_ch)
|
||||
self.model.append(ConvLayer(cin, cout, scale='down_avg', **nargs))
|
||||
self.model = nn.Sequential(*self.model)
|
||||
self.score_out = ConvLayer(cout, 1, use_pad=False)
|
||||
|
||||
def forward(self, x, return_feat=False):
|
||||
ret_feats = []
|
||||
for idx, m in enumerate(self.model):
|
||||
x = m(x)
|
||||
ret_feats.append(x)
|
||||
x = self.score_out(x)
|
||||
if return_feat:
|
||||
return x, ret_feats
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from utils import utils
|
||||
|
||||
class ParseModel(BaseModel):
|
||||
def modify_commandline_options(parser, is_train):
|
||||
if is_train:
|
||||
parser.add_argument('--parse_map', type=float, default=1.0, help='weight for parsing map')
|
||||
parser.add_argument('--parse_sr', type=float, default=1.0, help='weight for sr')
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this model class.
|
||||
|
||||
Parameters:
|
||||
opt -- training/test options
|
||||
|
||||
A few things can be done here.
|
||||
- (required) call the initialization function of BaseModel
|
||||
- define loss function, visualization images, model names, and optimizers
|
||||
"""
|
||||
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
|
||||
self.loss_names = ['P', 'SR']
|
||||
self.visual_names = ['img_LR', 'img_HR', 'gt_Parse', 'img_SR', 'pred_Parse']
|
||||
|
||||
self.model_names = ['P']
|
||||
self.netP = networks.define_P(opt)
|
||||
|
||||
if self.isTrain: # only defined during training time
|
||||
self.criterionParse = torch.nn.CrossEntropyLoss()
|
||||
self.criterionSR = torch.nn.L1Loss()
|
||||
self.optimizer = torch.optim.Adam(self.netP.parameters(), lr=opt.lr, betas=(0.9, 0.999))
|
||||
self.optimizers = [self.optimizer]
|
||||
|
||||
def set_input(self, input, cur_iters=None):
|
||||
self.img_LR = input['LR'].to(self.opt.device)
|
||||
self.img_HR = input['HR'].to(self.opt.device)
|
||||
self.gt_Parse = input['Mask'].to(self.opt.device)
|
||||
if self.opt.debug:
|
||||
print('ParseNet input shape:', self.img_LR.shape, self.img_HR.shape, self.gt_Parse.shape)
|
||||
|
||||
def load_pretrain_models(self,):
|
||||
self.netP.eval()
|
||||
print('Loading pretrained LQ face parsing network from', self.opt.parse_net_weight)
|
||||
self.netP.load_state_dict(torch.load(self.opt.parse_net_weight))
|
||||
|
||||
def forward(self):
|
||||
self.pred_Parse, self.img_SR = self.netP(self.img_LR)
|
||||
if self.opt.debug:
|
||||
print('ParseNet output shape', self.pred_Parse.shape, self.img_SR.shape)
|
||||
|
||||
def backward(self):
|
||||
self.loss_P = self.criterionParse(self.pred_Parse, self.gt_Parse) * self.opt.parse_map
|
||||
self.loss_SR = self.criterionSR(self.img_SR, self.img_HR) * self.opt.parse_sr
|
||||
|
||||
loss = self.loss_P + self.loss_SR
|
||||
loss.backward()
|
||||
|
||||
def optimize_parameters(self):
|
||||
self.optimizer.zero_grad() # clear network G's existing gradients
|
||||
self.backward() # calculate gradients for network G
|
||||
self.optimizer.step()
|
||||
|
||||
def get_current_visuals(self, size=512):
|
||||
out = []
|
||||
visual_imgs = []
|
||||
out.append(utils.tensor_to_numpy(self.img_LR))
|
||||
out.append(utils.tensor_to_numpy(self.img_SR))
|
||||
out.append(utils.tensor_to_numpy(self.img_HR))
|
||||
out_imgs = [utils.batch_numpy_to_image(x, size) for x in out]
|
||||
|
||||
visual_imgs.append(out_imgs[0])
|
||||
visual_imgs.append(out_imgs[1])
|
||||
visual_imgs.append(utils.color_parse_map(self.pred_Parse))
|
||||
visual_imgs.append(utils.color_parse_map(self.gt_Parse.unsqueeze(1)))
|
||||
visual_imgs.append(out_imgs[2])
|
||||
|
||||
return visual_imgs
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
import numpy as np
|
||||
from models.blocks import *
|
||||
|
||||
|
||||
class SPADENorm(nn.Module):
|
||||
def __init__(self, norm_nc, ref_nc, norm_type='spade', ksz=3):
|
||||
super().__init__()
|
||||
|
||||
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
||||
mid_c = 64
|
||||
|
||||
self.norm_type = norm_type
|
||||
if norm_type == 'spade':
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(ref_nc, mid_c, ksz, 1, ksz//2),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
self.gamma_conv = nn.Conv2d(mid_c, norm_nc, ksz, 1, ksz//2)
|
||||
self.beta_conv = nn.Conv2d(mid_c, norm_nc, ksz, 1, ksz//2)
|
||||
|
||||
def get_gamma_beta(self, x, conv, gamma_conv, beta_conv):
|
||||
act = conv(x)
|
||||
gamma = gamma_conv(act)
|
||||
beta = beta_conv(act)
|
||||
return gamma, beta
|
||||
|
||||
def forward(self, x, ref):
|
||||
normalized_input = self.param_free_norm(x)
|
||||
if x.shape[-1] != ref.shape[-1]:
|
||||
ref = nn.functional.interpolate(ref, x.shape[2:], mode='bicubic', align_corners=False)
|
||||
if self.norm_type == 'spade':
|
||||
gamma, beta = self.get_gamma_beta(ref, self.conv1, self.gamma_conv, self.beta_conv)
|
||||
return normalized_input * gamma + beta
|
||||
elif self.norm_type == 'in':
|
||||
return normalized_input
|
||||
|
||||
|
||||
class SPADEResBlock(nn.Module):
|
||||
def __init__(self, fin, fout, ref_nc, relu_type, norm_type='spade'):
|
||||
super().__init__()
|
||||
|
||||
fmiddle = min(fin, fout)
|
||||
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
||||
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
||||
|
||||
# define normalization layers
|
||||
self.norm_0 = SPADENorm(fmiddle, ref_nc, norm_type)
|
||||
self.norm_1 = SPADENorm(fmiddle, ref_nc, norm_type)
|
||||
self.relu = ReluLayer(fmiddle, relu_type)
|
||||
|
||||
def forward(self, x, ref):
|
||||
res = self.conv_0(self.relu(self.norm_0(x, ref)))
|
||||
res = self.conv_1(self.relu(self.norm_1(res, ref)))
|
||||
out = x + res
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class PSFRGenerator(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, in_size=512, out_size=512, min_feat_size=16, ngf=64, n_blocks=9, parse_ch=19, relu_type='relu',
|
||||
ch_range=[32, 1024], norm_type='spade'):
|
||||
super().__init__()
|
||||
|
||||
min_ch, max_ch = ch_range
|
||||
ch_clip = lambda x: max(min_ch, min(x, max_ch))
|
||||
get_ch = lambda size: ch_clip(1024*16//size)
|
||||
|
||||
self.const_input = nn.Parameter(torch.randn(1, get_ch(min_feat_size), min_feat_size, min_feat_size))
|
||||
up_steps = int(np.log2(out_size//min_feat_size))
|
||||
self.up_steps = up_steps
|
||||
|
||||
ref_ch = 19+3
|
||||
|
||||
head_ch = get_ch(min_feat_size)
|
||||
head = [
|
||||
nn.Conv2d(head_ch, head_ch, kernel_size=3, padding=1),
|
||||
SPADEResBlock(head_ch, head_ch, ref_ch, relu_type, norm_type),
|
||||
]
|
||||
|
||||
body = []
|
||||
for i in range(up_steps):
|
||||
cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
|
||||
body += [
|
||||
nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(cin, cout, kernel_size=3, padding=1),
|
||||
SPADEResBlock(cout, cout, ref_ch, relu_type, norm_type)
|
||||
)
|
||||
]
|
||||
head_ch = head_ch // 2
|
||||
|
||||
self.img_out = nn.Conv2d(ch_clip(head_ch), output_nc, kernel_size=3, padding=1)
|
||||
|
||||
self.head = nn.Sequential(*head)
|
||||
self.body = nn.Sequential(*body)
|
||||
self.upsample = nn.Upsample(scale_factor=2)
|
||||
|
||||
def forward_spade(self, net, x, ref):
|
||||
for m in net:
|
||||
x = self.forward_spade_m(m, x, ref)
|
||||
return x
|
||||
|
||||
def forward_spade_m(self, m, x, ref):
|
||||
if isinstance(m, SPADENorm) or isinstance(m, SPADEResBlock):
|
||||
x = m(x, ref)
|
||||
else:
|
||||
x = m(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, ref):
|
||||
b, c, h, w = x.shape
|
||||
const_input = self.const_input.repeat(b, 1, 1, 1)
|
||||
ref_input = torch.cat((x, ref), dim=1)
|
||||
|
||||
feat = self.forward_spade(self.head, const_input, ref_input)
|
||||
|
||||
for idx, m in enumerate(self.body):
|
||||
feat = self.forward_spade(m, feat, ref_input)
|
||||
|
||||
out_img = self.img_out(feat)
|
||||
|
||||
return out_img
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
x = torch.randn(2, 16, 567, 234)
|
||||
nearest_interpolate(x)
|
||||
@@ -0,0 +1 @@
|
||||
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
|
||||
@@ -0,0 +1,165 @@
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
from utils import utils
|
||||
import torch
|
||||
import models
|
||||
import data
|
||||
from utils import utils
|
||||
|
||||
|
||||
class BaseOptions():
|
||||
"""This class defines options used during both training and test time.
|
||||
|
||||
It also implements several helper functions such as parsing, printing, and saving the options.
|
||||
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Reset the class; indicates the class hasn't been initailized"""
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self, parser):
|
||||
"""Define the common options that are used in both training and test."""
|
||||
# basic parameters
|
||||
parser.add_argument('--dataroot', required=False, help='path to images')
|
||||
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
|
||||
parser.add_argument('--gpus', type=int, default=1, help='how many gpus to use')
|
||||
parser.add_argument('--seed', type=int, default=123, help='Random seed for training')
|
||||
parser.add_argument('--checkpoints_dir', type=str, default='./check_points', help='models are saved here')
|
||||
# model parameters
|
||||
parser.add_argument('--model', type=str, default='enhance', help='chooses which model to train [parse|enhance]')
|
||||
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
|
||||
parser.add_argument('--Dinput_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
|
||||
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
|
||||
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
||||
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
||||
parser.add_argument('--n_layers_D', type=int, default=4, help='downsampling layers in discriminator')
|
||||
parser.add_argument('--D_num', type=int, default=3, help='numbers of discriminators')
|
||||
|
||||
parser.add_argument('--Pnorm', type=str, default='bn', help='parsing net norm [in | bn| none]')
|
||||
parser.add_argument('--Gnorm', type=str, default='spade', help='generator norm [in | bn | none]')
|
||||
parser.add_argument('--Dnorm', type=str, default='in', help='discriminator norm [in | bn | none]')
|
||||
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
|
||||
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
||||
# dataset parameters
|
||||
parser.add_argument('--dataset_name', type=str, default='single', help='dataset name')
|
||||
parser.add_argument('--Pimg_size', type=int, default='512', help='image size for face parse net')
|
||||
parser.add_argument('--Gin_size', type=int, default='512', help='image size for face parse net')
|
||||
parser.add_argument('--Gout_size', type=int, default='512', help='image size for face parse net')
|
||||
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
||||
parser.add_argument('--num_threads', default=8, type=int, help='# threads for loading data')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='input batch size')
|
||||
parser.add_argument('--load_size', type=int, default=512, help='scale images to this size')
|
||||
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
|
||||
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
||||
parser.add_argument('--preprocess', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
|
||||
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
|
||||
# additional parameters
|
||||
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
||||
parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
|
||||
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
||||
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
|
||||
|
||||
parser.add_argument('--debug', action='store_true', help='if specified, set to debug mode')
|
||||
self.initialized = True
|
||||
return parser
|
||||
|
||||
def gather_options(self):
|
||||
"""Initialize our parser with basic options(only once).
|
||||
Add additional model-specific and dataset-specific options.
|
||||
These options are defined in the <modify_commandline_options> function
|
||||
in model and dataset classes.
|
||||
"""
|
||||
if not self.initialized: # check if it has been initialized
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = self.initialize(parser)
|
||||
|
||||
# get the basic options
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
# modify model-related parser options
|
||||
model_name = opt.model
|
||||
model_option_setter = models.get_option_setter(model_name)
|
||||
parser = model_option_setter(parser, self.isTrain)
|
||||
opt, _ = parser.parse_known_args() # parse again with new defaults
|
||||
|
||||
# modify dataset-related parser options
|
||||
dataset_name = opt.dataset_name
|
||||
dataset_option_setter = data.get_option_setter(dataset_name)
|
||||
parser = dataset_option_setter(parser, self.isTrain)
|
||||
|
||||
# save and return the parser
|
||||
self.parser = parser
|
||||
return parser.parse_args()
|
||||
|
||||
def print_options(self, opt):
|
||||
"""Print and save options
|
||||
|
||||
It will print both current options and default values(if different).
|
||||
It will save options into a text file / [checkpoints_dir] / opt.txt
|
||||
"""
|
||||
message = ''
|
||||
message += '----------------- Options ---------------\n'
|
||||
for k, v in sorted(vars(opt).items()):
|
||||
comment = ''
|
||||
default = self.parser.get_default(k)
|
||||
if v != default:
|
||||
comment = '\t[default: %s]' % str(default)
|
||||
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
||||
message += '----------------- End -------------------'
|
||||
print(message)
|
||||
|
||||
# save to the disk
|
||||
opt.expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
utils.mkdirs(opt.expr_dir)
|
||||
file_name = os.path.join(opt.expr_dir, '{}_opt.txt'.format(opt.phase))
|
||||
with open(file_name, 'wt') as opt_file:
|
||||
opt_file.write(message)
|
||||
opt_file.write('\n')
|
||||
|
||||
opt.log_dir = os.path.join(opt.checkpoints_dir, 'log_dir')
|
||||
utils.mkdirs(opt.log_dir)
|
||||
opt.log_archive = os.path.join(opt.checkpoints_dir, 'log_archive')
|
||||
utils.mkdirs(opt.log_archive)
|
||||
|
||||
def parse(self):
|
||||
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
|
||||
opt = self.gather_options()
|
||||
opt.isTrain = self.isTrain # train or test
|
||||
|
||||
if opt.debug:
|
||||
opt.name = 'debug'
|
||||
opt.save_iter_freq = 1
|
||||
opt.save_latest_freq = 1
|
||||
opt.visual_freq = 1
|
||||
opt.print_freq = 1
|
||||
|
||||
# Find avaliable GPUs automatically
|
||||
if opt.gpus > 0:
|
||||
opt.gpu_ids = utils.get_gpu_memory_map()[1][:opt.gpus]
|
||||
if not isinstance(opt.gpu_ids, list):
|
||||
opt.gpu_ids = [opt.gpu_ids]
|
||||
torch.cuda.set_device(opt.gpu_ids[0])
|
||||
opt.device = torch.device('cuda:{}'.format(opt.gpu_ids[0 % opt.gpus]))
|
||||
opt.data_device = torch.device('cuda:{}'.format(opt.gpu_ids[1 % opt.gpus]))
|
||||
else:
|
||||
opt.gpu_ids = []
|
||||
opt.device = torch.device('cpu')
|
||||
|
||||
# set random seeds to ensure reproducibility
|
||||
np.random.seed(opt.seed)
|
||||
random.seed(opt.seed)
|
||||
torch.manual_seed(opt.seed)
|
||||
torch.cuda.manual_seed_all(opt.seed)
|
||||
|
||||
# process opt.suffix
|
||||
if opt.suffix:
|
||||
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
||||
opt.name = opt.name + suffix
|
||||
|
||||
self.print_options(opt)
|
||||
|
||||
self.opt = opt
|
||||
return self.opt
|
||||
@@ -0,0 +1,30 @@
|
||||
from .base_options import BaseOptions
|
||||
|
||||
|
||||
class TestOptions(BaseOptions):
|
||||
"""This class includes test options.
|
||||
|
||||
It also includes shared options defined in BaseOptions.
|
||||
"""
|
||||
|
||||
def initialize(self, parser):
|
||||
parser = BaseOptions.initialize(self, parser) # define shared options
|
||||
parser.add_argument('--src_dir', type=str, default='G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan/n000002', help='source directory containing test images')
|
||||
parser.add_argument('--save_masks_dir', type=str, default='../datasets/FFHQ/masks512', help='path to save parsing masks for FFHQ')
|
||||
parser.add_argument('--test_img_path', type=str, default='', help='path for single image test')
|
||||
parser.add_argument('--test_upscale', type=float, default=1, help='upsample scale for single image test')
|
||||
parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
|
||||
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
|
||||
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
|
||||
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
||||
# Dropout and Batchnorm has different behavioir during training and test.
|
||||
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
|
||||
parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
|
||||
parser.add_argument('--parse_net_weight', type=str, default='./pretrain_models/parse_multi_iter_90000.pth', help='parse model path')
|
||||
parser.add_argument('--psfr_net_weight', type=str, default='./pretrain_models/psfrgan_epoch15_net_G.pth', help='parse model path')
|
||||
# rewrite devalue values
|
||||
# To avoid cropping, the load_size should be the same as crop_size
|
||||
parser.set_defaults(load_size=parser.get_default('crop_size'))
|
||||
self.isTrain = False
|
||||
|
||||
return parser
|
||||
@@ -0,0 +1,41 @@
|
||||
from .base_options import BaseOptions
|
||||
|
||||
class TrainOptions(BaseOptions):
|
||||
"""This class includes training options.
|
||||
|
||||
It also includes shared options defined in BaseOptions.
|
||||
"""
|
||||
|
||||
def initialize(self, parser):
|
||||
parser = BaseOptions.initialize(self, parser)
|
||||
# visdom and HTML visualization parameters
|
||||
parser.add_argument('--visual_freq', type=int, default=400, help='frequency of show training images in tensorboard')
|
||||
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
|
||||
# network saving and loading parameters
|
||||
parser.add_argument('--save_iter_freq', type=int, default=5000, help='frequency of saving the models')
|
||||
parser.add_argument('--save_latest_freq', type=int, default=500, help='save latest freq')
|
||||
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
|
||||
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
|
||||
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
|
||||
parser.add_argument('--no_strict_load', action='store_true', help='set strict load to false')
|
||||
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
|
||||
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
||||
# training parameters
|
||||
parser.add_argument('--resume_epoch', type=int, default=0, help='training resume epoch')
|
||||
parser.add_argument('--resume_iter', type=int, default=0, help='training resume iter')
|
||||
parser.add_argument('--total_epochs', type=int, default=50, help='# of epochs to train')
|
||||
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
|
||||
parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')
|
||||
parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
|
||||
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
|
||||
parser.add_argument('--g_lr', type=float, default=0.0001, help='generator learning rate')
|
||||
parser.add_argument('--d_lr', type=float, default=0.0004, help='discriminator learning rate')
|
||||
parser.add_argument('--gan_mode', type=str, default='hinge', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
|
||||
parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
|
||||
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
|
||||
parser.add_argument('--lr_decay_gamma', type=float, default=1, help='multiply by a gamma every lr_decay_iters iterations')
|
||||
|
||||
self.isTrain = True
|
||||
|
||||
return parser
|
||||
@@ -0,0 +1,11 @@
|
||||
torch==1.5.1
|
||||
torchvision==0.6.1
|
||||
tensorflow>=1.15.4
|
||||
tensorboard==1.15.0
|
||||
tensorboardX==2.1
|
||||
opencv-python
|
||||
dlib
|
||||
scikit-image==0.17.2
|
||||
scipy==1.4.1
|
||||
tqdm
|
||||
imgaug
|
||||
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
from options.test_options import TestOptions
|
||||
from data import create_dataset
|
||||
from models import create_model
|
||||
from utils import utils
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = TestOptions().parse() # get test options
|
||||
opt.num_threads = 0 # test code only supports num_threads = 1
|
||||
opt.batch_size = 4 # test code only supports batch_size = 1
|
||||
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
|
||||
opt.no_flip = True
|
||||
|
||||
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
||||
model = create_model(opt) # create a model given opt.model and other options
|
||||
model.load_pretrain_models()
|
||||
|
||||
save_dir = opt.results_dir
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
print('creating result directory', save_dir)
|
||||
netP = model.netP
|
||||
netG = model.netG
|
||||
model.eval()
|
||||
max_size = 9999
|
||||
os.makedirs(os.path.join(save_dir, 'sr'), exist_ok=True)
|
||||
for i, data in tqdm(enumerate(dataset), total=len(dataset)//opt.batch_size):
|
||||
inp = data['LR']
|
||||
with torch.no_grad():
|
||||
parse_map, _ = netP(inp)
|
||||
parse_map_sm = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float()
|
||||
output_SR = netG(inp, parse_map_sm)
|
||||
img_path = data['LR_paths'] # get image paths
|
||||
for i in tqdm(range(len(img_path))):
|
||||
inp_img = utils.batch_tensor_to_img(inp)
|
||||
output_sr_img = utils.batch_tensor_to_img(output_SR)
|
||||
ref_parse_img = utils.color_parse_map(parse_map_sm)
|
||||
|
||||
save_path = os.path.join(save_dir, 'lq', os.path.basename(img_path[i]))
|
||||
os.makedirs(os.path.join(save_dir, 'lq'), exist_ok=True)
|
||||
save_img = Image.fromarray(inp_img[i])
|
||||
save_img.save(save_path)
|
||||
|
||||
save_path = os.path.join(save_dir, 'hq', os.path.basename(img_path[i]))
|
||||
os.makedirs(os.path.join(save_dir, 'hq'), exist_ok=True)
|
||||
save_img = Image.fromarray(output_sr_img[i])
|
||||
save_img.save(save_path)
|
||||
|
||||
save_path = os.path.join(save_dir, 'parse', os.path.basename(img_path[i]))
|
||||
os.makedirs(os.path.join(save_dir, 'parse'), exist_ok=True)
|
||||
save_img = Image.fromarray(ref_parse_img[i])
|
||||
save_img.save(save_path)
|
||||
|
||||
if i > max_size: break
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
'''
|
||||
This script enhance images with unaligned faces in a folder and paste it back to the original place.
|
||||
'''
|
||||
import dlib
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from skimage import transform as trans
|
||||
from skimage import io
|
||||
|
||||
import torch
|
||||
from utils import utils
|
||||
from options.test_options import TestOptions
|
||||
from models import create_model
|
||||
|
||||
from test_enhance_single_unalign import *
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
# face_detector = dlib.get_frontal_face_detector()
|
||||
face_detector = dlib.cnn_face_detection_model_v1('./pretrain_models/mmod_human_face_detector.dat')
|
||||
lmk_predictor = dlib.shape_predictor('./pretrain_models/shape_predictor_5_face_landmarks.dat')
|
||||
template_path = './pretrain_models/FFHQ_template.npy'
|
||||
enhance_model = def_models(opt)
|
||||
|
||||
for img_name in os.listdir(opt.src_dir):
|
||||
img_path = os.path.join(opt.src_dir, img_name)
|
||||
save_current_dir = os.path.join(opt.results_dir, os.path.splitext(img_name)[0])
|
||||
os.makedirs(save_current_dir, exist_ok=True)
|
||||
print('======> Loading image', img_path)
|
||||
img = dlib.load_rgb_image(img_path)
|
||||
aligned_faces, tform_params = detect_and_align_faces(img, face_detector, lmk_predictor, template_path)
|
||||
# Save aligned LQ faces
|
||||
save_lq_dir = os.path.join(save_current_dir, 'LQ_faces')
|
||||
os.makedirs(save_lq_dir, exist_ok=True)
|
||||
print('======> Saving aligned LQ faces to', save_lq_dir)
|
||||
save_imgs(aligned_faces, save_lq_dir)
|
||||
|
||||
hq_faces, lq_parse_maps = enhance_faces(aligned_faces, enhance_model)
|
||||
# Save LQ parsing maps and enhanced faces
|
||||
save_parse_dir = os.path.join(save_current_dir, 'ParseMaps')
|
||||
save_hq_dir = os.path.join(save_current_dir, 'HQ')
|
||||
os.makedirs(save_parse_dir, exist_ok=True)
|
||||
os.makedirs(save_hq_dir, exist_ok=True)
|
||||
print('======> Save parsing map and the enhanced faces.')
|
||||
save_imgs(lq_parse_maps, save_parse_dir)
|
||||
save_imgs(hq_faces, save_hq_dir)
|
||||
|
||||
print('======> Paste the enhanced faces back to the original image.')
|
||||
hq_img = past_faces_back(img, hq_faces, tform_params, upscale=opt.test_upscale)
|
||||
final_save_path = os.path.join(save_current_dir, 'hq_final.jpg')
|
||||
print('======> Save final result to', final_save_path)
|
||||
io.imsave(final_save_path, hq_img)
|
||||
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
'''
|
||||
This script enhance all faces in one image with PSFR-GAN and paste it back to the original place.
|
||||
'''
|
||||
import dlib
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from skimage import transform as trans
|
||||
from skimage import io
|
||||
|
||||
import torch
|
||||
from utils import utils
|
||||
from options.test_options import TestOptions
|
||||
from models import create_model
|
||||
|
||||
|
||||
def detect_and_align_faces(img, face_detector, lmk_predictor, template_path, template_scale=2, size_threshold=999):
|
||||
align_out_size = (512, 512)
|
||||
ref_points = np.load(template_path) / template_scale
|
||||
|
||||
# Detect landmark points
|
||||
face_dets = face_detector(img, 1)
|
||||
assert len(face_dets) > 0, 'No faces detected'
|
||||
|
||||
aligned_faces = []
|
||||
tform_params = []
|
||||
for det in face_dets:
|
||||
if isinstance(face_detector, dlib.cnn_face_detection_model_v1):
|
||||
rec = det.rect # for cnn detector
|
||||
else:
|
||||
rec = det
|
||||
if rec.width() > size_threshold or rec.height() > size_threshold:
|
||||
print('Face is too large')
|
||||
break
|
||||
landmark_points = lmk_predictor(img, rec)
|
||||
single_points = []
|
||||
for i in range(5):
|
||||
single_points.append([landmark_points.part(i).x, landmark_points.part(i).y])
|
||||
single_points = np.array(single_points)
|
||||
tform = trans.SimilarityTransform()
|
||||
tform.estimate(single_points, ref_points)
|
||||
tmp_face = trans.warp(img, tform.inverse, output_shape=align_out_size, order=3)
|
||||
aligned_faces.append(tmp_face*255)
|
||||
tform_params.append(tform)
|
||||
return [aligned_faces, tform_params]
|
||||
|
||||
|
||||
def def_models(opt):
|
||||
model = create_model(opt)
|
||||
model.load_pretrain_models()
|
||||
model.netP.to(opt.device)
|
||||
model.netG.to(opt.device)
|
||||
return model
|
||||
|
||||
|
||||
def enhance_faces(LQ_faces, model):
|
||||
hq_faces = []
|
||||
lq_parse_maps = []
|
||||
for lq_face in tqdm(LQ_faces):
|
||||
with torch.no_grad():
|
||||
lq_tensor = torch.tensor(lq_face.transpose(2, 0, 1)) / 255. * 2 - 1
|
||||
lq_tensor = lq_tensor.unsqueeze(0).float().to(model.device)
|
||||
parse_map, _ = model.netP(lq_tensor)
|
||||
parse_map_onehot = (parse_map == parse_map.max(dim=1, keepdim=True)[0]).float()
|
||||
output_SR = model.netG(lq_tensor, parse_map_onehot)
|
||||
hq_faces.append(utils.tensor_to_img(output_SR))
|
||||
lq_parse_maps.append(utils.color_parse_map(parse_map_onehot)[0])
|
||||
return hq_faces, lq_parse_maps
|
||||
|
||||
|
||||
def past_faces_back(img, hq_faces, tform_params, upscale=1):
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.resize(img, (int(w*upscale), int(h*upscale)), interpolation=cv2.INTER_CUBIC)
|
||||
for hq_img, tform in tqdm(zip(hq_faces, tform_params), total=len(hq_faces)):
|
||||
tform.params[0:2,0:2] /= upscale
|
||||
back_img = trans.warp(hq_img/255., tform, output_shape=[int(h*upscale), int(w*upscale)], order=3) * 255
|
||||
|
||||
# blur mask to avoid border artifacts
|
||||
mask = (back_img == 0)
|
||||
mask = cv2.blur(mask.astype(np.float32), (5,5))
|
||||
mask = (mask > 0)
|
||||
img = img * mask + (1 - mask) * back_img
|
||||
return img.astype(np.uint8)
|
||||
|
||||
|
||||
def save_imgs(img_list, save_dir):
|
||||
for idx, img in enumerate(img_list):
|
||||
save_path = os.path.join(save_dir, '{:03d}.jpg'.format(idx))
|
||||
io.imsave(save_path, img.astype(np.uint8))
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
# face_detector = dlib.get_frontal_face_detector()
|
||||
face_detector = dlib.cnn_face_detection_model_v1('./pretrain_models/mmod_human_face_detector.dat')
|
||||
lmk_predictor = dlib.shape_predictor('./pretrain_models/shape_predictor_5_face_landmarks.dat')
|
||||
template_path = './pretrain_models/FFHQ_template.npy'
|
||||
|
||||
print('======> Loading images, crop and align faces.')
|
||||
img_path = opt.test_img_path
|
||||
img = dlib.load_rgb_image(img_path)
|
||||
aligned_faces, tform_params = detect_and_align_faces(img, face_detector, lmk_predictor, template_path)
|
||||
# Save aligned LQ faces
|
||||
save_lq_dir = os.path.join(opt.results_dir, 'LQ_faces')
|
||||
os.makedirs(save_lq_dir, exist_ok=True)
|
||||
print('======> Saving aligned LQ faces to', save_lq_dir)
|
||||
save_imgs(aligned_faces, save_lq_dir)
|
||||
|
||||
enhance_model = def_models(opt)
|
||||
hq_faces, lq_parse_maps = enhance_faces(aligned_faces, enhance_model)
|
||||
# Save LQ parsing maps and enhanced faces
|
||||
save_parse_dir = os.path.join(opt.results_dir, 'ParseMaps')
|
||||
save_hq_dir = os.path.join(opt.results_dir, 'HQ')
|
||||
os.makedirs(save_parse_dir, exist_ok=True)
|
||||
os.makedirs(save_hq_dir, exist_ok=True)
|
||||
print('======> Save parsing map and the enhanced faces.')
|
||||
save_imgs(lq_parse_maps, save_parse_dir)
|
||||
save_imgs(hq_faces, save_hq_dir)
|
||||
|
||||
print('======> Paste the enhanced faces back to the original image.')
|
||||
hq_img = past_faces_back(img, hq_faces, tform_params, upscale=opt.test_upscale)
|
||||
final_save_path = os.path.join(opt.results_dir, 'hq_final.jpg')
|
||||
print('======> Save final result to', final_save_path)
|
||||
io.imsave(final_save_path, hq_img)
|
||||
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
from utils.timer import Timer
|
||||
from utils.logger import Logger
|
||||
from utils import utils
|
||||
|
||||
from options.train_options import TrainOptions
|
||||
from data import create_dataset
|
||||
from models import create_model
|
||||
|
||||
import torch
|
||||
import os
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
def train(opt):
|
||||
|
||||
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
||||
dataset_size = len(dataset) # get the number of images in the dataset.
|
||||
print('The number of training images = %d' % dataset_size)
|
||||
|
||||
model = create_model(opt)
|
||||
model.setup(opt)
|
||||
|
||||
logger = Logger(opt)
|
||||
timer = Timer()
|
||||
|
||||
single_epoch_iters = (dataset_size // opt.batch_size)
|
||||
total_iters = opt.total_epochs * single_epoch_iters
|
||||
cur_iters = opt.resume_iter + opt.resume_epoch * single_epoch_iters
|
||||
start_iter = opt.resume_iter
|
||||
print('Start training from epoch: {:05d}; iter: {:07d}'.format(opt.resume_epoch, opt.resume_iter))
|
||||
for epoch in range(opt.resume_epoch, opt.total_epochs + 1):
|
||||
for i, data in enumerate(dataset, start=start_iter):
|
||||
cur_iters += 1
|
||||
logger.set_current_iter(cur_iters)
|
||||
# =================== load data ===============
|
||||
model.set_input(data, cur_iters)
|
||||
timer.update_time('DataTime')
|
||||
|
||||
# =================== model train ===============
|
||||
model.forward(), timer.update_time('Forward')
|
||||
model.optimize_parameters()
|
||||
loss = model.get_current_losses()
|
||||
loss.update(model.get_lr())
|
||||
logger.record_losses(loss)
|
||||
timer.update_time('Backward')
|
||||
|
||||
# =================== save model and visualize ===============
|
||||
if cur_iters % opt.print_freq == 0:
|
||||
print('Model log directory: {}'.format(opt.expr_dir))
|
||||
epoch_progress = '{:03d}|{:05d}/{:05d}'.format(epoch, i, single_epoch_iters)
|
||||
logger.printIterSummary(epoch_progress, cur_iters, total_iters, timer)
|
||||
|
||||
if cur_iters % opt.visual_freq == 0:
|
||||
visual_imgs = model.get_current_visuals()
|
||||
logger.record_images(visual_imgs)
|
||||
|
||||
if cur_iters % opt.save_iter_freq == 0:
|
||||
print('saving current model (epoch %d, iters %d)' % (epoch, cur_iters))
|
||||
save_suffix = 'iter_%d' % cur_iters
|
||||
info = {'resume_epoch': epoch, 'resume_iter': i+1}
|
||||
model.save_networks(save_suffix, info)
|
||||
|
||||
if cur_iters % opt.save_latest_freq == 0:
|
||||
print('saving the latest model (epoch %d, iters %d)' % (epoch, cur_iters))
|
||||
info = {'resume_epoch': epoch, 'resume_iter': i+1}
|
||||
model.save_networks('latest', info)
|
||||
|
||||
if i >= single_epoch_iters - 1:
|
||||
start_iter = 0
|
||||
break
|
||||
|
||||
# model.update_learning_rate()
|
||||
if opt.debug: break
|
||||
if opt.debug and epoch >= 0: break
|
||||
logger.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = TrainOptions().parse()
|
||||
train(opt)
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
from .utils import mkdirs
|
||||
from tensorboardX import SummaryWriter
|
||||
from datetime import datetime
|
||||
import socket
|
||||
import shutil
|
||||
|
||||
class Logger():
|
||||
def __init__(self, opts):
|
||||
time_stamp = '_{}'.format(datetime.now().strftime('%Y-%m-%d_%H:%M'))
|
||||
self.opts = opts
|
||||
self.log_dir = os.path.join(opts.log_dir, opts.name+time_stamp)
|
||||
self.phase_keys = ['train', 'val', 'test']
|
||||
self.iter_log = []
|
||||
self.epoch_log = OrderedDict()
|
||||
self.set_mode(opts.phase)
|
||||
|
||||
# check if exist previous log belong to the same experiment name
|
||||
exist_log = None
|
||||
for log_name in os.listdir(opts.log_dir):
|
||||
if opts.name in log_name:
|
||||
exist_log = log_name
|
||||
if exist_log is not None:
|
||||
old_dir = os.path.join(opts.log_dir, exist_log)
|
||||
archive_dir = os.path.join(opts.log_archive, exist_log)
|
||||
shutil.move(old_dir, archive_dir)
|
||||
|
||||
self.mk_log_file()
|
||||
|
||||
self.writer = SummaryWriter(self.log_dir)
|
||||
|
||||
def mk_log_file(self):
|
||||
mkdirs(self.log_dir)
|
||||
self.txt_files = OrderedDict()
|
||||
for i in self.phase_keys:
|
||||
self.txt_files[i] = os.path.join(self.log_dir, 'log_{}'.format(i))
|
||||
|
||||
def set_mode(self, mode):
|
||||
self.mode = mode
|
||||
self.epoch_log[mode] = []
|
||||
|
||||
def set_current_iter(self, cur_iter):
|
||||
self.cur_iter = cur_iter
|
||||
|
||||
def record_losses(self, items):
|
||||
"""
|
||||
iteration log: [iter][{key: value}]
|
||||
"""
|
||||
self.iter_log.append(items)
|
||||
for k, v in items.items():
|
||||
if 'loss' in k.lower():
|
||||
self.writer.add_scalar('loss/{}'.format(k), v, self.cur_iter)
|
||||
|
||||
def record_scalar(self, items):
|
||||
"""
|
||||
Add scalar records. item, {key: value}
|
||||
"""
|
||||
for i in items.keys():
|
||||
self.writer.add_scalar('{}'.format(i), items[i], self.cur_iter)
|
||||
|
||||
def record_images(self, visuals, nrow=6, tag='ckpt_image'):
|
||||
imgs = []
|
||||
max_len = visuals[0].shape[0]
|
||||
for i in range(nrow):
|
||||
if i >= max_len: continue
|
||||
tmp_imgs = [x[i] for x in visuals]
|
||||
imgs.append(np.hstack(tmp_imgs))
|
||||
imgs = np.vstack(imgs).astype(np.uint8)
|
||||
self.writer.add_image(tag, imgs, self.cur_iter, dataformats='HWC')
|
||||
|
||||
def record_text(self, tag, text):
|
||||
self.writer.add_text(tag, text)
|
||||
|
||||
def printIterSummary(self, epoch, cur_iters, total_it, timer):
|
||||
msg = '{}\nIter: [{}]{:03d}/{:03d}\t\t'.format(
|
||||
timer.to_string(total_it - cur_iters), epoch, cur_iters, total_it)
|
||||
for k, v in self.iter_log[-1].items():
|
||||
msg += '{}: {:.6f}\t'.format(k, v)
|
||||
print(msg + '\n')
|
||||
with open(self.txt_files[self.mode], 'a+') as f:
|
||||
f.write(msg + '\n')
|
||||
|
||||
def close(self):
|
||||
self.writer.export_scalars_to_json(os.path.join(self.log_dir, 'all_scalars.json'))
|
||||
self.writer.close()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
import time
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
|
||||
class Timer():
|
||||
def __init__(self):
|
||||
self.reset_timer()
|
||||
self.start = time.time()
|
||||
|
||||
def reset_timer(self):
|
||||
self.before = time.time()
|
||||
self.timer = OrderedDict()
|
||||
|
||||
def restart(self):
|
||||
self.before = time.time()
|
||||
|
||||
def update_time(self, key):
|
||||
self.timer[key] = time.time() - self.before
|
||||
self.before = time.time()
|
||||
|
||||
def to_string(self, iters_left, short=False):
|
||||
iter_total = sum(self.timer.values())
|
||||
msg = "{:%Y-%m-%d %H:%M:%S}\tElapse: {}\tTimeLeft: {}\t".format(
|
||||
datetime.datetime.now(),
|
||||
datetime.timedelta(seconds=round(time.time() - self.start)),
|
||||
datetime.timedelta(seconds=round(iter_total*iters_left))
|
||||
)
|
||||
if short:
|
||||
msg += '{}: {:.2f}s'.format('|'.join(self.timer.keys()), iter_total)
|
||||
else:
|
||||
msg += '\tIterTotal: {:.2f}s\t{}: {} '.format(iter_total,
|
||||
'|'.join(self.timer.keys()), ' '.join('{:.2f}s'.format(x) for x in self.timer.values()))
|
||||
return msg
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
from skimage import io
|
||||
from PIL import Image
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
# MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
|
||||
MASK_COLORMAP = [[0, 0, 0], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [0,0, 0], [0, 0, 0], [255, 255, 255], [255, 255, 255], [255, 255, 255], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]
|
||||
label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth']
|
||||
|
||||
def array_to_heatmap(x):
|
||||
x = (x - x.min()) / (x.max() - x.min()) * 255
|
||||
x = x.astype(np.uint8)
|
||||
return cv.applyColorMap(x.astype(np.uint8), cv.COLORMAP_RAINBOW)
|
||||
|
||||
def img_to_tensor(img_path, device, size=None, mode='rgb'):
|
||||
"""
|
||||
Read image from img_path, and convert to (C, H, W) tensor in range [-1, 1]
|
||||
"""
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
img = np.array(img)
|
||||
if mode=='bgr':
|
||||
img = img[..., ::-1]
|
||||
if size:
|
||||
img = cv.resize(img, size)
|
||||
img = img / 255 * 2 - 1
|
||||
img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
|
||||
return img_tensor.float()
|
||||
|
||||
def tensor_to_img(tensor, save_path=None, size=None, mode='RGB', normal=[-1, 1]):
|
||||
"""
|
||||
mode: RGB or L (gray image)
|
||||
Input: tensor with shape (C, H, W)
|
||||
Output: PIL Image
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
img_array = tensor.squeeze().data.cpu().numpy()
|
||||
if mode == 'RGB':
|
||||
img_array = img_array.transpose(1, 2, 0)
|
||||
|
||||
if size is not None:
|
||||
img_array = cv.resize(img_array, size, interpolation=cv.INTER_LINEAR)
|
||||
|
||||
if len(normal):
|
||||
img_array = (img_array - normal[0]) / (normal[1] - normal[0]) * 255
|
||||
img_array = img_array.clip(0, 255)
|
||||
|
||||
img_array = img_array.astype(np.uint8)
|
||||
if save_path:
|
||||
img = Image.fromarray(img_array, mode)
|
||||
img.save(save_path)
|
||||
|
||||
return img_array
|
||||
|
||||
def tensor_to_numpy(tensor):
|
||||
return tensor.data.cpu().numpy()
|
||||
|
||||
def batch_numpy_to_image(array, size=None):
|
||||
"""
|
||||
Input: numpy array (B, C, H, W) in [-1, 1]
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
|
||||
out_imgs = []
|
||||
array = np.clip((array + 1)/2 * 255, 0, 255)
|
||||
array = np.transpose(array, (0, 2, 3, 1))
|
||||
for i in range(array.shape[0]):
|
||||
if size is not None:
|
||||
tmp_array = cv.resize(array[i], size)
|
||||
else:
|
||||
tmp_array = array[i]
|
||||
out_imgs.append(tmp_array)
|
||||
return np.array(out_imgs).astype(np.uint8)
|
||||
|
||||
def batch_tensor_to_img(tensor, size=None):
|
||||
"""
|
||||
Input: (B, C, H, W)
|
||||
Return: RGB image, [0, 255]
|
||||
"""
|
||||
arrays = tensor_to_numpy(tensor)
|
||||
out_imgs = batch_numpy_to_image(arrays, size)
|
||||
return out_imgs
|
||||
|
||||
def color_parse_map(tensor, size=None):
|
||||
"""
|
||||
input: tensor or batch tensor
|
||||
return: colorized parsing maps
|
||||
"""
|
||||
if len(tensor.shape) < 4:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
if tensor.shape[1] > 1:
|
||||
tensor = tensor.argmax(dim=1)
|
||||
|
||||
tensor = tensor.squeeze(1).data.cpu().numpy()
|
||||
color_maps = []
|
||||
for t in tensor:
|
||||
tmp_img = np.zeros(tensor.shape[1:] + (3,))
|
||||
for idx, color in enumerate(MASK_COLORMAP):
|
||||
tmp_img[t == idx] = color
|
||||
if size is not None:
|
||||
tmp_img = cv.resize(tmp_img, (size, size))
|
||||
color_maps.append(tmp_img.astype(np.uint8))
|
||||
return color_maps
|
||||
|
||||
def onehot_parse_map(img):
|
||||
"""
|
||||
input: RGB color parse map
|
||||
output: one hot encoding of parse map
|
||||
"""
|
||||
n_label = len(MASK_COLORMAP)
|
||||
img = np.array(img, dtype=np.uint8)
|
||||
h, w = img.shape[:2]
|
||||
onehot_label = np.zeros((n_label, h, w))
|
||||
colormap = np.array(MASK_COLORMAP).reshape(n_label, 1, 1, 3)
|
||||
colormap = np.tile(colormap, (1, h, w, 1))
|
||||
for idx, color in enumerate(MASK_COLORMAP):
|
||||
tmp_label = colormap[idx] == img
|
||||
onehot_label[idx] = tmp_label[..., 0] * tmp_label[..., 1] * tmp_label[..., 2]
|
||||
return onehot_label
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
else:
|
||||
if not os.path.exists(paths):
|
||||
os.makedirs(paths)
|
||||
|
||||
|
||||
def get_gpu_memory_map():
|
||||
"""Get the current gpu usage within visible cuda devices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Memory Map: dict
|
||||
Keys are device ids as integers.
|
||||
Values are memory usage as integers in MB.
|
||||
Device Ids: gpu ids sorted in descending order according to the available memory.
|
||||
"""
|
||||
result = subprocess.check_output(
|
||||
[
|
||||
'nvidia-smi', '--query-gpu=memory.used',
|
||||
'--format=csv,nounits,noheader'
|
||||
]).decode('utf-8')
|
||||
# Convert lines into a dictionary
|
||||
gpu_memory = np.array([int(x) for x in result.strip().split('\n')])
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
visible_devices = sorted([int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',')])
|
||||
else:
|
||||
visible_devices = range(len(gpu_memory))
|
||||
gpu_memory_map = dict(zip(range(len(visible_devices)), gpu_memory[visible_devices]))
|
||||
return gpu_memory_map, sorted(gpu_memory_map, key=gpu_memory_map.get)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
hm = torch.randn(32, 68, 128, 128).cuda()
|
||||
flip(hm, 2)
|
||||
x = torch.ones(32, 68)
|
||||
y = torch.ones(32, 68)
|
||||
print(get_gpu_memory_map())
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: filter.py
|
||||
# Created Date: Wednesday April 13th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 13th April 2022 3:49:23 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import cv2
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
class HighPass(nn.Module):
|
||||
def __init__(self, w_hpf, device):
|
||||
super(HighPass, self).__init__()
|
||||
self.filter = torch.tensor([[-1, -1, -1],
|
||||
[-1, 8., -1],
|
||||
[-1, -1, -1]]).to(device) / w_hpf
|
||||
|
||||
def forward(self, x):
|
||||
filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1)
|
||||
return F.conv2d(x, filter, padding=1, groups=x.size(1))
|
||||
|
||||
if __name__ == "__main__":
|
||||
transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
|
||||
img = "G:/swap_data/ID/2.jpg"
|
||||
attr = cv2.imread(img)
|
||||
attr = Image.fromarray(cv2.cvtColor(attr,cv2.COLOR_BGR2RGB))
|
||||
attr = transformer_Arcface(attr).unsqueeze(0)
|
||||
results = HighPass(0.5,torch.device("cpu"))(attr)
|
||||
|
||||
results = results * imagenet_std + imagenet_mean
|
||||
results = results.cpu().permute(0,2,3,1)[0,...]
|
||||
results = results.numpy()
|
||||
results = np.clip(results,0.0,1.0) * 255
|
||||
results = cv2.cvtColor(results,cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite("filter_results2.png",results)
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday February 13th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 4th March 2022 1:53:53 am
|
||||
# Last Modified: Monday, 18th April 2022 10:52:57 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -22,7 +22,7 @@ from thop import clever_format
|
||||
if __name__ == '__main__':
|
||||
#
|
||||
# script = "Generator_modulation_up"
|
||||
script = "Generator_Invobn_config3"
|
||||
script = "Generator_2mask"
|
||||
# script = "Generator_ori_modulation_config"
|
||||
# script = "Generator_ori_config"
|
||||
class_name = "Generator"
|
||||
@@ -30,12 +30,13 @@ if __name__ == '__main__':
|
||||
model_config={
|
||||
"id_dim": 512,
|
||||
"g_kernel_size": 3,
|
||||
"in_channel":16,
|
||||
"res_num": 9,
|
||||
"in_channel":64,
|
||||
"res_num": 3,
|
||||
# "up_mode": "nearest",
|
||||
"up_mode": "bilinear",
|
||||
"aggregator": "eca_invo",
|
||||
"res_mode": "conv"
|
||||
"res_mode": "conv",
|
||||
"norm": "bn"
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: id_cos.py
|
||||
# Created Date: Friday March 25th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 29th March 2022 11:58:30 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
|
||||
from arcface_torch.backbones.iresnet import iresnet100
|
||||
|
||||
if __name__ == "__main__":
|
||||
imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
arcface_ckpt = "./arcface_ckpt/arcface_checkpoint.tar"
|
||||
arcface1 = torch.load(arcface_ckpt, map_location=torch.device("cpu"))
|
||||
arcface = arcface1['model'].module
|
||||
arcface.eval()
|
||||
|
||||
root1 = "G:/VGGFace2-HQ/VGGface2_ffhq_align_256_9_28_512_bygfpgan/n000002/"
|
||||
root2 = "G:/VGGFace2-HQ/VGGface2_None_norm_512_true_bygfpgan/n000002/"
|
||||
|
||||
# arcface_ckpt = "./arcface_torch/checkpoints/backbone.pth" # backbone.pth glint360k_cosface_r100_fp16_backbone.pth
|
||||
# arcface = iresnet100(pretrained=False, fp16=False)
|
||||
# arcface.load_state_dict(torch.load(arcface_ckpt, map_location='cpu'))
|
||||
# arcface.eval()
|
||||
|
||||
# id1 = "G:/swap_data/ID/hinton.jpg"
|
||||
# id2 = "G:/hififace-master/hififace-master/assets/inference_samples/hififace/img-172.jpg"
|
||||
id1 = root2 + "0003_01.jpg"
|
||||
id2 = root2 + "0036_01.jpg"
|
||||
|
||||
mode = "none"
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
# detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
# detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
id_img = cv2.imread(id1)
|
||||
# id_img_align_crop, _ = detect.get(id_img,256)
|
||||
# cv2.imwrite("id1_crop.png",id_img_align_crop[0])
|
||||
# id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img,cv2.COLOR_BGR2RGB))
|
||||
id_img = transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0)
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
# id_img = (id_img-0.5)*2.0
|
||||
latend_id = arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
|
||||
id_img2 = cv2.imread(id2)
|
||||
# id_img_align_crop2, _ = detect.get(id_img2,256)
|
||||
# cv2.imwrite("id2_crop.png",id_img_align_crop2[0])
|
||||
# id_img_align_crop_pil2 = Image.fromarray(cv2.cvtColor(id_img_align_crop2[0],cv2.COLOR_BGR2RGB))
|
||||
id_img_align_crop_pil2 = Image.fromarray(cv2.cvtColor(id_img2,cv2.COLOR_BGR2RGB))
|
||||
id_img2 = transformer_Arcface(id_img_align_crop_pil2)
|
||||
id_img2 = id_img2.unsqueeze(0)
|
||||
id_img2 = F.interpolate(id_img2,size=(112,112), mode='bicubic')
|
||||
# id_img2 = (id_img2-0.5)*2.0
|
||||
latend_id2 = arcface(id_img2)
|
||||
latend_id2 = F.normalize(latend_id2, p=2, dim=1)
|
||||
|
||||
cos_dis = 1 - cos_loss(latend_id, latend_id2)
|
||||
print("cosine similarity:", cos_dis.item())
|
||||
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
|
||||
"Alias-Free Generative Adversarial Networks"."""
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.fft
|
||||
from torch_utils.ops import upfirdn2d
|
||||
from . import metric_utils
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Utilities.
|
||||
|
||||
def sinc(x):
|
||||
y = (x * np.pi).abs()
|
||||
z = torch.sin(y) / y.clamp(1e-30, float('inf'))
|
||||
return torch.where(y < 1e-30, torch.ones_like(x), z)
|
||||
|
||||
def lanczos_window(x, a):
|
||||
x = x.abs() / a
|
||||
return torch.where(x < 1, sinc(x), torch.zeros_like(x))
|
||||
|
||||
def rotation_matrix(angle):
|
||||
angle = torch.as_tensor(angle).to(torch.float32)
|
||||
mat = torch.eye(3, device=angle.device)
|
||||
mat[0, 0] = angle.cos()
|
||||
mat[0, 1] = angle.sin()
|
||||
mat[1, 0] = -angle.sin()
|
||||
mat[1, 1] = angle.cos()
|
||||
return mat
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Apply integer translation to a batch of 2D images. Corresponds to the
|
||||
# operator T_x in Appendix E.1.
|
||||
|
||||
def apply_integer_translation(x, tx, ty):
|
||||
_N, _C, H, W = x.shape
|
||||
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
|
||||
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
|
||||
ix = tx.round().to(torch.int64)
|
||||
iy = ty.round().to(torch.int64)
|
||||
|
||||
z = torch.zeros_like(x)
|
||||
m = torch.zeros_like(x)
|
||||
if abs(ix) < W and abs(iy) < H:
|
||||
y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
|
||||
z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
|
||||
m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
|
||||
return z, m
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Apply integer translation to a batch of 2D images. Corresponds to the
|
||||
# operator T_x in Appendix E.2.
|
||||
|
||||
def apply_fractional_translation(x, tx, ty, a=3):
|
||||
_N, _C, H, W = x.shape
|
||||
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
|
||||
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
|
||||
ix = tx.floor().to(torch.int64)
|
||||
iy = ty.floor().to(torch.int64)
|
||||
fx = tx - ix
|
||||
fy = ty - iy
|
||||
b = a - 1
|
||||
|
||||
z = torch.zeros_like(x)
|
||||
zx0 = max(ix - b, 0)
|
||||
zy0 = max(iy - b, 0)
|
||||
zx1 = min(ix + a, 0) + W
|
||||
zy1 = min(iy + a, 0) + H
|
||||
if zx0 < zx1 and zy0 < zy1:
|
||||
taps = torch.arange(a * 2, device=x.device) - b
|
||||
filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
|
||||
filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
|
||||
y = x
|
||||
y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
|
||||
y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
|
||||
y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
|
||||
z[:, :, zy0:zy1, zx0:zx1] = y
|
||||
|
||||
m = torch.zeros_like(x)
|
||||
mx0 = max(ix + a, 0)
|
||||
my0 = max(iy + a, 0)
|
||||
mx1 = min(ix - b, 0) + W
|
||||
my1 = min(iy - b, 0) + H
|
||||
if mx0 < mx1 and my0 < my1:
|
||||
m[:, :, my0:my1, mx0:mx1] = 1
|
||||
return z, m
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Construct an oriented low-pass filter that applies the appropriate
|
||||
# bandlimit with respect to the input and output of the given affine 2D
|
||||
# image transformation.
|
||||
|
||||
def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
|
||||
assert a <= amax < aflt
|
||||
mat = torch.as_tensor(mat).to(torch.float32)
|
||||
|
||||
# Construct 2D filter taps in input & output coordinate spaces.
|
||||
taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
|
||||
yi, xi = torch.meshgrid(taps, taps)
|
||||
xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
|
||||
|
||||
# Convolution of two oriented 2D sinc filters.
|
||||
fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
|
||||
fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
|
||||
f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
|
||||
|
||||
# Convolution of two oriented 2D Lanczos windows.
|
||||
wi = lanczos_window(xi, a) * lanczos_window(yi, a)
|
||||
wo = lanczos_window(xo, a) * lanczos_window(yo, a)
|
||||
w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
|
||||
|
||||
# Construct windowed FIR filter.
|
||||
f = f * w
|
||||
|
||||
# Finalize.
|
||||
c = (aflt - amax) * up
|
||||
f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
|
||||
f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
|
||||
f = f / f.sum([0,2], keepdim=True) / (up ** 2)
|
||||
f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
|
||||
return f
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Apply the given affine transformation to a batch of 2D images.
|
||||
|
||||
def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
|
||||
_N, _C, H, W = x.shape
|
||||
mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
|
||||
|
||||
# Construct filter.
|
||||
f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
|
||||
assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
|
||||
p = f.shape[0] // 2
|
||||
|
||||
# Construct sampling grid.
|
||||
theta = mat.inverse()
|
||||
theta[:2, 2] *= 2
|
||||
theta[0, 2] += 1 / up / W
|
||||
theta[1, 2] += 1 / up / H
|
||||
theta[0, :] *= W / (W + p / up * 2)
|
||||
theta[1, :] *= H / (H + p / up * 2)
|
||||
theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
|
||||
g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
|
||||
|
||||
# Resample image.
|
||||
y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
|
||||
z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
|
||||
# Form mask.
|
||||
m = torch.zeros_like(y)
|
||||
c = p * 2 + 1
|
||||
m[:, :, c:-c, c:-c] = 1
|
||||
m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
|
||||
return z, m
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Apply fractional rotation to a batch of 2D images. Corresponds to the
|
||||
# operator R_\alpha in Appendix E.3.
|
||||
|
||||
def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
|
||||
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
|
||||
mat = rotation_matrix(angle)
|
||||
return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Modify the frequency content of a batch of 2D images as if they had undergo
|
||||
# fractional rotation -- but without actually rotating them. Corresponds to
|
||||
# the operator R^*_\alpha in Appendix E.3.
|
||||
|
||||
def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
|
||||
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
|
||||
mat = rotation_matrix(-angle)
|
||||
f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
|
||||
y = upfirdn2d.filter2d(x=x, f=f)
|
||||
m = torch.zeros_like(y)
|
||||
c = f.shape[0] // 2
|
||||
m[:, :, c:-c, c:-c] = 1
|
||||
return y, m
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Compute the selected equivariance metrics for the given generator.
|
||||
|
||||
def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
|
||||
assert compute_eqt_int or compute_eqt_frac or compute_eqr
|
||||
|
||||
# Setup generator and labels.
|
||||
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
||||
I = torch.eye(3, device=opts.device)
|
||||
M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
|
||||
if M is None:
|
||||
raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
|
||||
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
|
||||
|
||||
# Sampling loop.
|
||||
sums = None
|
||||
progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
|
||||
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
|
||||
progress.update(batch_start)
|
||||
s = []
|
||||
|
||||
# Randomize noise buffers, if any.
|
||||
for name, buf in G.named_buffers():
|
||||
if name.endswith('.noise_const'):
|
||||
buf.copy_(torch.randn_like(buf))
|
||||
|
||||
# Run mapping network.
|
||||
z = torch.randn([batch_size, G.z_dim], device=opts.device)
|
||||
c = next(c_iter)
|
||||
ws = G.mapping(z=z, c=c)
|
||||
|
||||
# Generate reference image.
|
||||
M[:] = I
|
||||
orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
||||
|
||||
# Integer translation (EQ-T).
|
||||
if compute_eqt_int:
|
||||
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
|
||||
t = (t * G.img_resolution).round() / G.img_resolution
|
||||
M[:] = I
|
||||
M[:2, 2] = -t
|
||||
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
||||
ref, mask = apply_integer_translation(orig, t[0], t[1])
|
||||
s += [(ref - img).square() * mask, mask]
|
||||
|
||||
# Fractional translation (EQ-T_frac).
|
||||
if compute_eqt_frac:
|
||||
t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
|
||||
M[:] = I
|
||||
M[:2, 2] = -t
|
||||
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
||||
ref, mask = apply_fractional_translation(orig, t[0], t[1])
|
||||
s += [(ref - img).square() * mask, mask]
|
||||
|
||||
# Rotation (EQ-R).
|
||||
if compute_eqr:
|
||||
angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
|
||||
M[:] = rotation_matrix(-angle)
|
||||
img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
|
||||
ref, ref_mask = apply_fractional_rotation(orig, angle)
|
||||
pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
|
||||
mask = ref_mask * pseudo_mask
|
||||
s += [(ref - pseudo).square() * mask, mask]
|
||||
|
||||
# Accumulate results.
|
||||
s = torch.stack([x.to(torch.float64).sum() for x in s])
|
||||
sums = sums + s if sums is not None else s
|
||||
progress.update(num_samples)
|
||||
|
||||
# Compute PSNRs.
|
||||
if opts.num_gpus > 1:
|
||||
torch.distributed.all_reduce(sums)
|
||||
sums = sums.cpu()
|
||||
mses = sums[0::2] / sums[1::2]
|
||||
psnrs = np.log10(2) * 20 - mses.log10() * 10
|
||||
psnrs = tuple(psnrs.numpy())
|
||||
return psnrs[0] if len(psnrs) == 1 else psnrs
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Frechet Inception Distance (FID) from the paper
|
||||
"GANs trained by a two time-scale update rule converge to a local Nash
|
||||
equilibrium". Matches the original implementation by Heusel et al. at
|
||||
https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
|
||||
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
from . import metric_utils
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def compute_fid(opts, max_real, num_gen, swav=False, sfid=False):
|
||||
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
||||
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
||||
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
||||
|
||||
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
|
||||
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
||||
rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, swav=swav, sfid=sfid).get_mean_cov()
|
||||
|
||||
mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
|
||||
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
||||
rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, swav=swav, sfid=sfid).get_mean_cov()
|
||||
|
||||
if opts.rank != 0:
|
||||
return float('nan')
|
||||
|
||||
m = np.square(mu_gen - mu_real).sum()
|
||||
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
||||
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
||||
return float(fid)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Inception Score (IS) from the paper "Improved techniques for training
|
||||
GANs". Matches the original implementation by Salimans et al. at
|
||||
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
|
||||
|
||||
import numpy as np
|
||||
from . import metric_utils
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def compute_is(opts, num_gen, num_splits):
|
||||
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
||||
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
||||
detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
|
||||
|
||||
gen_probs = metric_utils.compute_feature_stats_for_generator(
|
||||
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
||||
capture_all=True, max_items=num_gen).get_all()
|
||||
|
||||
if opts.rank != 0:
|
||||
return float('nan'), float('nan')
|
||||
|
||||
scores = []
|
||||
for i in range(num_splits):
|
||||
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
|
||||
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
|
||||
kl = np.mean(np.sum(kl, axis=1))
|
||||
scores.append(np.exp(kl))
|
||||
return float(np.mean(scores)), float(np.std(scores))
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
|
||||
GANs". Matches the original implementation by Binkowski et al. at
|
||||
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
|
||||
|
||||
import numpy as np
|
||||
from . import metric_utils
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
|
||||
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
||||
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
||||
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
||||
|
||||
real_features = metric_utils.compute_feature_stats_for_dataset(
|
||||
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
||||
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
|
||||
|
||||
gen_features = metric_utils.compute_feature_stats_for_generator(
|
||||
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
||||
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
|
||||
|
||||
if opts.rank != 0:
|
||||
return float('nan')
|
||||
|
||||
n = real_features.shape[1]
|
||||
m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
|
||||
t = 0
|
||||
for _subset_idx in range(num_subsets):
|
||||
x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
|
||||
y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
|
||||
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
|
||||
b = (x @ y.T / n + 1) ** 3
|
||||
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
|
||||
kid = t / num_subsets / m
|
||||
return float(kid)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,151 @@
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Main API for computing and reporting quality metrics."""
|
||||
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import torch
|
||||
import dnnlib
|
||||
|
||||
from . import metric_utils
|
||||
from . import frechet_inception_distance
|
||||
from . import kernel_inception_distance
|
||||
from . import precision_recall
|
||||
from . import perceptual_path_length
|
||||
from . import inception_score
|
||||
from . import equivariance
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_metric_dict = dict() # name => fn
|
||||
|
||||
def register_metric(fn):
|
||||
assert callable(fn)
|
||||
_metric_dict[fn.__name__] = fn
|
||||
return fn
|
||||
|
||||
def is_valid_metric(metric):
|
||||
return metric in _metric_dict
|
||||
|
||||
def list_valid_metrics():
|
||||
return list(_metric_dict.keys())
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
|
||||
assert is_valid_metric(metric)
|
||||
opts = metric_utils.MetricOptions(**kwargs)
|
||||
|
||||
# Calculate.
|
||||
start_time = time.time()
|
||||
results = _metric_dict[metric](opts)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Broadcast results.
|
||||
for key, value in list(results.items()):
|
||||
if opts.num_gpus > 1:
|
||||
value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
|
||||
torch.distributed.broadcast(tensor=value, src=0)
|
||||
value = float(value.cpu())
|
||||
results[key] = value
|
||||
|
||||
# Decorate with metadata.
|
||||
return dnnlib.EasyDict(
|
||||
results = dnnlib.EasyDict(results),
|
||||
metric = metric,
|
||||
total_time = total_time,
|
||||
total_time_str = dnnlib.util.format_time(total_time),
|
||||
num_gpus = opts.num_gpus,
|
||||
)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
|
||||
metric = result_dict['metric']
|
||||
assert is_valid_metric(metric)
|
||||
if run_dir is not None and snapshot_pkl is not None:
|
||||
snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
|
||||
|
||||
jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
|
||||
print(jsonl_line)
|
||||
if run_dir is not None and os.path.isdir(run_dir):
|
||||
with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
|
||||
f.write(jsonl_line + '\n')
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Recommended metrics.
|
||||
|
||||
@register_metric
|
||||
def fid50k_full(opts):
|
||||
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
||||
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
|
||||
return dict(fid50k_full=fid)
|
||||
|
||||
@register_metric
|
||||
def fid10k_full(opts):
|
||||
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
||||
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=10000)
|
||||
return dict(fid10k_full=fid)
|
||||
|
||||
@register_metric
|
||||
def kid50k_full(opts):
|
||||
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
||||
kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
||||
return dict(kid50k_full=kid)
|
||||
|
||||
@register_metric
|
||||
def pr50k3_full(opts):
|
||||
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
||||
precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
||||
return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
|
||||
|
||||
@register_metric
|
||||
def ppl2_wend(opts):
|
||||
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
|
||||
return dict(ppl2_wend=ppl)
|
||||
|
||||
@register_metric
|
||||
def eqt50k_int(opts):
|
||||
opts.G_kwargs.update(force_fp32=True)
|
||||
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
|
||||
return dict(eqt50k_int=psnr)
|
||||
|
||||
@register_metric
|
||||
def eqt50k_frac(opts):
|
||||
opts.G_kwargs.update(force_fp32=True)
|
||||
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
|
||||
return dict(eqt50k_frac=psnr)
|
||||
|
||||
@register_metric
|
||||
def eqr50k(opts):
|
||||
opts.G_kwargs.update(force_fp32=True)
|
||||
psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
|
||||
return dict(eqr50k=psnr)
|
||||
|
||||
# Legacy metrics.
|
||||
|
||||
@register_metric
|
||||
def fid50k(opts):
|
||||
opts.dataset_kwargs.update(max_size=None)
|
||||
fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
|
||||
return dict(fid50k=fid)
|
||||
|
||||
@register_metric
|
||||
def kid50k(opts):
|
||||
opts.dataset_kwargs.update(max_size=None)
|
||||
kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
||||
return dict(kid50k=kid)
|
||||
|
||||
@register_metric
|
||||
def pr50k3(opts):
|
||||
opts.dataset_kwargs.update(max_size=None)
|
||||
precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
||||
return dict(pr50k3_precision=precision, pr50k3_recall=recall)
|
||||
|
||||
@register_metric
|
||||
def is50k(opts):
|
||||
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
||||
mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
|
||||
return dict(is50k_mean=mean, is50k_std=std)
|
||||
@@ -0,0 +1,298 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Miscellaneous utilities used internally by the quality metrics."""
|
||||
|
||||
import os
|
||||
import time
|
||||
import hashlib
|
||||
import pickle
|
||||
import copy
|
||||
import uuid
|
||||
import numpy as np
|
||||
import torch
|
||||
import dnnlib
|
||||
from tqdm import tqdm
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class MetricOptions:
|
||||
def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True, run_dir=None, cur_nimg=None, snapshot_pkl=None):
|
||||
assert 0 <= rank < num_gpus
|
||||
self.G = G
|
||||
self.G_kwargs = dnnlib.EasyDict(G_kwargs)
|
||||
self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
|
||||
self.num_gpus = num_gpus
|
||||
self.rank = rank
|
||||
self.device = device if device is not None else torch.device('cuda', rank)
|
||||
self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
|
||||
self.cache = cache
|
||||
self.run_dir = run_dir
|
||||
self.cur_nimg = cur_nimg
|
||||
self.snapshot_pkl = snapshot_pkl
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_feature_detector_cache = dict()
|
||||
|
||||
def get_feature_detector_name(url):
|
||||
return os.path.splitext(url.split('/')[-1])[0]
|
||||
|
||||
def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
|
||||
assert 0 <= rank < num_gpus
|
||||
key = (url, device)
|
||||
if key not in _feature_detector_cache:
|
||||
is_leader = (rank == 0)
|
||||
if not is_leader and num_gpus > 1:
|
||||
torch.distributed.barrier() # leader goes first
|
||||
with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
|
||||
_feature_detector_cache[key] = pickle.load(f).to(device)
|
||||
if is_leader and num_gpus > 1:
|
||||
torch.distributed.barrier() # others follow
|
||||
return _feature_detector_cache[key]
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def iterate_random_labels(opts, batch_size):
|
||||
if opts.G.c_dim == 0:
|
||||
c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
|
||||
while True:
|
||||
yield c
|
||||
else:
|
||||
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
||||
while True:
|
||||
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
|
||||
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
||||
yield c
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class FeatureStats:
|
||||
def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
|
||||
self.capture_all = capture_all
|
||||
self.capture_mean_cov = capture_mean_cov
|
||||
self.max_items = max_items
|
||||
self.num_items = 0
|
||||
self.num_features = None
|
||||
self.all_features = None
|
||||
self.raw_mean = None
|
||||
self.raw_cov = None
|
||||
|
||||
def set_num_features(self, num_features):
|
||||
if self.num_features is not None:
|
||||
assert num_features == self.num_features
|
||||
else:
|
||||
self.num_features = num_features
|
||||
self.all_features = []
|
||||
self.raw_mean = np.zeros([num_features], dtype=np.float64)
|
||||
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
|
||||
|
||||
def is_full(self):
|
||||
return (self.max_items is not None) and (self.num_items >= self.max_items)
|
||||
|
||||
def append(self, x):
|
||||
x = np.asarray(x, dtype=np.float32)
|
||||
assert x.ndim == 2
|
||||
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
|
||||
if self.num_items >= self.max_items:
|
||||
return
|
||||
x = x[:self.max_items - self.num_items]
|
||||
|
||||
self.set_num_features(x.shape[1])
|
||||
self.num_items += x.shape[0]
|
||||
if self.capture_all:
|
||||
self.all_features.append(x)
|
||||
if self.capture_mean_cov:
|
||||
x64 = x.astype(np.float64)
|
||||
self.raw_mean += x64.sum(axis=0)
|
||||
self.raw_cov += x64.T @ x64
|
||||
|
||||
def append_torch(self, x, num_gpus=1, rank=0):
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 2
|
||||
assert 0 <= rank < num_gpus
|
||||
if num_gpus > 1:
|
||||
ys = []
|
||||
for src in range(num_gpus):
|
||||
y = x.clone()
|
||||
torch.distributed.broadcast(y, src=src)
|
||||
ys.append(y)
|
||||
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
|
||||
self.append(x.cpu().numpy())
|
||||
|
||||
def get_all(self):
|
||||
assert self.capture_all
|
||||
return np.concatenate(self.all_features, axis=0)
|
||||
|
||||
def get_all_torch(self):
|
||||
return torch.from_numpy(self.get_all())
|
||||
|
||||
def get_mean_cov(self):
|
||||
assert self.capture_mean_cov
|
||||
mean = self.raw_mean / self.num_items
|
||||
cov = self.raw_cov / self.num_items
|
||||
cov = cov - np.outer(mean, mean)
|
||||
return mean, cov
|
||||
|
||||
def save(self, pkl_file):
|
||||
with open(pkl_file, 'wb') as f:
|
||||
pickle.dump(self.__dict__, f)
|
||||
|
||||
@staticmethod
|
||||
def load(pkl_file):
|
||||
with open(pkl_file, 'rb') as f:
|
||||
s = dnnlib.EasyDict(pickle.load(f))
|
||||
obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
|
||||
obj.__dict__.update(s)
|
||||
return obj
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class ProgressMonitor:
|
||||
def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
|
||||
self.tag = tag
|
||||
self.num_items = num_items
|
||||
self.verbose = verbose
|
||||
self.flush_interval = flush_interval
|
||||
self.progress_fn = progress_fn
|
||||
self.pfn_lo = pfn_lo
|
||||
self.pfn_hi = pfn_hi
|
||||
self.pfn_total = pfn_total
|
||||
self.start_time = time.time()
|
||||
self.batch_time = self.start_time
|
||||
self.batch_items = 0
|
||||
if self.progress_fn is not None:
|
||||
self.progress_fn(self.pfn_lo, self.pfn_total)
|
||||
|
||||
def update(self, cur_items):
|
||||
assert (self.num_items is None) or (cur_items <= self.num_items)
|
||||
if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
|
||||
return
|
||||
cur_time = time.time()
|
||||
total_time = cur_time - self.start_time
|
||||
time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
|
||||
if (self.verbose) and (self.tag is not None):
|
||||
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
|
||||
self.batch_time = cur_time
|
||||
self.batch_items = cur_items
|
||||
|
||||
if (self.progress_fn is not None) and (self.num_items is not None):
|
||||
self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
|
||||
|
||||
def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
|
||||
return ProgressMonitor(
|
||||
tag = tag,
|
||||
num_items = num_items,
|
||||
flush_interval = flush_interval,
|
||||
verbose = self.verbose,
|
||||
progress_fn = self.progress_fn,
|
||||
pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
|
||||
pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
|
||||
pfn_total = self.pfn_total,
|
||||
)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, swav=False, sfid=False, **stats_kwargs):
|
||||
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
||||
if data_loader_kwargs is None:
|
||||
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
|
||||
|
||||
# Try to lookup from cache.
|
||||
cache_file = None
|
||||
if opts.cache:
|
||||
det_name = get_feature_detector_name(detector_url)
|
||||
|
||||
# Choose cache file name.
|
||||
args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
|
||||
md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
|
||||
cache_tag = f'{dataset.name}-{det_name}-{md5.hexdigest()}'
|
||||
cache_file = os.path.join('.', 'dnnlib', 'gan-metrics', cache_tag + '.pkl')
|
||||
# cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
|
||||
|
||||
# Check if the file exists (all processes must agree).
|
||||
flag = os.path.isfile(cache_file) if opts.rank == 0 else False
|
||||
if opts.num_gpus > 1:
|
||||
flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
|
||||
torch.distributed.broadcast(tensor=flag, src=0)
|
||||
flag = (float(flag.cpu()) != 0)
|
||||
|
||||
# Load.
|
||||
if flag:
|
||||
return FeatureStats.load(cache_file)
|
||||
|
||||
print('Calculating the stats for this dataset the first time\n')
|
||||
print(f'Saving them to {cache_file}')
|
||||
|
||||
# Initialize.
|
||||
num_items = len(dataset)
|
||||
if max_items is not None:
|
||||
num_items = min(num_items, max_items)
|
||||
stats = FeatureStats(max_items=num_items, **stats_kwargs)
|
||||
progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
||||
|
||||
# get detector
|
||||
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
||||
|
||||
# Main loop.
|
||||
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
|
||||
for images, _labels in tqdm(torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs)):
|
||||
if images.shape[1] == 1:
|
||||
images = images.repeat([1, 3, 1, 1])
|
||||
|
||||
with torch.no_grad():
|
||||
features = detector(images.to(opts.device), **detector_kwargs)
|
||||
|
||||
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
||||
progress.update(stats.num_items)
|
||||
|
||||
# Save to cache.
|
||||
if cache_file is not None and opts.rank == 0:
|
||||
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
||||
temp_file = cache_file + '.' + uuid.uuid4().hex
|
||||
stats.save(temp_file)
|
||||
os.replace(temp_file, cache_file) # atomic
|
||||
return stats
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, swav=False, sfid=False, **stats_kwargs):
|
||||
if batch_gen is None:
|
||||
batch_gen = min(batch_size, 4)
|
||||
assert batch_size % batch_gen == 0
|
||||
|
||||
# Setup generator and labels.
|
||||
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
||||
c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
|
||||
|
||||
# Initialize.
|
||||
stats = FeatureStats(**stats_kwargs)
|
||||
assert stats.max_items is not None
|
||||
progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
||||
|
||||
# get detector
|
||||
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
||||
|
||||
# Main loop.
|
||||
while not stats.is_full():
|
||||
images = []
|
||||
for _i in range(batch_size // batch_gen):
|
||||
z = torch.randn([batch_gen, G.z_dim], device=opts.device)
|
||||
# img = G(z=z, c=next(c_iter), truncation_psi=0.1, **opts.G_kwargs)
|
||||
img = G(z=z, c=next(c_iter), **opts.G_kwargs)
|
||||
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
||||
images.append(img)
|
||||
images = torch.cat(images)
|
||||
if images.shape[1] == 1:
|
||||
images = images.repeat([1, 3, 1, 1])
|
||||
|
||||
with torch.no_grad():
|
||||
features = detector(images.to(opts.device), **detector_kwargs)
|
||||
|
||||
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
||||
progress.update(stats.num_items)
|
||||
return stats
|
||||
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
|
||||
Architecture for Generative Adversarial Networks". Matches the original
|
||||
implementation by Karras et al. at
|
||||
https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
from . import metric_utils
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
# Spherical interpolation of a batch of vectors.
|
||||
def slerp(a, b, t):
|
||||
a = a / a.norm(dim=-1, keepdim=True)
|
||||
b = b / b.norm(dim=-1, keepdim=True)
|
||||
d = (a * b).sum(dim=-1, keepdim=True)
|
||||
p = t * torch.acos(d)
|
||||
c = b - d * a
|
||||
c = c / c.norm(dim=-1, keepdim=True)
|
||||
d = a * torch.cos(p) + c * torch.sin(p)
|
||||
d = d / d.norm(dim=-1, keepdim=True)
|
||||
return d
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class PPLSampler(torch.nn.Module):
|
||||
def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
|
||||
assert space in ['z', 'w']
|
||||
assert sampling in ['full', 'end']
|
||||
super().__init__()
|
||||
self.G = copy.deepcopy(G)
|
||||
self.G_kwargs = G_kwargs
|
||||
self.epsilon = epsilon
|
||||
self.space = space
|
||||
self.sampling = sampling
|
||||
self.crop = crop
|
||||
self.vgg16 = copy.deepcopy(vgg16)
|
||||
|
||||
def forward(self, c):
|
||||
# Generate random latents and interpolation t-values.
|
||||
t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
|
||||
z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
|
||||
|
||||
# Interpolate in W or Z.
|
||||
if self.space == 'w':
|
||||
w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
|
||||
wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
|
||||
wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
|
||||
else: # space == 'z'
|
||||
zt0 = slerp(z0, z1, t.unsqueeze(1))
|
||||
zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
|
||||
wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
|
||||
|
||||
# Randomize noise buffers.
|
||||
for name, buf in self.G.named_buffers():
|
||||
if name.endswith('.noise_const'):
|
||||
buf.copy_(torch.randn_like(buf))
|
||||
|
||||
# Generate images.
|
||||
img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
|
||||
|
||||
# Center crop.
|
||||
if self.crop:
|
||||
assert img.shape[2] == img.shape[3]
|
||||
c = img.shape[2] // 8
|
||||
img = img[:, :, c*3 : c*7, c*2 : c*6]
|
||||
|
||||
# Downsample to 256x256.
|
||||
factor = self.G.img_resolution // 256
|
||||
if factor > 1:
|
||||
img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
|
||||
|
||||
# Scale dynamic range from [-1,1] to [0,255].
|
||||
img = (img + 1) * (255 / 2)
|
||||
if self.G.img_channels == 1:
|
||||
img = img.repeat([1, 3, 1, 1])
|
||||
|
||||
# Evaluate differential LPIPS.
|
||||
lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
|
||||
dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
|
||||
return dist
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
|
||||
vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
|
||||
vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
|
||||
|
||||
# Setup sampler and labels.
|
||||
sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
|
||||
sampler.eval().requires_grad_(False).to(opts.device)
|
||||
c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
|
||||
|
||||
# Sampling loop.
|
||||
dist = []
|
||||
progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
|
||||
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
|
||||
progress.update(batch_start)
|
||||
x = sampler(next(c_iter))
|
||||
for src in range(opts.num_gpus):
|
||||
y = x.clone()
|
||||
if opts.num_gpus > 1:
|
||||
torch.distributed.broadcast(y, src=src)
|
||||
dist.append(y)
|
||||
progress.update(num_samples)
|
||||
|
||||
# Compute PPL.
|
||||
if opts.rank != 0:
|
||||
return float('nan')
|
||||
dist = torch.cat(dist)[:num_samples].cpu().numpy()
|
||||
lo = np.percentile(dist, 1, interpolation='lower')
|
||||
hi = np.percentile(dist, 99, interpolation='higher')
|
||||
ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
|
||||
return float(ppl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Precision/Recall (PR) from the paper "Improved Precision and Recall
|
||||
Metric for Assessing Generative Models". Matches the original implementation
|
||||
by Kynkaanniemi et al. at
|
||||
https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
|
||||
|
||||
import torch
|
||||
from . import metric_utils
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
|
||||
assert 0 <= rank < num_gpus
|
||||
num_cols = col_features.shape[0]
|
||||
num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
|
||||
col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
|
||||
dist_batches = []
|
||||
for col_batch in col_batches[rank :: num_gpus]:
|
||||
dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
|
||||
for src in range(num_gpus):
|
||||
dist_broadcast = dist_batch.clone()
|
||||
if num_gpus > 1:
|
||||
torch.distributed.broadcast(dist_broadcast, src=src)
|
||||
dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
|
||||
return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
|
||||
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
|
||||
detector_kwargs = dict(return_features=True)
|
||||
|
||||
real_features = metric_utils.compute_feature_stats_for_dataset(
|
||||
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
||||
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
|
||||
|
||||
gen_features = metric_utils.compute_feature_stats_for_generator(
|
||||
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
||||
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
|
||||
|
||||
results = dict()
|
||||
for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
|
||||
kth = []
|
||||
for manifold_batch in manifold.split(row_batch_size):
|
||||
dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
|
||||
kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
|
||||
kth = torch.cat(kth) if opts.rank == 0 else None
|
||||
pred = []
|
||||
for probes_batch in probes.split(row_batch_size):
|
||||
dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
|
||||
pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
|
||||
results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
|
||||
return results['precision'], results['recall']
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
+6
-5
@@ -5,7 +5,7 @@
|
||||
# Created Date: Thursday February 10th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 3rd March 2022 6:44:57 pm
|
||||
# Last Modified: Sunday, 3rd April 2022 6:39:26 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -23,7 +23,7 @@ if __name__ == '__main__':
|
||||
# script = "Generator_modulation_up"
|
||||
# script = "Generator_modulation_up"
|
||||
# script = "Generator_Invobn_config3"
|
||||
script = "Generator_ori_config"
|
||||
script = "Generator_maskhead_config"
|
||||
# script = "Generator_ori_config"
|
||||
class_name = "Generator"
|
||||
arcface_ckpt= "arcface_ckpt/arcface_checkpoint.tar"
|
||||
@@ -31,11 +31,12 @@ if __name__ == '__main__':
|
||||
"id_dim": 512,
|
||||
"g_kernel_size": 3,
|
||||
"in_channel":64,
|
||||
"res_num": 9,
|
||||
"res_num": 0,
|
||||
# "up_mode": "nearest",
|
||||
"up_mode": "bilinear",
|
||||
"aggregator": "eca_invo",
|
||||
"res_mode": "eca_invo"
|
||||
"res_mode": "eca_invo",
|
||||
"norm": "bn"
|
||||
}
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
|
||||
print("GPU used : ", os.environ['CUDA_VISIBLE_DEVICES'])
|
||||
@@ -57,7 +58,7 @@ if __name__ == '__main__':
|
||||
id_latent = torch.rand((4,512)).cuda()
|
||||
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0]
|
||||
|
||||
attr = torch.rand((4,3,224,224)).cuda()
|
||||
attr = torch.rand((4,3,512,512)).cuda()
|
||||
|
||||
import datetime
|
||||
start_time = time.time()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user