This commit is contained in:
chenxuanhong
2022-01-10 15:03:58 +08:00
parent 573689a591
commit 3783ef0e75
57 changed files with 9520 additions and 0 deletions
+6
View File
@@ -112,3 +112,9 @@ dmypy.json
# Pyre type checker
.pyre/
/train_logs
/test_logs
/GUI
/benchmark
/reference
+1
View File
@@ -0,0 +1 @@
python GUI.py
+924
View File
@@ -0,0 +1,924 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: GUI copy 2.py
# Created Date: Wednesday December 22nd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 10th January 2022 1:47:55 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import sys
import time
import json
import tkinter
try:
import paramiko
except:
from pip._internal import main
main(['install', 'paramiko'])
import paramiko
import threading
import tkinter as tk
import tkinter.ttk as ttk
import subprocess
from pathlib import Path
#############################################################
# Predefined functions
#############################################################
def read_config(path):
with open(path,'r') as cf:
nodelocaltionstr = cf.read()
nodelocaltioninf = json.loads(nodelocaltionstr)
if isinstance(nodelocaltioninf,str):
nodelocaltioninf = json.loads(nodelocaltioninf)
return nodelocaltioninf
def write_config(path, info):
with open(path, 'w') as cf:
configjson = json.dumps(info, indent=4)
cf.writelines(configjson)
class fileUploaderClass(object):
def __init__(self,serverIp,userName,passWd,port=22):
self.__ip__ = serverIp
self.__userName__ = userName
self.__passWd__ = passWd
self.__port__ = port
self.__ssh__ = paramiko.SSHClient()
self.__ssh__.set_missing_host_key_policy(paramiko.AutoAddPolicy())
def sshScpPut(self,localFile,remoteFile):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
remoteDir = remoteFile.split("/")
if remoteFile[0]=='/':
sftp.chdir('/')
for item in remoteDir[0:-1]:
if item == "":
continue
try:
sftp.chdir(item)
except:
sftp.mkdir(item)
sftp.chdir(item)
sftp.put(localFile,remoteDir[-1])
sftp.close()
self.__ssh__.close()
print("[To %s]:%s remotefile:%s success"%(self.__ip__,localFile,remoteFile))
def sshScpPuts(self,localFiles,remoteFiles):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
for i_dir in range(len(remoteFiles)):
remoteDir = remoteFiles[i_dir].split("/")
if remoteFiles[i_dir][0]=='/':
sftp.chdir('/')
for item in remoteDir[0:-1]:
if item == "":
continue
try:
sftp.chdir(item)
except:
sftp.mkdir(item)
sftp.chdir(item)
sftp.put(localFiles[i_dir],remoteDir[-1])
print("[To %s]:%s remotefile:%s success"%(self.__ip__,localFiles[i_dir],remoteFiles[i_dir]))
sftp.close()
self.__ssh__.close()
def sshExec(self, cmd):
try:
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
_, stdout, _ = self.__ssh__.exec_command(cmd)
results = stdout.read().strip().decode('utf-8')
self.__ssh__.close()
return results
except Exception as e:
print(e)
finally:
self.__ssh__.close()
def sshScpGetNames(self,remoteDir):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
wocao = sftp.listdir(remoteDir)
# print(wocao.st_mtime)
roots = {}
for item in wocao:
wocao = sftp.stat(remoteDir+"/"+item)
roots[item] = {
"t":wocao.st_mtime,
"p":remoteDir+"/"+item
}
# temp= remoteDir+ "/"+item
# child_dirs = sftp.listdir(temp)
# child_dirs = ["save\\" +item + "\\" + i for i in child_dirs]
# list_name += child_dirs
sftp.close()
self.__ssh__.close()
return roots
def sshScpGetRNames(self,remoteDir):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
wocao = sftp.listdir(remoteDir)
# print(wocao.st_mtime)
roots = {}
for item in wocao:
wocao = sftp.stat(remoteDir+"/"+item)
roots[item] = {
"t":wocao.st_mtime,
"p":remoteDir+"/"+item
}
# temp= remoteDir+ "/"+item
# child_dirs = sftp.listdir(temp)
# child_dirs = ["save\\" +item + "\\" + i for i in child_dirs]
# list_name += child_dirs
sftp.close()
self.__ssh__.close()
return roots
def sshScpGet(self, remoteFile, localFile, showProgress=False):
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
if showProgress:
sftp.get(remoteFile, localFile,callback=self.__putCallBack__)
else:
sftp.get(remoteFile, localFile)
sftp.close()
self.__ssh__.close()
def sshScpGetFiles(self, remoteFiles, localFiles, showProgress=False):
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
for i in range(len(remoteFiles)):
if showProgress:
sftp.get(remoteFiles[i], localFiles[i],callback=self.__putCallBack__)
else:
sftp.get(remoteFiles[i], localFiles[i])
print("Get %s success!"%(remoteFiles[i]))
sftp.close()
self.__ssh__.close()
def sshScpGetDir(self, remoteDir, localDir, showProgress=False):
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
files = sftp.listdir(remoteDir)
for i_f in files:
i_remote_file = Path(remoteDir,i_f).as_posix()
local_file = Path(localDir,i_f)
if showProgress:
sftp.get(i_remote_file, local_file,callback=self.__putCallBack__)
else:
sftp.get(i_remote_file, local_file)
sftp.close()
self.__ssh__.close()
def __putCallBack__(self,transferred,total):
print("current transferred %.1f percent"%(transferred/total*100))
def sshScpRename(self, oldpath, newpath):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
sftp.rename(oldpath,newpath)
sftp.close()
self.__ssh__.close()
print("ssh oldpath:%s newpath:%s success"%(oldpath,newpath))
def sshScpDelete(self,path):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
sftp.remove(path)
sftp.close()
self.__ssh__.close()
print("ssh delete:%s success"%(path))
def sshScpDeleteDir(self,path):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
self.__rm__(sftp,path)
sftp.close()
self.__ssh__.close()
def __rm__(self,sftp,path):
try:
files = sftp.listdir(path=path)
print(files)
for f in files:
filepath = os.path.join(path, f).replace('\\','/')
self.__rm__(sftp,filepath)
sftp.rmdir(path)
print("ssh delete:%s success"%(path))
except:
print(path)
sftp.remove(path)
print("ssh delete:%s success"%(path))
class TextRedirector(object):
def __init__(self, widget, tag="stdout"):
self.widget = widget
self.tag = tag
def write(self, str):
self.widget.configure(state="normal")
self.widget.insert("end", str, (self.tag,))
self.widget.configure(state="disabled")
self.widget.see(tk.END)
def flush(self):
pass
#############################################################
# Main Class
#############################################################
class Application(tk.Frame):
tab_info = []
tab_body = None
current_index = 0
gui_root = "GUI/"
machine_json = gui_root + "machines.json"
filesynlogroot = "file_sync/"
filesynlogroot = gui_root + filesynlogroot
ignore_json = gui_root + "guiignore.json"
machine_list = []
machine_dict = {}
ignore_text={
"white_list":{
"extension":["py",
"yaml"
],
"file":[],
"path":[]
},
"black_list":{
"extension":[
"png",
"yaml"
],
"file":[],
"path":["save/", "GUI/",]
}
}
env_text={
"train_log_root":"./train_logs",
"test_log_root":"./test_logs",
"systemLog":"./system/system_log.log",
"dataset_paths":{
"train_dataset_root":"",
"val_dataset_root":"",
"test_dataset_root":""
},
"train_config_path":"./train_yamls",
"train_scripts_path":"./train_scripts",
"test_scripts_path":"./test_scripts",
"config_json_name":"model_config.json"
}
machine_text = {
"ip": "0.0.0.0",
"user": "username",
"port": 22,
"passwd": "12345678",
"path": "/path/to/remote_host",
"ckp_path":"save",
"logfilename": "filestate_machine0.json"
}
current_log = {}
def __init__(self, master=None):
tk.Frame.__init__(self, master,bg='black')
# self.font_size = 16
self.font_list = ("Times New Roman",14)
self.padx = 5
self.pady = 5
if not Path(self.gui_root).exists():
Path(self.gui_root).mkdir(parents=True)
self.window_init()
def __label_text__(self, usr, root):
return "User Name: %s\nWorkspace: %s"%(usr, root)
def window_init(self):
cwd = os.getcwd()
self.master.title('File Synchronize - %s'%cwd)
# self.master.iconbitmap('./utilities/_logo.ico')
self.master.geometry("{}x{}".format(640, 800))
font_list = self.font_list
try:
self.machines = read_config(self.machine_json)
except:
self.machine_list = [self.machine_text,]
write_config(self.machine_json,self.machine_list)
# subprocess.call("start %s"%self.machine_json, shell=True)
#################################################################################################
list_frame = tk.Frame(self.master)
list_frame.pack(fill="both", padx=5,pady=5)
list_frame.columnconfigure(0, weight=1)
list_frame.columnconfigure(1, weight=1)
list_frame.columnconfigure(2, weight=1)
self.mac_var = tkinter.StringVar()
self.list_com = ttk.Combobox(list_frame, textvariable=self.mac_var)
self.list_com.grid(row=0,column=0,sticky=tk.EW)
open_button = tk.Button(list_frame, text = "Update",
font=font_list, command = self.Machines_Update, bg='#F4A460', fg='#F5F5F5')
open_button.grid(row=0,column=1,sticky=tk.EW)
open_button = tk.Button(list_frame, text = "Machines",
font=font_list, command = self.MachineConfig, bg='#F4A460', fg='#F5F5F5')
open_button.grid(row=0,column=2,sticky=tk.EW)
#################################################################################################
self.mac_text = tk.StringVar()
mac_label = tk.Label(self.master, textvariable=self.mac_text,font=self.font_list,justify="left")
mac_label.pack(fill="both", padx=5,pady=5)
self.mac_text.set(self.list_com.get())
self.machines_update()
def xFunc(event):
ip = self.list_com.get()
cur_mac = self.machine_dict[ip]
str_temp= self.__label_text__(cur_mac["user"],cur_mac["path"])
self.mac_text.set(str_temp)
self.update_log_task()
self.update_ckpt_task()
self.list_com.bind("<<ComboboxSelected>>",xFunc)
#################################################################################################
run_frame = tk.Frame(self.master)
run_frame.pack(fill="both", padx=5,pady=5)
run_frame.columnconfigure(0, weight=1)
run_frame.columnconfigure(1, weight=1)
run_test_button = tk.Button(run_frame, text = "Synch Files",
font=font_list, command = self.Synchronize, bg='#006400', fg='#FF0000')
run_test_button.grid(row=0,column=0,sticky=tk.EW)
open_button = tk.Button(run_frame, text = "Synch To All",
font=font_list, command = self.SynchronizeAll, bg='#F4A460', fg='#F5F5F5')
open_button.grid(row=0,column=1,sticky=tk.EW)
#################################################################################################
ssh_frame = tk.Frame(self.master)
ssh_frame.pack(fill="both", padx=5,pady=5)
ssh_frame.columnconfigure(0, weight=1)
ssh_frame.columnconfigure(1, weight=1)
ssh_button = tk.Button(ssh_frame, text = "Open SSH",
font=font_list, command = self.OpenSSH, bg='#990033', fg='#F5F5F5')
ssh_button.grid(row=0,column=0,sticky=tk.EW)
ssh_button = tk.Button(ssh_frame, text = "Pull Log",
font=font_list, command = self.PullLog, bg='#990033', fg='#F5F5F5')
ssh_button.grid(row=0,column=1,sticky=tk.EW)
#################################################################################################
config_frame = tk.Frame(self.master)
config_frame.pack(fill="both", padx=5,pady=5)
config_frame.columnconfigure(0, weight=1)
config_frame.columnconfigure(1, weight=1)
config_frame.columnconfigure(2, weight=1)
cmd_btn = tk.Button(config_frame, text = "Open CMD",
font=font_list, command = self.OpenCMD, bg='#0033FF', fg='#F5F5F5')
cmd_btn.grid(row=0,column=0,sticky=tk.EW)
cwd_btn = tk.Button(config_frame, text = "CWD",
font=font_list, command = self.CWD, bg='#0033FF', fg='#F5F5F5')
cwd_btn.grid(row=0,column=1,sticky=tk.EW)
gpu_btn = tk.Button(config_frame, text = "GPU Usage",
font=font_list, command = self.GPUUsage, bg='#0033FF', fg='#F5F5F5')
gpu_btn.grid(row=0,column=2,sticky=tk.EW)
################################################################################################
config_frame = tk.Frame(self.master)
config_frame.pack(fill="both", padx=5,pady=5)
config_frame.columnconfigure(0, weight=1)
config_frame.columnconfigure(1, weight=1)
# config_frame.columnconfigure(2, weight=1)
machine_btn = tk.Button(config_frame, text = "Ignore Conf",
font=font_list, command = self.IgnoreConfig, bg='#660099', fg='#F5F5F5')
machine_btn.grid(row=0,column=0,sticky=tk.EW)
machine_btn2 = tk.Button(config_frame, text = "Env Conf",
font=font_list, command = self.EnvConfig, bg='#660099', fg='#F5F5F5')
machine_btn2.grid(row=0,column=1,sticky=tk.EW)
#################################################################################################
log_frame = tk.Frame(self.master)
log_frame.pack(fill="both", padx=5,pady=5)
log_frame.columnconfigure(0, weight=1)
log_frame.columnconfigure(1, weight=1)
log_frame.columnconfigure(2, weight=1)
self.log_var = tkinter.StringVar()
self.log_com = ttk.Combobox(log_frame, textvariable=self.log_var)
self.log_com.grid(row=0,column=0,sticky=tk.EW)
def select_log(event):
self.update_ckpt_task()
self.log_com.bind("<<ComboboxSelected>>",select_log)
log_update_button = tk.Button(log_frame, text = "Fresh",
font=font_list, command = self.UpdateLog, bg='#F4A460', fg='#F5F5F5')
log_update_button.grid(row=0,column=1,sticky=tk.EW)
log_update_button = tk.Button(log_frame, text = "Pull Log",
font=font_list, command = self.PullLog, bg='#F4A460', fg='#F5F5F5')
log_update_button.grid(row=0,column=2,sticky=tk.EW)
#################################################################################################
test_frame = tk.Frame(self.master)
test_frame.pack(fill="both", padx=5,pady=5)
test_frame.columnconfigure(0, weight=1)
test_frame.columnconfigure(1, weight=1)
test_frame.columnconfigure(2, weight=1)
self.test_var = tkinter.StringVar()
self.test_com = ttk.Combobox(test_frame, textvariable=self.test_var)
self.test_com.grid(row=0,column=0,sticky=tk.EW)
test_update_button = tk.Button(test_frame, text = "Fresh CKPT",
font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5')
test_update_button.grid(row=0,column=1,sticky=tk.EW)
test_update_button = tk.Button(test_frame, text = "Test",
font=font_list, command = self.Test, bg='#F4A460', fg='#F5F5F5')
test_update_button.grid(row=0,column=2,sticky=tk.EW)
# #################################################################################################
# tbtext_frame = tk.Frame(self.master)
# tbtext_frame.pack(fill="both", padx=5,pady=5)
# tbtext_frame.columnconfigure(0, weight=5)
# self.tensorlog_str = tk.StringVar()
# tb_text = tk.Entry(tbtext_frame,font=font_list, textvariable=self.tensorlog_str )
# tb_text.grid(row=0,column=0,sticky=tk.EW)
# config_tb = tk.Button(tbtext_frame, text = "Config",
# font=font_list, command = self.OpenConfig, bg='#003472', fg='#F5F5F5')
# config_tb.grid(row=0,column=1,sticky=tk.EW)
# #################################################################################################
# #################################################################################################
# tb_frame = tk.Frame(self.master)
# tb_frame.pack(fill="both", padx=5,pady=5)
# tb_frame.columnconfigure(1, weight=1)
# tb_frame.columnconfigure(0, weight=1)
# open_tb = tk.Button(tb_frame, text = "Open Tensorboard",
# font=font_list, command = self.OpenTensorboard, bg='#003472', fg='#F5F5F5')
# open_tb.grid(row=0,column=0,sticky=tk.EW)
# download_tb = tk.Button(tb_frame, text = "Update Tensorboard Logs",
# font=font_list, command = self.DownloadTBLogs, bg='#003472', fg='#F5F5F5')
# download_tb.grid(row=0,column=1,sticky=tk.EW)
# #################################################################################################
text = tk.Text(self.master, wrap="word")
text.pack(fill="both",expand="yes", padx=5,pady=5)
sys.stdout = TextRedirector(text, "stdout")
self.master.protocol("WM_DELETE_WINDOW", self.on_closing)
# def __scaning_logs__(self):
def UpdateCKPT(self):
thread_update = threading.Thread(target=self.update_ckpt_task)
thread_update.start()
def update_ckpt_task(self):
ip = self.list_com.get()
log = self.log_com.get()
cur_mac = self.machine_dict[ip]
files = Path('.',cur_mac["ckp_path"], log)
files = files.glob('*.pth')
all_files = []
for one_file in files:
all_files.append(one_file.name)
self.test_com["value"] =all_files
self.test_com.current(0)
def Test(self):
def test_task():
log = self.log_com.get()
ckpt = self.test_com.get()
cwd = os.getcwd()
files = str(Path(log, ckpt))
print(files)
subprocess.check_call("start cmd /k \"cd /d %s && conda activate base \
&& python test.py --model %s\""%(cwd, files), shell=True)
thread_update = threading.Thread(target=test_task)
thread_update.start()
def Machines_Update(self):
self.update_log_task()
thread_update = threading.Thread(target=self.machines_update)
thread_update.start()
def machines_update(self):
self.machine_list = read_config(self.machine_json)
ip_list = []
for item in self.machine_list:
self.machine_dict[item["ip"]] = item
ip_list.append(item["ip"])
self.list_com["value"] = ip_list
self.list_com.current(0)
ip = self.list_com.get()
cur_mac = self.machine_dict[ip]
str_temp= self.__label_text__(cur_mac["user"],cur_mac["path"])
self.mac_text.set(str_temp)
print("Machine list update success!")
def connection(self):
ip = self.list_com.get()
cur_mac = self.machine_dict[ip]
ssh_ip = cur_mac["ip"]
ssh_username = cur_mac["user"]
ssh_passwd = cur_mac["passwd"]
ssh_port = int(cur_mac["port"])
print(ssh_ip)
if ip.lower() == "local" or ip.lower() == "localhost":
print("localhost no need to connect!")
return [], cur_mac
remotemachine = fileUploaderClass(ssh_ip,ssh_username,ssh_passwd,ssh_port)
return remotemachine, cur_mac
def __decode_filestr__(self, filestr):
cells = filestr.split("\n")
print(cells)
def update_log_task(self):
remotemachine,mac = self.connection()
remote_path = os.path.join(mac["path"],mac["ckp_path"]).replace("\\", "/")
if remotemachine == []:
files = Path(remote_path).glob("*/")
first_level = {}
for one_file in files:
first_level[one_file.name] = {
"t":"",
"p":""
}
else:
first_level = remotemachine.sshScpGetNames(remote_path)
logs = []
for k,v in first_level.items():
logs.append(k)
logs = sorted(logs)
self.log_com["value"] =logs
self.log_com.current(0)
self.current_log = first_level
self.update_ckpt_task()
def UpdateLog(self):
thread_update = threading.Thread(target=self.update_log_task)
thread_update.start()
def PullLog(self):
def pull_log_task():
remotemachine,mac = self.connection()
log = self.log_com.get()
remote_path = self.current_log[log]["p"]
if remotemachine == []:
return
all_level = remotemachine.sshScpGetRNames(remote_path)
file_need_download = []
local_position = []
local_dir = Path("./",mac["ckp_path"],log)
if not local_dir.exists():
local_dir.mkdir()
for k,v in all_level.items():
local_file = Path("./",mac["ckp_path"],log,k)
if local_file.exists():
if int(local_file.stat().st_mtime) < v["t"]:
file_need_download.append(v["p"])
local_position.append(str(local_file))
# print(int(local_file.stat().st_mtime))
# print(v["t"])
else:
file_need_download.append(v["p"])
local_position.append(str(local_file))
if len(file_need_download) > 0 :
remotemachine.sshScpGetFiles(file_need_download, local_position)
else:
print("No file need to pull......")
self.update_ckpt_task()
thread_update = threading.Thread(target=pull_log_task)
thread_update.start()
def OpenCMD(self):
def open_cmd_task():
subprocess.call("start cmd", shell=True)
thread_update = threading.Thread(target=open_cmd_task)
thread_update.start()
def CWD(self):
def open_cmd_task():
cwd = os.getcwd()
subprocess.call("explorer "+cwd, shell=False)
thread_update = threading.Thread(target=open_cmd_task)
thread_update.start()
def OpenSSH(self):
def open_ssh_task():
ip = self.list_com.get()
if ip.lower() == "local" or ip.lower() == "localhost":
print("localhost no need to connect!")
cur_mac = self.machine_dict[ip]
ssh_ip = cur_mac["ip"]
ssh_username = cur_mac["user"]
ssh_passwd = cur_mac["passwd"]
ssh_port = cur_mac["port"]
# subprocess.call("start cmd", shell=True)
subprocess.call("start cmd /k ssh %s@%s -p %s"%(ssh_username, ssh_ip, ssh_port), shell=True)
# subprocess.call("start echo %s"%(ssh_passwd), shell=True)
# p = Popen("cp -rf a/* b/", shell=True, stdout=PIPE, stderr=PIPE)
# proc = subprocess.Popen("ssh %s@%s -p %s"%(ssh_username, ssh_ip, ssh_port),
# stdin=subprocess.PIPE, stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,creationflags =subprocess.CREATE_NEW_CONSOLE)
# # out, err = proc.communicate(ssh_passwd.encode("utf-8"))
# proc.stdin.write(ssh_passwd.encode('utf-8'))
# print(out.decode('utf-8'))
thread_update = threading.Thread(target=open_ssh_task)
thread_update.start()
def GPUUsage(self):
def gpu_usage_task():
remotemachine,_ = self.connection()
results = remotemachine.sshExec("nvidia-smi")
print(results)
thread_update = threading.Thread(target=gpu_usage_task)
thread_update.start()
def IgnoreConfig(self):
def ignore_config_task():
if not os.path.exists(self.ignore_json):
print("guiignore.json file does not exist...")
if not os.path.exists(self.gui_root):
os.makedirs(self.gui_root)
write_config(self.ignore_json,self.ignore_text)
subprocess.call("start %s"%self.ignore_json, shell=True)
thread_update = threading.Thread(target=ignore_config_task)
thread_update.start()
def EnvConfig(self):
def env_config_task():
root_dir = os.getcwd()
logs_dir = os.path.join(root_dir,"env","env.json")
if not os.path.exists(logs_dir):
print("env.json file does not exist...")
if not os.path.exists(os.path.join(root_dir,"env")):
os.makedirs(os.path.join(root_dir,"env"))
write_config(logs_dir,self.env_text)
subprocess.call("start env/env.json", shell=True)
thread_update = threading.Thread(target=env_config_task)
thread_update.start()
def MachineConfig(self):
def machine_config_task():
subprocess.call("start %s"%self.machine_json, shell=True)
thread_update = threading.Thread(target=machine_config_task)
thread_update.start()
def OpenConfig(self):
def open_config_task():
root_dir = os.getcwd()
logs_dir = os.path.join(root_dir,"env","logs_position.json")
if not os.path.exists(logs_dir):
print("logs configuration file does not exist...")
positions={
"template":{
"root_path":"./",
"machine_name":"localhost",
}
}
if not os.path.exists(os.path.join(root_dir,"env")):
os.makedirs(os.path.join(root_dir,"env"))
write_config(logs_dir,positions)
subprocess.call("start env/logs_position.json", shell=True)
# time.sleep(5)
# subprocess.call("start http://localhost:6006/", shell=True)
thread_update = threading.Thread(target=open_config_task)
thread_update.start()
def OpenTensorboard(self):
thread_update = threading.Thread(target=self.open_tensorboard_task)
thread_update.start()
def open_tensorboard_task(self):
self.download_tblogs()
root_dir = os.getcwd()
logs_dir = os.path.join(root_dir,"train_logs")
subprocess.call("start cmd /k tensorboard --logdir=\"%s\""%(logs_dir), shell=True)
time.sleep(5)
subprocess.call("start http://localhost:6006/", shell=True)
def DownloadTBLogs(self):
thread_update = threading.Thread(target=self.download_tblogs)
thread_update.start()
def download_tblogs(self):
tb_monitor_logs = self.tensorlog_str.get()
tb_monitor_logs = tb_monitor_logs.split(";")
root_dir = os.getcwd()
mach_dir = os.path.join(root_dir,"env","machine_config.json")
machines = read_config(mach_dir)
logs_dir = os.path.join(root_dir,"env","logs_position.json")
tb_logs = read_config(logs_dir)
for i_logs in tb_monitor_logs:
try:
mac_name = tb_logs[i_logs]["machine_name"]
i_mac = machines[mac_name]
i_mac["log_name"] = i_logs
# mac_list.append(i_mac)
remotemachine = fileUploaderClass(i_mac["ip"],i_mac["usrname"],i_mac["passwd"],i_mac["port"])
path_temp = Path(i_mac["root"],"train_logs",i_logs,"summary").as_posix()
local_dir = Path(root_dir,"train_logs",i_logs,"summary")
if not Path(local_dir).exists():
Path(local_dir).mkdir(parents=True)
remotemachine.sshScpGetDir(path_temp,local_dir)
print("%s log files download successful!"%i_logs)
except Exception as e:
print(e)
def Synchronize(self):
def update():
self.update_action()
thread_update = threading.Thread(target=update)
thread_update.start()
def SynchronizeAll(self):
def update_all():
for i_mach in range(len(self.tab_info["configs"])):
self.update_action(i_mach)
thread_update = threading.Thread(target=update_all)
thread_update.start()
def update_action(self):
last_state = {}
changed_files = []
ip = self.list_com.get()
cur_mac = self.machine_dict[ip]
if ip.lower() == "local" or ip.lower() == "localhost":
print("localhost no need to update!")
return
ssh_ip = cur_mac["ip"]
ssh_username = cur_mac["user"]
ssh_passwd = cur_mac["passwd"]
ssh_port = cur_mac["port"]
root_path = cur_mac["path"]
log_path = os.path.join(self.filesynlogroot,cur_mac["logfilename"])
if not Path(self.filesynlogroot).exists():
Path(self.filesynlogroot).mkdir(parents=True)
else:
if Path(log_path).exists():
with open(log_path,'r') as cf:
nodelocaltionstr = cf.read()
last_state = json.loads(nodelocaltionstr)
all_files = []
# scan files
file_filter = read_config("./GUI/guiignore.json")
white_list = file_filter["white_list"]
black_list = file_filter["black_list"]
white_ext = white_list["extension"]
black_path = black_list["path"]
black_file = black_list["file"]
for item in white_ext:
if item=="":
print("something error in the white list")
continue
files = Path('.').rglob('*.%s'%item) # ./*
for one_file in files:
all_files.append(one_file)
for i_dir in black_path:
files = Path('.', i_dir).rglob('*.%s'%item)
for one_file in files:
# print(one_file)
all_files.remove(one_file)
for item in black_file:
try:
all_files.remove(Path('.', item))
except:
print("%s does not exist!"%item)
# check updated files
for item in all_files:
temp = item.stat().st_mtime
if item._str in last_state:
last_mtime = last_state[item._str]
if last_mtime != temp:
changed_files.append(item._str)
last_state[item._str] = temp
else:
changed_files.append(item._str)
last_state[item._str] = temp
print("[To %s]"%ssh_ip,changed_files)
localfiles = []
remotefiles = []
for item in changed_files:
localfiles.append(item)
remotefiles.append(Path(root_path,item).as_posix())
try:
remotemachine = fileUploaderClass(ssh_ip,ssh_username,ssh_passwd,ssh_port)
remotemachine.sshScpPuts(localfiles,remotefiles)
with open(log_path, 'w') as cf:
configjson = json.dumps(last_state, indent=4)
cf.writelines(configjson)
except Exception as e:
print(e)
print("File Synchronize Failed!")
# def __save_config__(self):
# previous_info = read_config(self.log_path)
# for i in range(len(self.tab_info["names"])):
# databind = self.tab_info["databind"][i]
# data_aquire = {
# "name": self.tab_info["names"][i],
# "remote_ip": databind["remote_ip"].get(),
# "remote_user": databind["remote_user"].get(),
# "remote_port": databind["remote_port"].get(),
# "remote_passwd":databind["remote_passwd"].get(),
# "remote_path": databind["remote_path"].get(),
# "logfilename": "filestate_%s.json"%self.tab_info["names"][i]
# }
# if self.tab_info["names"][i] in previous_info["names"]:
# location = previous_info["names"].index(self.tab_info["names"][i])
# previous_info["configs"][location] = data_aquire
# else:
# previous_info["names"].append(self.tab_info["names"][i])
# previous_info["configs"].append(data_aquire)
# previous_info["databind"] = []
# write_config(self.log_path,previous_info)
def on_closing(self):
# self.__save_config__()
self.master.destroy()
if __name__ == "__main__":
app = Application()
app.mainloop()
@@ -0,0 +1,124 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Conditional_Discriminator copy.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 29th June 2021 4:26:33 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from torch.nn import utils
class Discriminator(nn.Module):
def __init__(self, chn=32, k_size=3, n_class=3):
super().__init__()
# padding_size = int((k_size -1)/2)
slop = 0.2
enable_bias = True
# stage 1
self.block1 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size= k_size, stride = 2, padding=2,bias= enable_bias)),
nn.LeakyReLU(slop),
utils.spectral_norm(nn.Conv2d(in_channels = chn, out_channels = chn * 2 , kernel_size= k_size, stride = 2,padding=2, bias= enable_bias)), # 1/4
nn.LeakyReLU(slop)
)
self.aux_classfier1 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = chn * 2 , out_channels = chn , kernel_size= 5, bias=enable_bias)),
nn.LeakyReLU(slop),
nn.AdaptiveAvgPool2d(1),
)
self.embed1 = utils.spectral_norm(nn.Embedding(n_class, chn))
self.linear1= utils.spectral_norm(nn.Linear(chn, 1))
# stage 2
self.block2 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = chn * 2 , out_channels = chn * 4 , kernel_size= k_size, stride = 2, padding=2, bias= enable_bias)),# 1/8
nn.LeakyReLU(slop),
utils.spectral_norm(nn.Conv2d(in_channels = chn * 4, out_channels = chn * 8 , kernel_size= k_size, stride = 2,padding=2, bias= enable_bias)),# 1/16
nn.LeakyReLU(slop)
)
self.aux_classfier2 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8 , out_channels = chn , kernel_size= 5, bias= enable_bias)),
nn.LeakyReLU(slop),
nn.AdaptiveAvgPool2d(1),
)
self.embed2 = utils.spectral_norm(nn.Embedding(n_class, chn))
self.linear2= utils.spectral_norm(nn.Linear(chn, 1))
# stage 3
self.block3 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8 , out_channels = chn * 8 , kernel_size= k_size, stride = 2,padding=3, bias= enable_bias)),# 1/32
nn.LeakyReLU(slop),
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8, out_channels = chn * 16 , kernel_size= k_size, stride = 2,padding=3, bias= enable_bias)),# 1/64
nn.LeakyReLU(slop)
)
self.aux_classfier3 = nn.Sequential(
utils.spectral_norm(nn.Conv2d(in_channels = chn * 16 , out_channels = chn, kernel_size= 5, bias= enable_bias)),
nn.LeakyReLU(slop),
nn.AdaptiveAvgPool2d(1),
)
self.embed3 = utils.spectral_norm(nn.Embedding(n_class, chn))
self.linear3= utils.spectral_norm(nn.Linear(chn, 1))
self.__weights_init__()
def __weights_init__(self):
print("Init weights")
for m in self.modules():
if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear):
nn.init.xavier_uniform_(m.weight)
try:
nn.init.zeros_(m.bias)
except:
print("No bias found!")
if isinstance(m, nn.Embedding):
nn.init.xavier_uniform_(m.weight)
def forward(self, input, condition):
h = self.block1(input)
prep1 = self.aux_classfier1(h)
prep1 = prep1.view(prep1.size()[0], -1)
y1 = self.embed1(condition)
y1 = torch.sum(y1 * prep1, dim=1, keepdim=True)
prep1 = self.linear1(prep1) + y1
h = self.block2(h)
prep2 = self.aux_classfier2(h)
prep2 = prep2.view(prep2.size()[0], -1)
y2 = self.embed2(condition)
y2 = torch.sum(y2 * prep2, dim=1, keepdim=True)
prep2 = self.linear2(prep2) + y2
h = self.block3(h)
prep3 = self.aux_classfier3(h)
prep3 = prep3.view(prep3.size()[0], -1)
y3 = self.embed3(condition)
y3 = torch.sum(y3 * prep3, dim=1, keepdim=True)
prep3 = self.linear3(prep3) + y3
out_prep = [prep1,prep2,prep3]
return out_prep
def get_outputs_len(self):
num = 0
for m in self.modules():
if isinstance(m,nn.Linear):
num+=1
return num
if __name__ == "__main__":
wocao = Discriminator().cuda()
from torchsummary import summary
summary(wocao, input_size=(3, 512, 512))
+114
View File
@@ -0,0 +1,114 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Conditional_Generator_tanh.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 6th July 2021 1:16:46 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from components.ResBlock import ResBlock
from components.DeConv import DeConv
from components.Conditional_ResBlock_ModulaConv import Conditional_ResBlock
class Generator(nn.Module):
def __init__(
self,
chn=32,
k_size=3,
res_num = 5,
class_num = 3,
**kwargs):
super().__init__()
padding_size = int((k_size -1)/2)
self.resblock_list = []
self.n_class = class_num
self.encoder1 = nn.Sequential(
# nn.InstanceNorm2d(3, affine=True),
# nn.ReflectionPad2d(padding_size),
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size= k_size, stride=1, padding=1, bias= False),
nn.InstanceNorm2d(chn, affine=True, momentum=0),
# nn.ReLU(),
nn.LeakyReLU(),
# nn.ReflectionPad2d(padding_size),
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size= k_size, stride=2, padding=1,bias =False), #
nn.InstanceNorm2d(chn*2, affine=True, momentum=0),
# nn.ReLU(),
nn.LeakyReLU(),
# nn.ReflectionPad2d(padding_size),
nn.Conv2d(in_channels = chn*2, out_channels = chn * 4, kernel_size= k_size, stride=2, padding=1,bias =False),
nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
# nn.ReLU(),
nn.LeakyReLU(),
# nn.ReflectionPad2d(padding_size),
nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size= k_size, stride=2, padding=1,bias =False),
nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.ReLU(),
nn.LeakyReLU(),
# # nn.ReflectionPad2d(padding_size),
nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size= k_size, stride=2, padding=1,bias =False),
nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.ReLU(),
nn.LeakyReLU()
)
res_size = chn * 8
for _ in range(res_num-1):
self.resblock_list += [ResBlock(res_size,k_size),]
self.resblocks = nn.Sequential(*self.resblock_list)
self.conditional_res = Conditional_ResBlock(res_size, k_size, class_num)
self.decoder1 = nn.Sequential(
DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size= k_size),
nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.ReLU(),
nn.LeakyReLU(),
DeConv(in_channels = chn * 8, out_channels = chn *4, kernel_size= k_size),
nn.InstanceNorm2d(chn *4, affine=True, momentum=0),
# nn.ReLU(),
nn.LeakyReLU(),
DeConv(in_channels = chn * 4, out_channels = chn * 2 , kernel_size= k_size),
nn.InstanceNorm2d(chn*2, affine=True, momentum=0),
# nn.ReLU(),
nn.LeakyReLU(),
DeConv(in_channels = chn *2, out_channels = chn, kernel_size= k_size),
nn.InstanceNorm2d(chn, affine=True, momentum=0),
# nn.ReLU(),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn, out_channels =3, kernel_size= k_size, stride=1, padding=1,bias =True)
# nn.Tanh()
)
self.__weights_init__()
def __weights_init__(self):
for layer in self.encoder1:
if isinstance(layer,nn.Conv2d):
nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, input, condition=None, get_feature = False):
feature = self.encoder1(input)
if get_feature:
return feature
out = self.conditional_res(feature, condition)
out = self.resblocks(out)
# n, _,h,w = out.size()
# attr = condition.view((n, self.n_class, 1, 1)).expand((n, self.n_class, h, w))
# out = torch.cat([out, attr], dim=1)
out = self.decoder1(out)
return out,feature
@@ -0,0 +1,82 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Conditional_ResBlock_v2.py
# Created Date: Tuesday June 29th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 29th June 2021 3:59:44 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
# -*- coding:utf-8 -*-
###################################################################
### @FilePath: \ASMegaGAN\components\Conditional_ResBlock_v2.py
### @Author: Ziang Liu
### @Date: 2021-06-28 21:30:17
### @LastEditors: Ziang Liu
### @LastEditTime: 2021-06-28 21:46:24
### @Copyright (C) 2021 SJTU. All rights reserved.
###################################################################
import torch
from torch import nn
import torch.nn.functional as F
# from ops.Conditional_BN import Conditional_BN
# from components.Adain import Adain
class Conv2DMod(nn.Module):
def __init__(self, in_channels, out_channels, kernel, demod=True, stride=1, dilation=1, eps = 1e-8, **kwargs):
super().__init__()
self.filters = out_channels
self.demod = demod
self.kernel = kernel
self.stride = stride
self.dilation = dilation
self.weight = nn.Parameter(torch.randn((out_channels, in_channels, kernel, kernel)))
self.eps = eps
padding_size = int((kernel -1)/2)
self.same_padding = nn.ReplicationPad2d(padding_size)
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, x, y):
b, c, h, w = x.shape
w1 = y[:, None, :, None, None]
w2 = self.weight[None, :, :, :, :]
weights = w2 * (w1 + 1)
if self.demod:
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
weights = weights * d
x = x.reshape(1, -1, h, w)
_, _, *ws = weights.shape
weights = weights.reshape(b * self.filters, *ws)
x = self.same_padding(x)
x = F.conv2d(x, weights, groups=b)
x = x.reshape(-1, self.filters, h, w)
return x
class Conditional_ResBlock(nn.Module):
def __init__(self, in_channel, k_size = 3, n_class = 2, stride=1):
super().__init__()
self.embed1 = nn.Embedding(n_class, in_channel)
self.embed2 = nn.Embedding(n_class, in_channel)
self.conv1 = Conv2DMod(in_channels = in_channel , out_channels = in_channel, kernel= k_size, stride=stride)
self.conv2 = Conv2DMod(in_channels = in_channel , out_channels = in_channel, kernel= k_size, stride=stride)
def forward(self, input, condition):
res = input
style1 = self.embed1(condition)
h = self.conv1(res, style1)
style2 = self.embed2(condition)
h = self.conv2(h, style2)
out = h + res
return out
+20
View File
@@ -0,0 +1,20 @@
import torch
from torch import nn
class DeConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size = 3, upsampl_scale = 2):
super().__init__()
self.upsampling = nn.UpsamplingNearest2d(scale_factor=upsampl_scale)
padding_size = int((kernel_size -1)/2)
# self.same_padding = nn.ReflectionPad2d(padding_size)
self.conv = nn.Conv2d(in_channels = in_channels ,padding=padding_size, out_channels = out_channels , kernel_size= kernel_size, bias= False)
self.__weights_init__()
def __weights_init__(self):
nn.init.xavier_uniform_(self.conv.weight)
def forward(self, input):
h = self.upsampling(input)
# h = self.same_padding(h)
h = self.conv(h)
return h
+156
View File
@@ -0,0 +1,156 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Conditional_Generator_gpt_LN_encoder copy.py
# Created Date: Saturday October 9th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 26th October 2021 3:25:47 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from components.DeConv import DeConv
from components.network_swin import SwinTransformerBlock, PatchEmbed, PatchUnEmbed
class ImageLN(nn.Module):
def __init__(self, dim) -> None:
super().__init__()
self.layer = nn.LayerNorm(dim)
def forward(self, x):
y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2)
return y
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
chn = kwargs["g_conv_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
class_num = kwargs["n_class"]
window_size = kwargs["window_size"]
image_size = kwargs["image_size"]
padding_size = int((k_size -1)/2)
self.resblock_list = []
embed_dim = 96
window_size = 8
num_heads = 8
mlp_ratio = 2.
norm_layer = nn.LayerNorm
qk_scale = None
qkv_bias = True
self.patch_norm = True
self.lnnorm = norm_layer(embed_dim)
self.encoder = nn.Sequential(
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
ImageLN(chn),
nn.ReLU(),
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
ImageLN(chn * 2),
nn.ReLU(),
nn.Conv2d(in_channels = chn*2, out_channels = embed_dim, kernel_size=k_size, stride=2, padding=1,bias =False),
ImageLN(embed_dim),
nn.ReLU(),
)
# self.encoder2 = nn.Sequential(
# nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU()
# )
self.fea_size = (image_size//4, image_size//4)
# self.conditional_GPT = GPT_Spatial(2, res_dim, res_num, class_num)
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=embed_dim, input_resolution=self.fea_size,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=0.0, attn_drop=0.0,
drop_path=0.1,
norm_layer=norm_layer)
for i in range(res_num)])
self.decoder = nn.Sequential(
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.LeakyReLU(),
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.LeakyReLU(),
# DeConv(in_channels = chn * 8, out_channels = chn *4, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
# nn.LeakyReLU(),
DeConv(in_channels = embed_dim, out_channels = chn * 2 , kernel_size=k_size),
# nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
ImageLN(chn * 2),
nn.ReLU(),
DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
ImageLN(chn),
nn.ReLU(),
nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
)
self.patch_embed = PatchEmbed(
img_size=self.fea_size[0], patch_size=1, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
self.patch_unembed = PatchUnEmbed(
img_size=self.fea_size[0], patch_size=1, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, input):
x2 = self.encoder(input)
x2 = self.patch_embed(x2)
for blk in self.blocks:
x2 = blk(x2,self.fea_size)
x2 = self.lnnorm(x2)
x2 = self.patch_unembed(x2,self.fea_size)
out = self.decoder(x2)
return out
if __name__ == '__main__':
upscale = 4
window_size = 8
height = 1024
width = 1024
model = Generator()
print(model)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
+129
View File
@@ -0,0 +1,129 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Conditional_Generator_gpt_LN_encoder copy.py
# Created Date: Saturday October 9th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 11th October 2021 5:22:22 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from components.ResBlock import ResBlock
from components.DeConv import DeConv
class ImageLN(nn.Module):
def __init__(self, dim) -> None:
super().__init__()
self.layer = nn.LayerNorm(dim)
def forward(self, x):
y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2)
return y
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
chn = kwargs["g_conv_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
class_num = kwargs["n_class"]
window_size = kwargs["window_size"]
image_size = kwargs["image_size"]
padding_size = int((k_size -1)/2)
self.resblock_list = []
embed_dim = 96
window_size = 8
num_heads = 8
mlp_ratio = 2.
norm_layer = nn.LayerNorm
qk_scale = None
qkv_bias = True
self.patch_norm = True
self.lnnorm = norm_layer(embed_dim)
self.encoder = nn.Sequential(
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
nn.InstanceNorm2d(chn),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
nn.InstanceNorm2d(chn * 2),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn*2, out_channels = embed_dim, kernel_size=k_size, stride=2, padding=1,bias =False),
nn.InstanceNorm2d(embed_dim),
nn.LeakyReLU(),
)
# self.encoder2 = nn.Sequential(
# nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU()
# )
self.decoder = nn.Sequential(
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.LeakyReLU(),
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.LeakyReLU(),
# DeConv(in_channels = chn * 8, out_channels = chn *4, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
# nn.LeakyReLU(),
DeConv(in_channels = embed_dim, out_channels = chn * 2 , kernel_size=k_size),
# nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
nn.InstanceNorm2d(chn * 2),
nn.LeakyReLU(),
DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
nn.InstanceNorm2d(chn),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
)
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, input):
x2 = self.encoder(input)
out = self.decoder(x2)
return out
if __name__ == '__main__':
upscale = 4
window_size = 8
height = 1024
width = 1024
model = Generator()
print(model)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
+110
View File
@@ -0,0 +1,110 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Conditional_Generator_gpt_LN_encoder copy.py
# Created Date: Saturday October 9th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 7:35:08 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from components.ResBlock import ResBlock
from components.DeConv import DeConv
class ImageLN(nn.Module):
def __init__(self, dim) -> None:
super().__init__()
self.layer = nn.LayerNorm(dim)
def forward(self, x):
y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2)
return y
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
chn = kwargs["g_conv_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
padding_size = int((k_size -1)/2)
self.resblock_list = []
self.encoder = nn.Sequential(
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
nn.InstanceNorm2d(chn, affine=True, momentum=0),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False),
nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False),
nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
nn.LeakyReLU(),
)
for _ in range(res_num):
self.resblock_list += [ResBlock(chn * 4,k_size),]
self.resblocks = nn.Sequential(*self.resblock_list)
self.decoder = nn.Sequential(
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.LeakyReLU(),
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.LeakyReLU(),
DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size),
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
nn.LeakyReLU(),
DeConv(in_channels = chn * 2, out_channels = chn * 2 , kernel_size=k_size),
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
nn.ReLU(),
DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
nn.InstanceNorm2d(chn, affine=True, momentum=0),
nn.ReLU(),
nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
)
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, input):
x2 = self.encoder(input)
x2 = self.resblocks(x2)
out = self.decoder(x2)
return out
if __name__ == '__main__':
upscale = 4
window_size = 8
height = 1024
width = 1024
model = Generator()
print(model)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
+144
View File
@@ -0,0 +1,144 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: FastNST_Liif.py
# Created Date: Thursday October 14th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 2:39:09 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from components.ResBlock import ResBlock
from components.DeConv import DeConv
from components.Liif import LIIF
class ImageLN(nn.Module):
def __init__(self, dim) -> None:
super().__init__()
self.layer = nn.LayerNorm(dim)
def forward(self, x):
y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2)
return y
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
chn = kwargs["g_conv_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
class_num = kwargs["n_class"]
window_size = kwargs["window_size"]
image_size = kwargs["image_size"]
batch_size = kwargs["batch_size"]
# mlp_in_dim = kwargs["mlp_in_dim"]
# mlp_out_dim = kwargs["mlp_out_dim"]
mlp_hidden_list = kwargs["mlp_hidden_list"]
padding_size = int((k_size -1)/2)
self.resblock_list = []
embed_dim = 96
window_size = 8
num_heads = 8
mlp_ratio = 2.
norm_layer = nn.LayerNorm
qk_scale = None
qkv_bias = True
self.patch_norm = True
self.lnnorm = norm_layer(embed_dim)
self.encoder = nn.Sequential(
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
nn.InstanceNorm2d(chn),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
nn.InstanceNorm2d(chn * 2),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False),
nn.InstanceNorm2d(chn * 4),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False),
ImageLN(chn * 4),
nn.LeakyReLU(),
)
for _ in range(res_num):
self.resblock_list += [ResBlock(chn * 4,k_size),]
self.resblocks = nn.Sequential(*self.resblock_list)
# self.encoder2 = nn.Sequential(
# nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU()
# )
self.decoder = nn.Sequential(
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.LeakyReLU(),
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.LeakyReLU(),
DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size),
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
nn.LeakyReLU(),
DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
nn.InstanceNorm2d(chn),
nn.LeakyReLU()
# DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
# nn.InstanceNorm2d(chn),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
)
self.upsample = LIIF(chn, 3, mlp_hidden_list)
self.upsample.gen_coord((batch_size, \
chn,image_size//2,image_size//2),(image_size,image_size))
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, input):
x2 = self.encoder(input)
x2 = self.resblocks(x2)
out = self.decoder(x2)
out = self.upsample(out)
return out
if __name__ == '__main__':
upscale = 4
window_size = 8
height = 1024
width = 1024
model = Generator()
print(model)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
+150
View File
@@ -0,0 +1,150 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: FastNST_Liif.py
# Created Date: Thursday October 14th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 4:33:51 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from components.ResBlock import ResBlock
from components.DeConv import DeConv
from components.Liif_conv import LIIF
class ImageLN(nn.Module):
def __init__(self, dim) -> None:
super().__init__()
self.layer = nn.LayerNorm(dim)
def forward(self, x):
y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2)
return y
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
chn = kwargs["g_conv_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
class_num = kwargs["n_class"]
window_size = kwargs["window_size"]
image_size = kwargs["image_size"]
batch_size = kwargs["batch_size"]
# mlp_in_dim = kwargs["mlp_in_dim"]
# mlp_out_dim = kwargs["mlp_out_dim"]
padding_size = int((k_size -1)/2)
self.resblock_list = []
embed_dim = 96
window_size = 8
num_heads = 8
mlp_ratio = 2.
norm_layer = nn.LayerNorm
qk_scale = None
qkv_bias = True
self.patch_norm = True
self.lnnorm = norm_layer(embed_dim)
self.encoder = nn.Sequential(
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
nn.InstanceNorm2d(chn, affine=True, momentum=0),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False),
nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False),
nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
nn.LeakyReLU(),
)
for _ in range(res_num):
self.resblock_list += [ResBlock(chn * 4,k_size),]
self.resblocks = nn.Sequential(*self.resblock_list)
# self.encoder2 = nn.Sequential(
# nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
# ImageLN(chn * 8),
# nn.LeakyReLU()
# )
self.decoder = nn.Sequential(
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.LeakyReLU(),
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# nn.LeakyReLU(),
DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size),
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
nn.LeakyReLU(),
# DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size),
# # nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
# nn.InstanceNorm2d(chn, affine=True, momentum=0),
# nn.LeakyReLU()
# DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
# nn.InstanceNorm2d(chn),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
)
self.upsample1 = LIIF(chn*2, chn)
self.upsample1.gen_coord((batch_size, \
chn,image_size//4,image_size//4),(image_size//2,image_size//2))
self.upsample2 = LIIF(chn, chn)
self.upsample2.gen_coord((batch_size, \
chn,image_size//2,image_size//2),(image_size,image_size))
self.out_conv = nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, input):
x2 = self.encoder(input)
x2 = self.resblocks(x2)
out = self.decoder(x2)
out = self.upsample1(out)
out = self.upsample2(out)
out = self.out_conv(out)
return out
if __name__ == '__main__':
upscale = 4
window_size = 8
height = 1024
width = 1024
model = Generator()
print(model)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
+146
View File
@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: FastNST_Liif.py
# Created Date: Thursday October 14th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 8:47:28 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from components.ResBlock import ResBlock
from components.DeConv import DeConv
from components.Liif_invo import LIIF
class Generator(nn.Module):
def __init__(
self,
**kwargs
):
super().__init__()
chn = kwargs["g_conv_dim"]
k_size = kwargs["g_kernel_size"]
res_num = kwargs["res_num"]
class_num = kwargs["n_class"]
window_size = kwargs["window_size"]
image_size = kwargs["image_size"]
batch_size = kwargs["batch_size"]
# mlp_in_dim = kwargs["mlp_in_dim"]
# mlp_out_dim = kwargs["mlp_out_dim"]
padding_size = int((k_size -1)/2)
self.resblock_list = []
embed_dim = 96
norm_layer = nn.LayerNorm
self.img_token = nn.Sequential(
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
nn.InstanceNorm2d(chn, affine=True, momentum=0),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False),
# nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
# nn.LeakyReLU(),
# nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False),
# nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
# nn.LeakyReLU(),
)
image_size = image_size // 2
self.downsample1 = LIIF(chn * 2, chn * 4)
self.downsample1.gen_coord((batch_size, \
chn,image_size,image_size),(image_size//2,image_size//2))
image_size = image_size // 2
self.downsample2 = LIIF(chn * 4, chn * 4)
self.downsample2.gen_coord((batch_size, \
chn,image_size,image_size),(image_size//2,image_size//2))
for _ in range(res_num):
self.resblock_list += [ResBlock(chn * 4,k_size),]
self.resblocks = nn.Sequential(*self.resblock_list)
# self.decoder = nn.Sequential(
# # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# # nn.LeakyReLU(),
# # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
# # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
# # nn.LeakyReLU(),
# DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size),
# nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
# nn.LeakyReLU(),
# # DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size),
# # # nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
# # nn.InstanceNorm2d(chn, affine=True, momentum=0),
# # nn.LeakyReLU()
# # DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
# # nn.InstanceNorm2d(chn),
# # nn.LeakyReLU(),
# # nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
# )
image_size = image_size // 2
self.upsample1 = LIIF(chn*4, chn * 4)
self.upsample1.gen_coord((batch_size, \
chn,image_size,image_size),(image_size*2,image_size*2))
image_size = image_size * 2
self.upsample2 = LIIF(chn*4, chn * 2)
self.upsample2.gen_coord((batch_size, \
chn,image_size,image_size),(image_size*2,image_size*2))
# image_size = image_size * 2
# self.upsample2 = LIIF(chn, chn)
# self.upsample2.gen_coord((batch_size, \
# chn,image_size,image_size),(image_size*2,image_size*2))
self.decoder = nn.Sequential(
DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size),
nn.InstanceNorm2d(chn, affine=True, momentum=0),
nn.LeakyReLU(),
nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
)
# self.out_conv = nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
# self.__weights_init__()
# def __weights_init__(self):
# for layer in self.encoder:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
# for layer in self.encoder2:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
def forward(self, input):
out = self.img_token(input)
out = self.downsample1(out)
out = self.downsample2(out)
out = self.resblocks(out)
out = self.upsample1(out)
out = self.upsample2(out)
out = self.decoder(out)
return out
if __name__ == '__main__':
upscale = 4
window_size = 8
height = 1024
width = 1024
model = Generator()
print(model)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
+303
View File
@@ -0,0 +1,303 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Involution.py
# Created Date: Tuesday July 20th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 20th July 2021 10:35:52 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
from torch.autograd import Function
import torch
from torch.nn.modules.utils import _pair
import torch.nn.functional as F
import torch.nn as nn
from mmcv.cnn import ConvModule
from collections import namedtuple
import cupy
from string import Template
Stream = namedtuple('Stream', ['ptr'])
def Dtype(t):
if isinstance(t, torch.cuda.FloatTensor):
return 'float'
elif isinstance(t, torch.cuda.DoubleTensor):
return 'double'
@cupy._util.memoize(for_each_device=True)
def load_kernel(kernel_name, code, **kwargs):
code = Template(code).substitute(**kwargs)
kernel_code = cupy.cuda.compile_with_cache(code)
return kernel_code.get_function(kernel_name)
CUDA_NUM_THREADS = 1024
kernel_loop = '''
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
'''
def GET_BLOCKS(N):
return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS
_involution_kernel = kernel_loop + '''
extern "C"
__global__ void involution_forward_kernel(
const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) {
CUDA_KERNEL_LOOP(index, ${nthreads}) {
const int n = index / ${channels} / ${top_height} / ${top_width};
const int c = (index / ${top_height} / ${top_width}) % ${channels};
const int h = (index / ${top_width}) % ${top_height};
const int w = index % ${top_width};
const int g = c / (${channels} / ${groups});
${Dtype} value = 0;
#pragma unroll
for (int kh = 0; kh < ${kernel_h}; ++kh) {
#pragma unroll
for (int kw = 0; kw < ${kernel_w}; ++kw) {
const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h};
const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w};
if ((h_in >= 0) && (h_in < ${bottom_height})
&& (w_in >= 0) && (w_in < ${bottom_width})) {
const int offset = ((n * ${channels} + c) * ${bottom_height} + h_in)
* ${bottom_width} + w_in;
const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h)
* ${top_width} + w;
value += weight_data[offset_weight] * bottom_data[offset];
}
}
}
top_data[index] = value;
}
}
'''
_involution_kernel_backward_grad_input = kernel_loop + '''
extern "C"
__global__ void involution_backward_grad_input_kernel(
const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* const bottom_diff) {
CUDA_KERNEL_LOOP(index, ${nthreads}) {
const int n = index / ${channels} / ${bottom_height} / ${bottom_width};
const int c = (index / ${bottom_height} / ${bottom_width}) % ${channels};
const int h = (index / ${bottom_width}) % ${bottom_height};
const int w = index % ${bottom_width};
const int g = c / (${channels} / ${groups});
${Dtype} value = 0;
#pragma unroll
for (int kh = 0; kh < ${kernel_h}; ++kh) {
#pragma unroll
for (int kw = 0; kw < ${kernel_w}; ++kw) {
const int h_out_s = h + ${pad_h} - kh * ${dilation_h};
const int w_out_s = w + ${pad_w} - kw * ${dilation_w};
if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) {
const int h_out = h_out_s / ${stride_h};
const int w_out = w_out_s / ${stride_w};
if ((h_out >= 0) && (h_out < ${top_height})
&& (w_out >= 0) && (w_out < ${top_width})) {
const int offset = ((n * ${channels} + c) * ${top_height} + h_out)
* ${top_width} + w_out;
const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h_out)
* ${top_width} + w_out;
value += weight_data[offset_weight] * top_diff[offset];
}
}
}
}
bottom_diff[index] = value;
}
}
'''
_involution_kernel_backward_grad_weight = kernel_loop + '''
extern "C"
__global__ void involution_backward_grad_weight_kernel(
const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* const buffer_data) {
CUDA_KERNEL_LOOP(index, ${nthreads}) {
const int h = (index / ${top_width}) % ${top_height};
const int w = index % ${top_width};
const int kh = (index / ${kernel_w} / ${top_height} / ${top_width})
% ${kernel_h};
const int kw = (index / ${top_height} / ${top_width}) % ${kernel_w};
const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h};
const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w};
if ((h_in >= 0) && (h_in < ${bottom_height})
&& (w_in >= 0) && (w_in < ${bottom_width})) {
const int g = (index / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${groups};
const int n = (index / ${groups} / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${num};
${Dtype} value = 0;
#pragma unroll
for (int c = g * (${channels} / ${groups}); c < (g + 1) * (${channels} / ${groups}); ++c) {
const int top_offset = ((n * ${channels} + c) * ${top_height} + h)
* ${top_width} + w;
const int bottom_offset = ((n * ${channels} + c) * ${bottom_height} + h_in)
* ${bottom_width} + w_in;
value += top_diff[top_offset] * bottom_data[bottom_offset];
}
buffer_data[index] = value;
} else {
buffer_data[index] = 0;
}
}
}
'''
class _involution(Function):
@staticmethod
def forward(ctx, input, weight, stride, padding, dilation):
assert input.dim() == 4 and input.is_cuda
assert weight.dim() == 6 and weight.is_cuda
batch_size, channels, height, width = input.size()
kernel_h, kernel_w = weight.size()[2:4]
output_h = int((height + 2 * padding[0] - (dilation[0] * (kernel_h - 1) + 1)) / stride[0] + 1)
output_w = int((width + 2 * padding[1] - (dilation[1] * (kernel_w - 1) + 1)) / stride[1] + 1)
output = input.new(batch_size, channels, output_h, output_w)
n = output.numel()
with torch.cuda.device_of(input):
f = load_kernel('involution_forward_kernel', _involution_kernel, Dtype=Dtype(input), nthreads=n,
num=batch_size, channels=channels, groups=weight.size()[1],
bottom_height=height, bottom_width=width,
top_height=output_h, top_width=output_w,
kernel_h=kernel_h, kernel_w=kernel_w,
stride_h=stride[0], stride_w=stride[1],
dilation_h=dilation[0], dilation_w=dilation[1],
pad_h=padding[0], pad_w=padding[1])
f(block=(CUDA_NUM_THREADS,1,1),
grid=(GET_BLOCKS(n),1,1),
args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
ctx.save_for_backward(input, weight)
ctx.stride, ctx.padding, ctx.dilation = stride, padding, dilation
return output
@staticmethod
def backward(ctx, grad_output):
assert grad_output.is_cuda and grad_output.is_contiguous()
input, weight = ctx.saved_tensors
stride, padding, dilation = ctx.stride, ctx.padding, ctx.dilation
batch_size, channels, height, width = input.size()
kernel_h, kernel_w = weight.size()[2:4]
output_h, output_w = grad_output.size()[2:]
grad_input, grad_weight = None, None
opt = dict(Dtype=Dtype(grad_output),
num=batch_size, channels=channels, groups=weight.size()[1],
bottom_height=height, bottom_width=width,
top_height=output_h, top_width=output_w,
kernel_h=kernel_h, kernel_w=kernel_w,
stride_h=stride[0], stride_w=stride[1],
dilation_h=dilation[0], dilation_w=dilation[1],
pad_h=padding[0], pad_w=padding[1])
with torch.cuda.device_of(input):
if ctx.needs_input_grad[0]:
grad_input = input.new(input.size())
n = grad_input.numel()
opt['nthreads'] = n
f = load_kernel('involution_backward_grad_input_kernel',
_involution_kernel_backward_grad_input, **opt)
f(block=(CUDA_NUM_THREADS,1,1),
grid=(GET_BLOCKS(n),1,1),
args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
if ctx.needs_input_grad[1]:
grad_weight = weight.new(weight.size())
n = grad_weight.numel()
opt['nthreads'] = n
f = load_kernel('involution_backward_grad_weight_kernel',
_involution_kernel_backward_grad_weight, **opt)
f(block=(CUDA_NUM_THREADS,1,1),
grid=(GET_BLOCKS(n),1,1),
args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
return grad_input, grad_weight, None, None, None
def _involution_cuda(input, weight, bias=None, stride=1, padding=0, dilation=1):
""" involution kernel
"""
assert input.size(0) == weight.size(0)
assert input.size(-2)//stride == weight.size(-2)
assert input.size(-1)//stride == weight.size(-1)
if input.is_cuda:
out = _involution.apply(input, weight, _pair(stride), _pair(padding), _pair(dilation))
if bias is not None:
out += bias.view(1,-1,1,1)
else:
raise NotImplementedError
return out
class involution(nn.Module):
def __init__(self,
channels,
kernel_size,
stride):
super(involution, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.channels = channels
reduction_ratio = 4
self.group_channels = 8
self.groups = self.channels // self.group_channels
self.seblock = nn.Sequential(
nn.Conv2d(in_channels = channels, out_channels = channels // reduction_ratio, kernel_size= 1),
nn.InstanceNorm2d(channels // reduction_ratio, affine=True, momentum=0),
nn.ReLU(),
nn.Conv2d(in_channels = channels // reduction_ratio, out_channels = kernel_size**2 * self.groups, kernel_size= 1)
)
# self.conv1 = ConvModule(
# in_channels=channels,
# out_channels=channels // reduction_ratio,
# kernel_size=1,
# conv_cfg=None,
# norm_cfg=dict(type='BN'),
# act_cfg=dict(type='ReLU'))
# self.conv2 = ConvModule(
# in_channels=channels // reduction_ratio,
# out_channels=kernel_size**2 * self.groups,
# kernel_size=1,
# stride=1,
# conv_cfg=None,
# norm_cfg=None,
# act_cfg=None)
if stride > 1:
self.avgpool = nn.AvgPool2d(stride, stride)
def forward(self, x):
# weight = self.conv2(self.conv1(x if self.stride == 1 else self.avgpool(x)))
weight = self.seblock(x)
b, c, h, w = weight.shape
weight = weight.view(b, self.groups, self.kernel_size, self.kernel_size, h, w)
out = _involution_cuda(x, weight, stride=self.stride, padding=(self.kernel_size-1)//2)
return out
+146
View File
@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Liif.py
# Created Date: Monday October 18th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 10:27:09 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
print("i: %d, n: %d"%(i,n))
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_list):
super().__init__()
layers = []
lastv = in_dim
for hidden in hidden_list:
layers.append(nn.Linear(lastv, hidden))
layers.append(nn.ReLU())
lastv = hidden
layers.append(nn.Linear(lastv, out_dim))
self.layers = nn.Sequential(*layers)
def forward(self, x):
shape = x.shape[:-1]
x = self.layers(x.view(-1, x.shape[-1]))
return x.view(*shape, -1)
class LIIF(nn.Module):
def __init__(self, mlp_in_dim, mlp_out_dim, mlp_hidden_list):
super().__init__()
imnet_in_dim = mlp_in_dim
imnet_in_dim *= 9
imnet_in_dim += 2 # attach coord
imnet_in_dim += 2
self.imnet = MLP(imnet_in_dim, mlp_out_dim, mlp_hidden_list).cuda()
def gen_coord(self, in_shape, output_size):
self.vx_lst = [-1, 1]
self.vy_lst = [-1, 1]
eps_shift = 1e-6
self.image_size=output_size
# field radius (global: [-1, 1])
rx = 2 / in_shape[-2] / 2
ry = 2 / in_shape[-1] / 2
coord = make_coord(output_size,flatten=False) \
.expand(in_shape[0],output_size[0],output_size[1],2) \
.view(in_shape[0],output_size[0]*output_size[1],2)
cell = torch.ones_like(coord)
cell[:, :, 0] *= 2 / coord.shape[-2]
cell[:, :, 1] *= 2 / coord.shape[-1]
feat_coord = make_coord(in_shape[-2:], flatten=False) \
.permute(2, 0, 1) \
.unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:])
areas = []
self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
for vx in self.vx_lst:
for vy in self.vy_lst:
self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone()
self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift
self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift
self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
q_coord = F.grid_sample(
feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone()
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1])
areas.append(area + 1e-9)
tot_area = torch.stack(areas).sum(dim=0)
t = areas[0]; areas[0] = areas[3]; areas[3] = t
t = areas[1]; areas[1] = areas[2]; areas[2] = t
self.area_weights = []
for item in areas:
self.area_weights.append((item / tot_area).unsqueeze(-1).cuda())
self.rel_coord = self.rel_coord.cuda()
self.rel_cell = self.rel_cell.cuda()
self.coord_ = self.coord_.cuda()
def forward(self, feat):
# B K*K*Cin H W
feat = F.unfold(feat, 3, padding=1).view(
feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
preds = []
for vx in [0,1]:
for vy in [0,1]:
q_feat = F.grid_sample(
feat, self.coord_[vx,vy,:,:,:].flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1)
bs, q = self.coord_[0,0,:,:,:].shape[:2]
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
# print("pred shape: ",pred.shape)
preds.append(pred)
ret = 0
for pred, area in zip(preds, self.area_weights):
ret = ret + pred * area
return ret.permute(0, 2, 1).view(-1,3,self.image_size[0],self.image_size[1])
+156
View File
@@ -0,0 +1,156 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Liif.py
# Created Date: Monday October 18th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 4:26:26 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
print("i: %d, n: %d"%(i,n))
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_list):
super().__init__()
layers = []
lastv = in_dim
for hidden in hidden_list:
layers.append(nn.Linear(lastv, hidden))
layers.append(nn.ReLU())
lastv = hidden
layers.append(nn.Linear(lastv, out_dim))
self.layers = nn.Sequential(*layers)
def forward(self, x):
shape = x.shape[:-1]
x = self.layers(x.view(-1, x.shape[-1]))
return x.view(*shape, -1)
class LIIF(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
imnet_in_dim = in_dim
# imnet_in_dim += 2 # attach coord
# imnet_in_dim += 2
self.imnet = nn.Sequential( \
nn.Conv2d(in_channels = imnet_in_dim, out_channels = out_dim, kernel_size= 3,padding=1),
nn.InstanceNorm2d(out_dim, affine=True, momentum=0),
nn.LeakyReLU(),
# nn.Conv2d(in_channels = out_dim, out_channels = out_dim, kernel_size= 3,padding=1),
# nn.InstanceNorm2d(out_dim),
# nn.LeakyReLU(),
)
def gen_coord(self, in_shape, output_size):
self.vx_lst = [-1, 1]
self.vy_lst = [-1, 1]
eps_shift = 1e-6
self.image_size=output_size
# field radius (global: [-1, 1])
rx = 2 / in_shape[-2] / 2
ry = 2 / in_shape[-1] / 2
self.coord = make_coord(output_size,flatten=False) \
.expand(in_shape[0],output_size[0],output_size[1],2)
# cell = torch.ones_like(coord)
# cell[:, :, 0] *= 2 / coord.shape[-2]
# cell[:, :, 1] *= 2 / coord.shape[-1]
# feat_coord = make_coord(in_shape[-2:], flatten=False) \
# .permute(2, 0, 1) \
# .unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:])
# areas = []
# self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
# self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
# self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
# for vx in self.vx_lst:
# for vy in self.vy_lst:
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone()
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift
# self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
# q_coord = F.grid_sample(
# feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1),
# mode='nearest', align_corners=False)[:, :, 0, :] \
# .permute(0, 2, 1)
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone()
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
# area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1])
# areas.append(area + 1e-9)
# tot_area = torch.stack(areas).sum(dim=0)
# t = areas[0]; areas[0] = areas[3]; areas[3] = t
# t = areas[1]; areas[1] = areas[2]; areas[2] = t
# self.area_weights = []
# for item in areas:
# self.area_weights.append((item / tot_area).unsqueeze(-1).cuda())
# self.rel_coord = self.rel_coord.cuda()
# self.rel_cell = self.rel_cell.cuda()
# self.coord_ = self.coord_.cuda()
self.coord = self.coord.cuda()
def forward(self, feat):
# B K*K*Cin H W
# feat = F.unfold(feat, 3, padding=1).view(
# feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
# preds = []
# for vx in [0,1]:
# for vy in [0,1]:
# print("feat shape: ", feat.shape)
# print("coor shape: ", self.coord.shape)
q_feat = F.grid_sample(
feat, self.coord,
mode='bilinear', align_corners=False)
out = self.imnet(q_feat)
# inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1)
# bs, q = self.coord_[0,0,:,:,:].shape[:2]
# pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
# # print("pred shape: ",pred.shape)
# preds.append(pred)
# ret = 0
# for pred, area in zip(preds, self.area_weights):
# ret = ret + pred * area
# print("warp output shape: ",out.shape)
return out
+164
View File
@@ -0,0 +1,164 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Liif.py
# Created Date: Monday October 18th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 8:25:18 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
from components.Involution import involution
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
print("i: %d, n: %d"%(i,n))
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_list):
super().__init__()
layers = []
lastv = in_dim
for hidden in hidden_list:
layers.append(nn.Linear(lastv, hidden))
layers.append(nn.ReLU())
lastv = hidden
layers.append(nn.Linear(lastv, out_dim))
self.layers = nn.Sequential(*layers)
def forward(self, x):
shape = x.shape[:-1]
x = self.layers(x.view(-1, x.shape[-1]))
return x.view(*shape, -1)
class LIIF(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
imnet_in_dim = in_dim
# imnet_in_dim += 2 # attach coord
# imnet_in_dim += 2
self.conv1x1 = nn.Conv2d(in_channels = imnet_in_dim, out_channels = out_dim, kernel_size= 1)
# self.same_padding = nn.ReflectionPad2d(padding_size)
# self.conv = involution(out_dim,5,1)
self.imnet = nn.Sequential( \
# nn.Conv2d(in_channels = imnet_in_dim, out_channels = out_dim, kernel_size= 3,padding=1),
involution(out_dim,5,1),
nn.InstanceNorm2d(out_dim, affine=True, momentum=0),
nn.LeakyReLU(),
# nn.Conv2d(in_channels = out_dim, out_channels = out_dim, kernel_size= 3,padding=1),
# nn.InstanceNorm2d(out_dim),
# nn.LeakyReLU(),
)
def gen_coord(self, in_shape, output_size):
self.vx_lst = [-1, 1]
self.vy_lst = [-1, 1]
eps_shift = 1e-6
self.image_size=output_size
# field radius (global: [-1, 1])
rx = 2 / in_shape[-2] / 2
ry = 2 / in_shape[-1] / 2
self.coord = make_coord(output_size,flatten=False) \
.expand(in_shape[0],output_size[0],output_size[1],2)
# cell = torch.ones_like(coord)
# cell[:, :, 0] *= 2 / coord.shape[-2]
# cell[:, :, 1] *= 2 / coord.shape[-1]
# feat_coord = make_coord(in_shape[-2:], flatten=False) \
# .permute(2, 0, 1) \
# .unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:])
# areas = []
# self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
# self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
# self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
# for vx in self.vx_lst:
# for vy in self.vy_lst:
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone()
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift
# self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
# q_coord = F.grid_sample(
# feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1),
# mode='nearest', align_corners=False)[:, :, 0, :] \
# .permute(0, 2, 1)
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone()
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
# area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1])
# areas.append(area + 1e-9)
# tot_area = torch.stack(areas).sum(dim=0)
# t = areas[0]; areas[0] = areas[3]; areas[3] = t
# t = areas[1]; areas[1] = areas[2]; areas[2] = t
# self.area_weights = []
# for item in areas:
# self.area_weights.append((item / tot_area).unsqueeze(-1).cuda())
# self.rel_coord = self.rel_coord.cuda()
# self.rel_cell = self.rel_cell.cuda()
# self.coord_ = self.coord_.cuda()
self.coord = self.coord.cuda()
def forward(self, feat):
# B K*K*Cin H W
# feat = F.unfold(feat, 3, padding=1).view(
# feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
# preds = []
# for vx in [0,1]:
# for vy in [0,1]:
# print("feat shape: ", feat.shape)
# print("coor shape: ", self.coord.shape)
q_feat = self.conv1x1(feat)
q_feat = F.grid_sample(
q_feat, self.coord,
mode='bilinear', align_corners=False)
out = self.imnet(q_feat)
# inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1)
# bs, q = self.coord_[0,0,:,:,:].shape[:2]
# pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
# # print("pred shape: ",pred.shape)
# preds.append(pred)
# ret = 0
# for pred, area in zip(preds, self.area_weights):
# ret = ret + pred * area
# print("warp output shape: ",out.shape)
return out
+38
View File
@@ -0,0 +1,38 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: ResBlock.py
# Created Date: Monday July 5th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 5th July 2021 12:18:18 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
from torch import nn
class ResBlock(nn.Module):
def __init__(self, in_channel, k_size = 3, stride=1):
super().__init__()
padding_size = int((k_size -1)/2)
self.block = nn.Sequential(
nn.ReflectionPad2d(padding_size),
nn.Conv2d(in_channels = in_channel , out_channels = in_channel , kernel_size= k_size, stride=stride, bias= False),
nn.InstanceNorm2d(in_channel, affine=True, momentum=0),
nn.ReflectionPad2d(padding_size),
nn.Conv2d(in_channels = in_channel , out_channels = in_channel , kernel_size= k_size, stride=stride, bias= False),
nn.InstanceNorm2d(in_channel, affine=True, momentum=0)
)
self.__weights_init__()
def __weights_init__(self):
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
def forward(self, input):
res = input
h = self.block(input)
out = h + res
return out
+14
View File
@@ -0,0 +1,14 @@
import torch
from torch import nn
class Transform_block(nn.Module):
def __init__(self, k_size = 10):
super().__init__()
padding_size = int((k_size -1)/2)
# self.padding = nn.ReplicationPad2d(padding_size)
self.pool = nn.AvgPool2d(k_size, stride=1,padding=padding_size)
def forward(self, input_image):
# h = self.padding(input)
out = self.pool(input_image)
return out
+854
View File
@@ -0,0 +1,854 @@
# -----------------------------------------------------------------------------------
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
# Originally Written by Ze Liu, Modified by Jingyun Liang.
# -----------------------------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
nn.init
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self,
input_resolution,
dim, norm_layer = nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim
return flops
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class SwinIR(nn.Module):
r""" SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(SwinIR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0],
patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
assert self.upscale == 4, 'only support x4 now.'
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
if __name__ == '__main__':
upscale = 4
window_size = 8
height = (1024 // upscale // window_size + 1) * window_size
width = (720 // upscale // window_size + 1) * window_size
model = SwinIR(upscale=2, img_size=(height, width),
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
print(model)
print(height, width, model.flops() / 1e9)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
+45
View File
@@ -0,0 +1,45 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: warp_invo.py
# Created Date: Tuesday October 19th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 11:27:13 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
from torch import nn
from components.Involution import involution
class DeConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size = 3, upsampl_scale = 2, padding="reflect"):
super().__init__()
self.upsampling = nn.UpsamplingNearest2d(scale_factor=upsampl_scale)
padding_size = int((kernel_size -1)/2)
self.conv1x1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size= 1)
# self.same_padding = nn.ReflectionPad2d(padding_size)
if padding.lower() == "reflect":
self.conv = involution(out_channels,5,1)
# self.conv = nn.Sequential(
# nn.ReflectionPad2d(padding_size),
# nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size= kernel_size, bias= False))
# for layer in self.conv:
# if isinstance(layer,nn.Conv2d):
# nn.init.xavier_uniform_(layer.weight)
elif padding.lower() == "zero":
self.conv = involution(out_channels,5,1)
# nn.init.xavier_uniform_(self.conv.weight)
# self.__weights_init__()
# def __weights_init__(self):
# nn.init.xavier_uniform_(self.conv.weight)
def forward(self, input):
h = self.conv1x1(input)
h = self.upsampling(h)
h = self.conv(h)
return h
+36
View File
@@ -0,0 +1,36 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: StyleResize.py
# Created Date: Friday April 17th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Saturday, 18th April 2020 1:39:53 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
from PIL import Image
import torchvision.transforms.functional as F
class StyleResize(object):
def __call__(self, images):
th, tw = images.size # target height, width
if max(th,tw) > 1800:
alpha = 1800. / float(min(th,tw))
h = int(th*alpha)
w = int(tw*alpha)
images = F.resize(images, (h, w))
if max(th,tw) < 800:
# Resize the smallest side of the image to 800px
alpha = 800. / float(min(th,tw))
if alpha < 4.:
h = int(th*alpha)
w = int(tw*alpha)
images = F.resize(images, (h, w))
else:
images = F.resize(images, (800, 800))
return images
def __repr__(self):
return self.__class__.__name__ + '()'
+269
View File
@@ -0,0 +1,269 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: data_loader.py
# Created Date: Saturday April 4th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 5th January 2021 2:12:29 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import glob
import torch
import random
from PIL import Image
from pathlib import Path
from torch.utils import data
import torchvision.datasets as dsets
from torchvision import transforms as T
import torchvision.transforms.functional as F
class StyleResize(object):
def __call__(self, images):
th, tw = images.size # target height, width
if max(th,tw) > 1800:
alpha = 1800. / float(min(th,tw))
h = int(th*alpha)
w = int(tw*alpha)
images = F.resize(images, (h, w))
if max(th,tw) < 800:
# Resize the smallest side of the image to 800px
alpha = 800. / float(min(th,tw))
if alpha < 4.:
h = int(th*alpha)
w = int(tw*alpha)
images = F.resize(images, (h, w))
else:
images = F.resize(images, (800, 800))
return images
def __repr__(self):
return self.__class__.__name__ + '()'
class DataPrefetcher():
def __init__(self, loader):
self.loader = loader
self.dataiter = iter(loader)
self.stream = torch.cuda.Stream()
# self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
# self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.__preload__()
def __preload__(self):
try:
self.content, self.style, self.label = next(self.dataiter)
except StopIteration:
self.dataiter = iter(self.loader)
self.content, self.style, self.label = next(self.dataiter)
with torch.cuda.stream(self.stream):
self.content= self.content.cuda(non_blocking=True)
self.style = self.style.cuda(non_blocking=True)
self.label = self.label.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
# self.next_input = self.next_input.float()
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
content = self.content
style = self.style
label = self.label
self.__preload__()
return content, style, label
def __len__(self):
"""Return the number of images."""
return len(self.loader)
class TotalDataset(data.Dataset):
"""Dataset class for the Artworks dataset and content dataset."""
def __init__(self, content_image_dir,style_image_dir,
selectedContent,selectedStyle,
content_transform,style_transform,
subffix='jpg', random_seed=1234):
"""Initialize and preprocess the CelebA dataset."""
self.content_image_dir= content_image_dir
self.style_image_dir = style_image_dir
self.content_transform= content_transform
self.style_transform = style_transform
self.selectedContent = selectedContent
self.selectedStyle = selectedStyle
self.subffix = subffix
self.content_dataset = []
self.art_dataset = []
self.random_seed= random_seed
self.__preprocess__()
self.num_images = len(self.content_dataset)
self.art_num = len(self.art_dataset)
def __preprocess__(self):
"""Preprocess the Artworks dataset."""
print("processing content images...")
for dir_item in self.selectedContent:
join_path = Path(self.content_image_dir,dir_item)#.replace('/','_'))
if join_path.exists():
print("processing %s"%dir_item)
images = join_path.glob('*.%s'%(self.subffix))
for item in images:
self.content_dataset.append(item)
else:
print("%s dir does not exist!"%dir_item)
label_index = 0
print("processing style images...")
for class_item in self.selectedStyle:
images = Path(self.style_image_dir).glob('%s/*.%s'%(class_item, self.subffix))
for item in images:
self.art_dataset.append([item, label_index])
label_index += 1
random.seed(self.random_seed)
random.shuffle(self.content_dataset)
random.shuffle(self.art_dataset)
# self.dataset = images
print('Finished preprocessing the Art Works dataset, total image number: %d...'%len(self.art_dataset))
print('Finished preprocessing the Content dataset, total image number: %d...'%len(self.content_dataset))
def __getitem__(self, index):
"""Return one image and its corresponding attribute label."""
filename = self.content_dataset[index]
image = Image.open(filename)
content = self.content_transform(image)
art_index = random.randint(0,self.art_num-1)
filename,label = self.art_dataset[art_index]
image = Image.open(filename)
style = self.style_transform(image)
return content,style,label
def __len__(self):
"""Return the number of images."""
return self.num_images
def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
def GetLoader(s_image_dir,c_image_dir,
style_selected_dir, content_selected_dir,
crop_size=178, batch_size=16, num_workers=8,
colorJitterEnable=True, colorConfig={"brightness":0.05,"contrast":0.05,"saturation":0.05,"hue":0.05}):
"""Build and return a data loader."""
s_transforms = []
c_transforms = []
s_transforms.append(T.Resize(768))
# s_transforms.append(T.Resize(900))
c_transforms.append(T.Resize(768))
s_transforms.append(T.RandomCrop(crop_size,pad_if_needed=True,padding_mode='reflect'))
c_transforms.append(T.RandomCrop(crop_size))
s_transforms.append(T.RandomHorizontalFlip())
c_transforms.append(T.RandomHorizontalFlip())
s_transforms.append(T.RandomVerticalFlip())
c_transforms.append(T.RandomVerticalFlip())
if colorJitterEnable:
if colorConfig is not None:
print("Enable color jitter!")
colorBrightness = colorConfig["brightness"]
colorContrast = colorConfig["contrast"]
colorSaturation = colorConfig["saturation"]
colorHue = (-colorConfig["hue"],colorConfig["hue"])
s_transforms.append(T.ColorJitter(brightness=colorBrightness,\
contrast=colorContrast,saturation=colorSaturation, hue=colorHue))
c_transforms.append(T.ColorJitter(brightness=colorBrightness,\
contrast=colorContrast,saturation=colorSaturation, hue=colorHue))
s_transforms.append(T.ToTensor())
c_transforms.append(T.ToTensor())
s_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
s_transforms = T.Compose(s_transforms)
c_transforms = T.Compose(c_transforms)
content_dataset = TotalDataset(c_image_dir,s_image_dir, content_selected_dir, style_selected_dir
, c_transforms,s_transforms)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
prefetcher = DataPrefetcher(content_data_loader)
return prefetcher
def GetValiDataTensors(
image_dir=None,
selected_imgs=[],
crop_size=178,
mean = (0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)
):
transforms = []
transforms.append(T.Resize(768))
transforms.append(T.RandomCrop(crop_size,pad_if_needed=True,padding_mode='reflect'))
transforms.append(T.ToTensor())
transforms.append(T.Normalize(mean=mean, std=std))
transforms = T.Compose(transforms)
result_img = []
print("Start to read validation data......")
if len(selected_imgs) != 0:
for s_img in selected_imgs:
if image_dir == None:
temp_img = s_img
else:
temp_img = os.path.join(image_dir, s_img)
temp_img = Image.open(temp_img)
temp_img = transforms(temp_img).cuda().unsqueeze(0)
result_img.append(temp_img)
else:
s_imgs = glob.glob(os.path.join(image_dir, '*.jpg'))
s_imgs = s_imgs + glob.glob(os.path.join(image_dir, '*.png'))
for s_img in s_imgs:
temp_img = os.path.join(image_dir, s_img)
temp_img = Image.open(temp_img)
temp_img = transforms(temp_img).cuda().unsqueeze(0)
result_img.append(temp_img)
print("Finish to read validation data......")
print("Total validation images: %d"%len(result_img))
return result_img
def ScanAbnormalImg(image_dir, selected_imgs):
"""Scan the dataset, this function is designed to exclude or remove the non-RGB images."""
print("processing images...")
subffix = "jpg"
for dir_item in selected_imgs:
join_path = Path(image_dir,dir_item)#.replace('/','_'))
if join_path.exists():
print("processing %s"%dir_item)
images = join_path.glob('*.%s'%(subffix))
for item in images:
# print(str(item.name)[0:6])
# temp = cv2.imread(str(item))
temp = Image.open(item)
# exclude the abnormal images
if temp.mode!="RGB":
print(temp.mode)
print("Found one abnormal image!")
print(item)
os.remove(str(item))
else:
print("%s dir does not exist!"%dir_item)
+253
View File
@@ -0,0 +1,253 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: data_loader_modify.py
# Created Date: Saturday April 4th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 4th July 2021 11:12:42 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import torch
import random
from PIL import Image
from pathlib import Path
from torch.utils import data
import torchvision.datasets as dsets
from torchvision import transforms as T
from data_tools.StyleResize import StyleResize
# from StyleResize import StyleResize
class data_prefetcher():
def __init__(self, loader):
self.loader = loader
self.dataiter = iter(loader)
self.stream = torch.cuda.Stream()
# self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
# self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.preload()
def preload(self):
try:
self.content, self.style, self.label = next(self.dataiter)
except StopIteration:
self.dataiter = iter(self.loader)
self.content, self.style, self.label = next(self.dataiter)
with torch.cuda.stream(self.stream):
self.content= self.content.cuda(non_blocking=True)
self.style = self.style.cuda(non_blocking=True)
self.label = self.label.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
# self.next_input = self.next_input.float()
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
content = self.content
style = self.style
label = self.label
self.preload()
return content, style, label
class TotalDataset(data.Dataset):
"""Dataset class for the Artworks dataset and content dataset."""
def __init__(self, content_image_dir,style_image_dir,
selectedContent,selectedStyle,
content_transform,style_transform,
subffix='jpg', random_seed=1234):
"""Initialize and preprocess the CelebA dataset."""
self.content_image_dir = content_image_dir
self.style_image_dir = style_image_dir
self.content_transform = content_transform
self.style_transform = style_transform
self.selectedContent = selectedContent
self.selectedStyle = selectedStyle
self.subffix = subffix
self.content_dataset = []
self.art_dataset = []
self.random_seed = random_seed
self.preprocess()
self.num_images = len(self.content_dataset)
self.art_num = len(self.art_dataset)
def preprocess(self):
"""Preprocess the Artworks dataset."""
print("processing content images...")
for dir_item in self.selectedContent:
join_path = Path(self.content_image_dir,dir_item)
if join_path.exists():
print("processing %s"%dir_item,end='\r')
images = join_path.glob('*.%s'%(self.subffix))
for item in images:
self.content_dataset.append(item)
else:
print("%s dir does not exist!"%dir_item,end='\r')
label_index = 0
print("processing style images...")
for class_item in self.selectedStyle:
images = Path(self.style_image_dir).glob('%s/*.%s'%(class_item, self.subffix))
for item in images:
self.art_dataset.append([item, label_index])
label_index += 1
random.seed(self.random_seed)
random.shuffle(self.content_dataset)
random.shuffle(self.art_dataset)
# self.dataset = images
print('Finished preprocessing the Art Works dataset, total image number: %d...'%len(self.art_dataset))
print('Finished preprocessing the Content dataset, total image number: %d...'%len(self.content_dataset))
def __getitem__(self, index):
"""Return one image and its corresponding attribute label."""
filename = self.content_dataset[index]
image = Image.open(filename)
content = self.content_transform(image)
art_index = random.randint(0,self.art_num-1)
filename,label = self.art_dataset[art_index]
image = Image.open(filename)
style = self.style_transform(image)
return content,style,label
def __len__(self):
"""Return the number of images."""
return self.num_images
def GetLoader( dataset_roots,
batch_size=16,
crop_size=512,
**kwargs
):
"""Build and return a data loader."""
if not kwargs:
a = "Input params error!"
raise ValueError(print(a))
colorJitterEnable = kwargs["color_jitter"]
colorConfig = kwargs["color_config"]
num_workers = kwargs["dataloader_workers"]
num_workers = kwargs["dataloader_workers"]
place365_root = dataset_roots["Place365_big"]
wikiart_root = dataset_roots["WikiArt"]
selected_c_dir = kwargs["selected_content_dir"]
selected_s_dir = kwargs["selected_style_dir"]
random_seed = kwargs["random_seed"]
s_transforms = []
c_transforms = []
s_transforms.append(StyleResize())
# s_transforms.append(T.Resize(900))
c_transforms.append(T.Resize(900))
s_transforms.append(T.RandomCrop(crop_size, pad_if_needed=True, padding_mode='reflect'))
c_transforms.append(T.RandomCrop(crop_size))
s_transforms.append(T.RandomHorizontalFlip())
c_transforms.append(T.RandomHorizontalFlip())
s_transforms.append(T.RandomVerticalFlip())
c_transforms.append(T.RandomVerticalFlip())
if colorJitterEnable:
if colorConfig is not None:
print("Enable color jitter!")
colorBrightness = colorConfig["brightness"]
colorContrast = colorConfig["contrast"]
colorSaturation = colorConfig["saturation"]
colorHue = (-colorConfig["hue"],colorConfig["hue"])
s_transforms.append(T.ColorJitter(brightness=colorBrightness,\
contrast=colorContrast,saturation=colorSaturation, hue=colorHue))
c_transforms.append(T.ColorJitter(brightness=colorBrightness,\
contrast=colorContrast,saturation=colorSaturation, hue=colorHue))
s_transforms.append(T.ToTensor())
c_transforms.append(T.ToTensor())
s_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
s_transforms = T.Compose(s_transforms)
c_transforms = T.Compose(c_transforms)
content_dataset = TotalDataset(place365_root,wikiart_root,
selected_c_dir, selected_s_dir,
c_transforms, s_transforms, "jpg", random_seed)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
prefetcher = data_prefetcher(content_data_loader)
return prefetcher
def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
if __name__ == "__main__":
from torchvision.utils import save_image
style_class = ["vangogh","picasso","samuel"]
categories_names = \
['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
'b/building_facade', 'c/cemetery']
s_datapath = "D:\\F_Disk\\data_set\\Art_Data\\data_art_backup"
c_datapath = "D:\\Downloads\\data_large"
savepath = "D:\\PatchFace\\PleaseWork\\multi-style-gan\\StyleTransfer\\dataloader_test"
imsize = 512
s_datasetloader= getLoader(s_datapath,c_datapath,
style_class, categories_names,
crop_size=imsize, batch_size=16, num_workers=4)
wocao = iter(s_datasetloader)
for i in range(500):
print("new batch")
s_image,c_image,label = next(wocao)
print(label)
# print(label)
# saved_image1 = torch.cat([denorm(image.data),denorm(hahh.data)],3)
# save_image(denorm(image), "%s\\%d-label-%d.jpg"%(savepath,i), nrow=1, padding=1)
pass
# import cv2
# import os
# for dir_item in categories_names:
# join_path = Path(contentdatapath,dir_item)
# if join_path.exists():
# print("processing %s"%dir_item,end='\r')
# images = join_path.glob('*.%s'%("jpg"))
# for item in images:
# temp_path = str(item)
# # temp = cv2.imread(temp_path)
# temp = Image.open(temp_path)
# if temp.layers<3:
# print("remove broken image...")
# print("image name:%s"%temp_path)
# del temp
# os.remove(item)
+223
View File
@@ -0,0 +1,223 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: data_loader_modify.py
# Created Date: Saturday April 4th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 11th October 2021 12:17:58 am
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import torch
import random
from PIL import Image
from pathlib import Path
from torch.utils import data
import torchvision.datasets as dsets
from torchvision import transforms as T
from data_tools.StyleResize import StyleResize
# from StyleResize import StyleResize
class data_prefetcher():
def __init__(self, loader):
self.loader = loader
self.dataiter = iter(loader)
self.stream = torch.cuda.Stream()
# self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
# self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.num_images = len(loader)
self.preload()
def preload(self):
try:
self.content = next(self.dataiter)
except StopIteration:
self.dataiter = iter(self.loader)
self.content = next(self.dataiter)
with torch.cuda.stream(self.stream):
self.content= self.content.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
# self.next_input = self.next_input.float()
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
content = self.content
self.preload()
return content
def __len__(self):
"""Return the number of images."""
return self.num_images
class Place365Dataset(data.Dataset):
"""Dataset class for the Artworks dataset and content dataset."""
def __init__(self,
content_image_dir,
selectedContent,
content_transform,
subffix='jpg',
random_seed=1234):
"""Initialize and preprocess the CelebA dataset."""
self.content_image_dir = content_image_dir
self.content_transform = content_transform
self.selectedContent = selectedContent
self.subffix = subffix
self.content_dataset = []
self.random_seed = random_seed
self.preprocess()
self.num_images = len(self.content_dataset)
def preprocess(self):
"""Preprocess the Artworks dataset."""
print("processing content images...")
for dir_item in self.selectedContent:
join_path = Path(self.content_image_dir,dir_item)
if join_path.exists():
print("processing %s"%dir_item,end='\r')
images = join_path.glob('*.%s'%(self.subffix))
for item in images:
self.content_dataset.append(item)
else:
print("%s dir does not exist!"%dir_item,end='\r')
random.seed(self.random_seed)
random.shuffle(self.content_dataset)
print('Finished preprocessing the Content dataset, total image number: %d...'%len(self.content_dataset))
def __getitem__(self, index):
"""Return one image and its corresponding attribute label."""
filename = self.content_dataset[index]
image = Image.open(filename)
content = self.content_transform(image)
return content
def __len__(self):
"""Return the number of images."""
return self.num_images
def GetLoader( dataset_roots,
batch_size=16,
crop_size=512,
**kwargs
):
"""Build and return a data loader."""
if not kwargs:
a = "Input params error!"
raise ValueError(print(a))
colorJitterEnable = kwargs["color_jitter"]
colorConfig = kwargs["color_config"]
num_workers = kwargs["dataloader_workers"]
num_workers = kwargs["dataloader_workers"]
place365_root = dataset_roots["Place365_big"]
selected_c_dir = kwargs["selected_content_dir"]
random_seed = kwargs["random_seed"]
c_transforms = []
# s_transforms.append(T.Resize(900))
c_transforms.append(T.Resize(900))
c_transforms.append(T.RandomCrop(crop_size))
c_transforms.append(T.RandomHorizontalFlip())
c_transforms.append(T.RandomVerticalFlip())
if colorJitterEnable:
if colorConfig is not None:
print("Enable color jitter!")
colorBrightness = colorConfig["brightness"]
colorContrast = colorConfig["contrast"]
colorSaturation = colorConfig["saturation"]
colorHue = (-colorConfig["hue"],colorConfig["hue"])
c_transforms.append(T.ColorJitter(brightness=colorBrightness,\
contrast=colorContrast,saturation=colorSaturation, hue=colorHue))
c_transforms.append(T.ToTensor())
c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
c_transforms = T.Compose(c_transforms)
content_dataset = Place365Dataset(
place365_root,
selected_c_dir,
c_transforms,
"jpg",
random_seed)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
prefetcher = data_prefetcher(content_data_loader)
return prefetcher
def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
if __name__ == "__main__":
from torchvision.utils import save_image
style_class = ["vangogh","picasso","samuel"]
categories_names = \
['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
'b/building_facade', 'c/cemetery']
s_datapath = "D:\\F_Disk\\data_set\\Art_Data\\data_art_backup"
c_datapath = "D:\\Downloads\\data_large"
savepath = "D:\\PatchFace\\PleaseWork\\multi-style-gan\\StyleTransfer\\dataloader_test"
imsize = 512
s_datasetloader= getLoader(s_datapath,c_datapath,
style_class, categories_names,
crop_size=imsize, batch_size=16, num_workers=4)
wocao = iter(s_datasetloader)
for i in range(500):
print("new batch")
s_image,c_image,label = next(wocao)
print(label)
# print(label)
# saved_image1 = torch.cat([denorm(image.data),denorm(hahh.data)],3)
# save_image(denorm(image), "%s\\%d-label-%d.jpg"%(savepath,i), nrow=1, padding=1)
pass
# import cv2
# import os
# for dir_item in categories_names:
# join_path = Path(contentdatapath,dir_item)
# if join_path.exists():
# print("processing %s"%dir_item,end='\r')
# images = join_path.glob('*.%s'%("jpg"))
# for item in images:
# temp_path = str(item)
# # temp = cv2.imread(temp_path)
# temp = Image.open(temp_path)
# if temp.layers<3:
# print("remove broken image...")
# print("image name:%s"%temp_path)
# del temp
# os.remove(item)
+81
View File
@@ -0,0 +1,81 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: eval_dataloader_DIV2K.py
# Created Date: Tuesday January 12th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 8:29:51 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import glob
import torch
class TestDataset:
def __init__( self,
path,
batch_size = 16,
subffix=['png','jpg']):
"""Initialize and preprocess the setX dataset."""
self.path = path
self.subffix = subffix
self.dataset = []
self.pointer = 0
self.batch_size = batch_size
self.__preprocess__()
self.num_images = len(self.dataset)
def __preprocess__(self):
"""Preprocess the SetX dataset."""
print("processing content images...")
for i_suf in self.subffix:
temp_path = os.path.join(self.path,'*.%s'%(i_suf))
images = glob.glob(temp_path)
for item in images:
file_name = os.path.basename(item)
file_name = os.path.splitext(file_name)
file_name = file_name[0]
# lr_name = os.path.join(set5lr_path, file_name)
self.dataset.append([item,file_name])
# self.dataset = images
print('Finished preprocessing the content dataset, total image number: %d...'%len(self.dataset))
def __call__(self):
"""Return one batch images."""
if self.pointer>=self.num_images:
self.pointer = 0
a = "The end of the story!"
raise StopIteration(print(a))
elif (self.pointer+self.batch_size) > self.num_images:
end = self.num_images
else:
end = self.pointer+self.batch_size
for i in range(self.pointer, end):
filename = self.dataset[i][0]
hr_img = cv2.imread(filename)
hr_img = cv2.cvtColor(hr_img,cv2.COLOR_BGR2RGB)
hr_img = hr_img.transpose((2,0,1))#.astype(np.float)
hr_img = torch.from_numpy(hr_img)
hr_img = hr_img/255.0
hr_img = 2 * (hr_img - 0.5)
if (i-self.pointer) == 0:
hr_ls = hr_img.unsqueeze(0)
nm_ls = [self.dataset[i][1],]
else:
hr_ls = torch.cat((hr_ls,hr_img.unsqueeze(0)),0)
nm_ls += [self.dataset[i][1],]
self.pointer = end
return hr_ls, nm_ls
def __len__(self):
return self.num_images
def __repr__(self):
return self.__class__.__name__ + ' (' + self.path + ')'
+248
View File
@@ -0,0 +1,248 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: PerceptualLoss.py
# Created Date: Wednesday January 13th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Saturday, 6th March 2021 4:42:26 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn as nn
from torch.nn import functional as F
from torchvision.models import vgg as vgg
from collections import OrderedDict
NAMES = {
'vgg11': [
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1',
'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1',
'conv5_2', 'relu5_2', 'pool5'
],
'vgg13': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
],
'vgg16': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1',
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
'pool5'
],
'vgg19': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4',
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4',
'pool5'
]
}
def insert_bn(names):
"""Insert bn layer after each conv.
Args:
names (list): The list of layer names.
Returns:
list: The list of layer names with bn layers.
"""
names_bn = []
for name in names:
names_bn.append(name)
if 'conv' in name:
position = name.replace('conv', '')
names_bn.append('bn' + position)
return names_bn
class VGGFeatureExtractor(nn.Module):
"""VGG network for feature extraction.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): Forward function returns the corresponding
features according to the layer_name_list.
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image. Importantly,
the input feature must in the range [0, 1]. Default: True.
requires_grad (bool): If true, the parameters of VGG network will be
optimized. Default: False.
remove_pooling (bool): If true, the max pooling operations in VGG net
will be removed. Default: False.
pooling_stride (int): The stride of max pooling operation. Default: 2.
"""
def __init__(self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
requires_grad=False,
remove_pooling=False,
pooling_stride=2):
super(VGGFeatureExtractor, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
self.names = NAMES[vgg_type.replace('_bn', '')]
if 'bn' in vgg_type:
self.names = insert_bn(self.names)
# only borrow layers that will be used to avoid unused params
max_idx = 0
for v in layer_name_list:
idx = self.names.index(v)
if idx > max_idx:
max_idx = idx
features = getattr(vgg,
vgg_type)(pretrained=True).features[:max_idx + 1]
modified_net = OrderedDict()
for k, v in zip(self.names, features):
if 'pool' in k:
# if remove_pooling is true, pooling operation will be removed
if remove_pooling:
continue
else:
# in some cases, we may want to change the default stride
modified_net[k] = nn.MaxPool2d(
kernel_size=2, stride=pooling_stride)
else:
modified_net[k] = v
self.vgg_net = nn.Sequential(modified_net)
if not requires_grad:
self.vgg_net.eval()
for param in self.parameters():
param.requires_grad = False
else:
self.vgg_net.train()
for param in self.parameters():
param.requires_grad = True
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer(
'mean',
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [0, 1]
self.register_buffer(
'std',
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for key, layer in self.vgg_net._modules.items():
x = layer(x)
if key in self.layer_name_list:
output[key] = x.clone()
return output
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'conv5_4': 1.}, which means the conv5_4
feature layer (before relu5_4) will be extracted with weight
1.0 in calculting losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 0.
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
Default: False.
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
def __init__(self,
layer_weights,
vgg_type='vgg19',
use_input_norm=True,
perceptual_weight=1.0,
criterion='l1'):
super(PerceptualLoss, self).__init__()
self.perceptual_weight = perceptual_weight
self.layer_weights = layer_weights
self.vgg = VGGFeatureExtractor(
layer_name_list=list(layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm)
self.criterion_type = criterion
if self.criterion_type == 'l1':
self.criterion = torch.nn.L1Loss()
elif self.criterion_type == 'l2':
self.criterion = torch.nn.L2loss()
else:
raise NotImplementedError(
f'{criterion} criterion has not been supported.')
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
# calculate perceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for k in x_features.keys():
percep_loss += self.criterion(
x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss *= self.perceptual_weight
else:
percep_loss = None
return percep_loss
+54
View File
@@ -0,0 +1,54 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: SliceWassersteinDistance.py
# Created Date: Tuesday October 12th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 3:11:23 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn
import torch.nn.functional as F
class SWD(nn.Module):
""" Slicing layer: computes projections and returns sorted vector """
def __init__(self, channel, direction_num=16):
super().__init__()
# Number of directions
self.direc_num = direction_num
self.channel = channel
self.seed = nn.Parameter(torch.normal(mean=0.0, std=torch.ones(self.direc_num, self.channel)),requires_grad=False)
def update(self):
""" Update random directions """
# Generate random directions
self.seed.normal_()
# norm = self.directions.norm(dim=-1,keepdim=True)
self.directions = F.normalize(self.seed)
# Normalize directions
# self.directions = self.directions/norm
# print("self.directions shape:", self.directions.shape)
# print("self.directions:", self.directions)
def forward(self, input):
""" Implementation of figure 2 """
input = input.flatten(-2)
sliced = self.directions @ input
sliced, _ = sliced.sort()
return sliced
if __name__ == "__main__":
wocao = torch.ones((4,3,5,5))
slice = SWD(wocao.shape[1])
slice.update()
wocao_slice = slice(wocao)
print(wocao_slice.shape)
print(wocao_slice)
+266
View File
@@ -0,0 +1,266 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: test.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 7:44:02 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import argparse
from torch.backends import cudnn
from utilities.json_config import readConfig
from utilities.reporter import Reporter
from utilities.sshupload import fileUploaderClass
def str2bool(v):
return v.lower() in ('true')
####################################################################################
# To configure the seting of training\finetune\test
#
####################################################################################
def getParameters():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('-v', '--version', type=str, default='fastnst_3',
help="version name for train, test, finetune")
parser.add_argument('-c', '--cuda', type=int, default=-1) # >0 if it is set as -1, program will use CPU
parser.add_argument('-e', '--checkpoint_epoch', type=int, default=19,
help="checkpoint epoch for test phase or finetune phase")
# test
parser.add_argument('-t', '--test_script_name', type=str, default='FastNST')
parser.add_argument('-b', '--batch_size', type=int, default=1)
parser.add_argument('-n', '--node_name', type=str, default='localhost',
choices=['localhost', '4card','8card','new4card'])
parser.add_argument('--save_test_result', action='store_false')
parser.add_argument('--test_dataloader', type=str, default='dir')
parser.add_argument('-p', '--test_data_path', type=str, default='G:\\UltraHighStyleTransfer\\benchmark')
parser.add_argument('--use_specified_data', action='store_true')
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
# # logs (does not to be changed in most time)
# parser.add_argument('--dataloader_workers', type=int, default=6)
# parser.add_argument('--use_tensorboard', type=str2bool, default='True',
# choices=['True', 'False'], help='enable the tensorboard')
# parser.add_argument('--log_step', type=int, default=100)
# parser.add_argument('--sample_step', type=int, default=100)
# # template (onece editing finished, it should be deleted)
# parser.add_argument('--str_parameter', type=str, default="default", help='str parameter')
# parser.add_argument('--str_parameter_choices', type=str,
# default="default", choices=['choice1', 'choice2','choice3'], help='str parameter with choices list')
# parser.add_argument('--int_parameter', type=int, default=0, help='int parameter')
# parser.add_argument('--float_parameter', type=float, default=0.0, help='float parameter')
# parser.add_argument('--bool_parameter', type=str2bool, default='True', choices=['True', 'False'], help='bool parameter')
# parser.add_argument('--list_str_parameter', type=str, nargs='+', default=["element1","element2"], help='str list parameter')
# parser.add_argument('--list_int_parameter', type=int, nargs='+', default=[0,1], help='int list parameter')
return parser.parse_args()
ignoreKey = [
"dataloader_workers",
"log_root_path",
"project_root",
"project_summary",
"project_checkpoints",
"project_samples",
"project_scripts",
"reporter_path",
"use_specified_data",
"specified_data_paths",
"dataset_path","cuda",
"test_script_name",
"test_dataloader",
"test_dataset_path",
"save_test_result",
"test_batch_size",
"node_name",
"checkpoint_epoch",
"test_dataset_path",
"test_dataset_name",
"use_my_test_date"]
####################################################################################
# This function will create the related directories before the
# training\fintune\test starts
# Your_log_root (version name)
# |---summary/...
# |---samples/... (save evaluated images)
# |---checkpoints/...
# |---scripts/...
#
####################################################################################
def createDirs(sys_state):
# the base dir
if not os.path.exists(sys_state["log_root_path"]):
os.makedirs(sys_state["log_root_path"])
# create dirs
sys_state["project_root"] = os.path.join(sys_state["log_root_path"],
sys_state["version"])
project_root = sys_state["project_root"]
if not os.path.exists(project_root):
os.makedirs(project_root)
sys_state["project_summary"] = os.path.join(project_root, "summary")
if not os.path.exists(sys_state["project_summary"]):
os.makedirs(sys_state["project_summary"])
sys_state["project_checkpoints"] = os.path.join(project_root, "checkpoints")
if not os.path.exists(sys_state["project_checkpoints"]):
os.makedirs(sys_state["project_checkpoints"])
sys_state["project_samples"] = os.path.join(project_root, "samples")
if not os.path.exists(sys_state["project_samples"]):
os.makedirs(sys_state["project_samples"])
sys_state["project_scripts"] = os.path.join(project_root, "scripts")
if not os.path.exists(sys_state["project_scripts"]):
os.makedirs(sys_state["project_scripts"])
sys_state["reporter_path"] = os.path.join(project_root,sys_state["version"]+"_report")
def main():
config = getParameters()
# speed up the program
cudnn.benchmark = True
sys_state = {}
# set the GPU number
if config.cuda >= 0:
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.cuda)
# read system environment paths
env_config = readConfig('env/env.json')
env_config = env_config["path"]
# obtain all configurations in argparse
config_dic = vars(config)
for config_key in config_dic.keys():
sys_state[config_key] = config_dic[config_key]
#=======================Test Phase=========================#
# TODO modify below lines to obtain the configuration
sys_state["log_root_path"] = env_config["train_log_root"]
sys_state["test_samples_path"] = os.path.join(env_config["test_log_root"],
sys_state["version"] , "samples")
# if not config.use_my_test_date:
# print("Use public benchmark...")
# data_key = config.test_dataset_name.lower()
# sys_state["test_dataset_path"] = env_config["test_dataset_paths"][data_key]
# if config.test_dataset_name.lower() == "set5" or config.test_dataset_name.lower() =="set14":
# sys_state["test_dataloader"] = "setx"
# else:
# sys_state["test_dataloader"] = config.test_dataset_name.lower()
# sys_state["test_dataset_name"] = config.test_dataset_name
if not os.path.exists(sys_state["test_samples_path"]):
os.makedirs(sys_state["test_samples_path"])
# Create dirs
createDirs(sys_state)
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
# Read model_config.json from remote machine
if sys_state["node_name"]!="localhost":
remote_mac = env_config["remote_machine"]
nodeinf = remote_mac[sys_state["node_name"]]
print("ready to fetch related files from server: %s ......"%nodeinf["ip"])
uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"])
remotebase = os.path.join(nodeinf['base_path'],"train_logs",sys_state["version"]).replace('\\','/')
# Get the config.json
print("ready to get the config.json...")
remoteFile = os.path.join(remotebase, env_config["config_json_name"]).replace('\\','/')
localFile = config_json
ssh_state = uploader.sshScpGet(remoteFile, localFile)
if not ssh_state:
raise Exception(print("Get file %s failed! config.json does not exist!"%remoteFile))
print("success get the config.json from server %s"%nodeinf['ip'])
# Read model_config.json
json_obj = readConfig(config_json)
for item in json_obj.items():
if item[0] in ignoreKey:
pass
else:
sys_state[item[0]] = item[1]
# Read scripts from remote machine
if sys_state["node_name"]!="localhost":
# # Get scripts
# remoteFile = os.path.join(remotebase, "scripts", sys_state["gScriptName"]+".py").replace('\\','/')
# localFile = os.path.join(sys_state["project_scripts"], sys_state["gScriptName"]+".py")
# ssh_state = uploader.sshScpGet(remoteFile, localFile)
# if not ssh_state:
# raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
# print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
# Get checkpoint of generator
localFile = os.path.join(sys_state["project_checkpoints"],
"epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"]))
if not os.path.exists(localFile):
remoteFile = os.path.join(remotebase, "checkpoints",
"epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"])).replace('\\','/')
ssh_state = uploader.sshScpGet(remoteFile, localFile, True)
if not ssh_state:
raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile))
print("Get the checkpoint %s successfully!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"])))
else:
print("%s exists!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"]["generator_name"])))
# TODO get the checkpoint file path
sys_state["ckp_name"] = {}
for data_key in sys_state["checkpoint_names"].keys():
sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"],
"%d_%s.pth"%(sys_state["checkpoint_epoch"],
sys_state["checkpoint_names"][data_key]))
# Get the test configurations
sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"]
# make a reporter
report_path = os.path.join(env_config["test_log_root"], sys_state["version"],
sys_state["version"]+"_report")
reporter = Reporter(report_path)
reporter.writeConfig(sys_state)
# Display the test information
# TODO modify below lines to display your configuration information
moduleName = "test_scripts.tester_" + sys_state["test_script_name"]
print("Start to run test script: {}".format(moduleName))
print("Test version: %s"%sys_state["version"])
print("Test Script Name: %s"%sys_state["test_script_name"])
package = __import__(moduleName, fromlist=True)
testerClass = getattr(package, 'Tester')
tester = testerClass(sys_state,reporter)
tester.test()
if __name__ == '__main__':
main()
+123
View File
@@ -0,0 +1,123 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_commonn.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 8:22:37 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import torch
from utilities.utilities import tensor2img
# from utilities.Reporter import Reporter
from tqdm import tqdm
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
#============build evaluation dataloader==============#
print("Prepare the test dataloader...")
dlModulename = config["test_dataloader"]
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'TestDataset')
dataloader = dataloaderClass(config["test_data_path"],
1,
["png","jpg"])
self.test_loader= dataloader
self.test_iter = len(dataloader)
# if len(dataloader)%config["batch_size"]>0:
# self.test_iter+=1
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
script_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(script_name, fromlist=True)
network_class = getattr(package, class_name)
# TODO replace below lines to define the model framework
self.network = network_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_epoch"],
self.config["checkpoint_names"]["generator_name"]))
self.network.load_state_dict(torch.load(model_path))
# self.network.load_state_dict(torch.load(pathwocao))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
def test(self):
# save_result = self.config["saveTestResult"]
save_dir = self.config["test_samples_path"]
ckp_epoch = self.config["checkpoint_epoch"]
version = self.config["version"]
batch_size = self.config["batch_size"]
win_size = self.config["model_configs"]["g_model"]["module_params"]["window_size"]
# models
self.__init_framework__()
total = len(self.test_loader)
print("total:", total)
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
with torch.no_grad():
for _ in tqdm(range(total)):
contents, img_names = self.test_loader()
B, C, H, W = contents.shape
crop_h = H - H%32
crop_w = W - W%32
crop_s = min(crop_h, crop_w)
contents = contents[:,:,(H//2 - crop_s//2):(crop_s//2 + H//2),
(W//2 - crop_s//2):(crop_s//2 + W//2)]
if self.config["cuda"] >=0:
contents = contents.cuda()
res = self.network(contents, (crop_s, crop_s))
print("res shape:", res.shape)
res = tensor2img(res.cpu())
temp_img = res[0,:,:,:]
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
print(save_dir)
print(img_names[0])
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}.png'.format(
img_names[0], version, ckp_epoch)),temp_img)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
+124
View File
@@ -0,0 +1,124 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: tester_commonn.py
# Created Date: Saturday July 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 4th July 2021 11:32:14 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import time
import torch
from utilities.utilities import tensor2img
# from utilities.Reporter import Reporter
from tqdm import tqdm
class Tester(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
#============build evaluation dataloader==============#
print("Prepare the test dataloader...")
dlModulename = config["test_dataloader"]
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'TestDataset')
dataloader = dataloaderClass(config["test_data_path"],
config["batch_size"],
["png","jpg"])
self.test_loader= dataloader
self.test_iter = len(dataloader)//config["batch_size"]
if len(dataloader)%config["batch_size"]>0:
self.test_iter+=1
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
script_name = "components."+self.config["module_script_name"]
class_name = self.config["class_name"]
package = __import__(script_name, fromlist=True)
network_class = getattr(package, class_name)
n_class = len(self.config["selectedStyleDir"])
# TODO replace below lines to define the model framework
self.network = network_class(self.config["GConvDim"],
self.config["GKS"],
self.config["resNum"],
n_class
#**self.config["module_params"]
)
# print and recorde model structure
self.reporter.writeInfo("Model structure:")
self.reporter.writeModel(self.network.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.network = self.network.cuda()
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
# print(loader1.key())
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"])
# self.network.load_state_dict(torch.load(pathwocao))
print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"]))
def test(self):
# save_result = self.config["saveTestResult"]
save_dir = self.config["test_samples_path"]
ckp_epoch = self.config["checkpoint_epoch"]
version = self.config["version"]
batch_size = self.config["batch_size"]
style_names = self.config["selectedStyleDir"]
n_class = len(style_names)
# models
self.__init_framework__()
condition_labels = torch.ones((n_class, batch_size, 1)).long()
for i in range(n_class):
condition_labels[i,:,:] = condition_labels[i,:,:]*i
if self.config["cuda"] >=0:
condition_labels = condition_labels.cuda()
total = len(self.test_loader)
# Start time
import datetime
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print('Start =================================== test...')
start_time = time.time()
self.network.eval()
with torch.no_grad():
for _ in tqdm(range(total//batch_size)):
contents, img_names = self.test_loader()
for i in range(n_class):
if self.config["cuda"] >=0:
contents = contents.cuda()
res, _ = self.network(contents, condition_labels[i, 0, :])
res = tensor2img(res.cpu())
for t in range(batch_size):
temp_img = res[t,:,:,:]
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format(
img_names[t], version, ckp_epoch, style_names[i])),temp_img)
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Elapsed [{}]".format(elapsed))
+240
View File
@@ -0,0 +1,240 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: train.py
# Created Date: Tuesday April 28th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 8:50:15 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import shutil
import argparse
from torch.backends import cudnn
from utilities.json_config import readConfig, writeConfig
from utilities.reporter import Reporter
from utilities.yaml_config import getConfigYaml
def str2bool(v):
return v.lower() in ('true')
####################################################################################
# To configure the seting of training\finetune\test
#
####################################################################################
def getParameters():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('-v', '--version', type=str, default='liff_warpinvo_0',
help="version name for train, test, finetune")
parser.add_argument('-p', '--phase', type=str, default="train",
choices=['train', 'finetune','debug'],
help="The phase of current project")
parser.add_argument('-c', '--cuda', type=int, default=1) # <0 if it is set as -1, program will use CPU
parser.add_argument('-e', '--checkpoint_epoch', type=int, default=74,
help="checkpoint epoch for test phase or finetune phase")
# training
parser.add_argument('--experiment_description', type=str,
default="尝试使用Liif+Invo作为上采样和降采样的算子,降采样两个DSF算子,上采样两个DSF算子")
parser.add_argument('--train_yaml', type=str, default="train_FastNST_CNN_Resblock.yaml")
# # logs (does not to be changed in most time)
# parser.add_argument('--dataloader_workers', type=int, default=6)
# parser.add_argument('--use_tensorboard', type=str2bool, default='True',
# choices=['True', 'False'], help='enable the tensorboard')
# parser.add_argument('--log_step', type=int, default=100)
# parser.add_argument('--sample_step', type=int, default=100)
# # template (onece editing finished, it should be deleted)
# parser.add_argument('--str_parameter', type=str, default="default", help='str parameter')
# parser.add_argument('--str_parameter_choices', type=str,
# default="default", choices=['choice1', 'choice2','choice3'], help='str parameter with choices list')
# parser.add_argument('--int_parameter', type=int, default=0, help='int parameter')
# parser.add_argument('--float_parameter', type=float, default=0.0, help='float parameter')
# parser.add_argument('--bool_parameter', type=str2bool, default='True', choices=['True', 'False'], help='bool parameter')
# parser.add_argument('--list_str_parameter', type=str, nargs='+', default=["element1","element2"], help='str list parameter')
# parser.add_argument('--list_int_parameter', type=int, nargs='+', default=[0,1], help='int list parameter')
return parser.parse_args()
ignoreKey = [
"dataloader_workers",
"log_root_path",
"project_root",
"project_summary",
"project_checkpoints",
"project_samples",
"project_scripts",
"reporter_path",
"use_specified_data",
"specified_data_paths",
"dataset_path","cuda",
"test_script_name",
"test_dataloader",
"test_dataset_path",
"save_test_result",
"test_batch_size",
"node_name",
"checkpoint_epoch",
"test_dataset_path",
"test_dataset_name",
"use_my_test_date"]
####################################################################################
# This function will create the related directories before the
# training\fintune\test starts
# Your_log_root (version name)
# |---summary/...
# |---samples/... (save evaluated images)
# |---checkpoints/...
# |---scripts/...
#
####################################################################################
def createDirs(sys_state):
# the base dir
if not os.path.exists(sys_state["log_root_path"]):
os.makedirs(sys_state["log_root_path"])
# create dirs
sys_state["project_root"] = os.path.join(sys_state["log_root_path"],
sys_state["version"])
project_root = sys_state["project_root"]
if not os.path.exists(project_root):
os.makedirs(project_root)
sys_state["project_summary"] = os.path.join(project_root, "summary")
if not os.path.exists(sys_state["project_summary"]):
os.makedirs(sys_state["project_summary"])
sys_state["project_checkpoints"] = os.path.join(project_root, "checkpoints")
if not os.path.exists(sys_state["project_checkpoints"]):
os.makedirs(sys_state["project_checkpoints"])
sys_state["project_samples"] = os.path.join(project_root, "samples")
if not os.path.exists(sys_state["project_samples"]):
os.makedirs(sys_state["project_samples"])
sys_state["project_scripts"] = os.path.join(project_root, "scripts")
if not os.path.exists(sys_state["project_scripts"]):
os.makedirs(sys_state["project_scripts"])
sys_state["reporter_path"] = os.path.join(project_root,sys_state["version"]+"_report")
def main():
config = getParameters()
# speed up the program
cudnn.benchmark = True
from utilities.logo_class import logo_class
logo_class.print_group_logo()
sys_state = {}
# set the GPU number
if config.cuda >= 0:
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.cuda)
# read system environment paths
env_config = readConfig('env/env.json')
env_config = env_config["path"]
# obtain all configurations in argparse
config_dic = vars(config)
for config_key in config_dic.keys():
sys_state[config_key] = config_dic[config_key]
#=======================Train Phase=========================#
if config.phase == "train":
# read training configurations from yaml file
ymal_config = getConfigYaml(os.path.join(env_config["train_config_path"], config.train_yaml))
for item in ymal_config.items():
sys_state[item[0]] = item[1]
# create related dirs
sys_state["log_root_path"] = env_config["train_log_root"]
createDirs(sys_state)
# create reporter file
reporter = Reporter(sys_state["reporter_path"])
# save the config json
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
writeConfig(config_json, sys_state)
# save the dependent scripts
# TODO and copy the scripts to the project dir
# save the trainer script into [train_logs_root]\[version name]\scripts\
file1 = os.path.join(env_config["train_scripts_path"],
"trainer_%s.py"%sys_state["train_script_name"])
tgtfile1 = os.path.join(sys_state["project_scripts"],
"trainer_%s.py"%sys_state["train_script_name"])
shutil.copyfile(file1,tgtfile1)
# save the yaml file
file1 = os.path.join(env_config["train_config_path"], config.train_yaml)
tgtfile1 = os.path.join(sys_state["project_scripts"], config.train_yaml)
shutil.copyfile(file1,tgtfile1)
# TODO replace below lines, here to save the critical scripts
#=====================Finetune Phase=====================#
elif config.phase == "finetune":
sys_state["log_root_path"] = env_config["train_log_root"]
sys_state["project_root"] = os.path.join(sys_state["log_root_path"], sys_state["version"])
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
train_config = readConfig(config_json)
for item in train_config.items():
if item[0] in ignoreKey:
pass
else:
sys_state[item[0]] = item[1]
createDirs(sys_state)
reporter = Reporter(sys_state["reporter_path"])
sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"]
# get the dataset path
sys_state["dataset_paths"] = {}
for data_key in env_config["dataset_paths"].keys():
sys_state["dataset_paths"][data_key] = env_config["dataset_paths"][data_key]
# display the training information
moduleName = "train_scripts.trainer_" + sys_state["train_script_name"]
if config.phase == "finetune":
moduleName = sys_state["com_base"] + "trainer_" + sys_state["train_script_name"]
# print some important information
# TODO
print("Start to run training script: {}".format(moduleName))
print("Traning version: %s"%sys_state["version"])
print("Dataloader Name: %s"%sys_state["dataloader"])
# print("Image Size: %d"%sys_state["imsize"])
print("Batch size: %d"%(sys_state["batch_size"]))
# Load the training script and start to train
reporter.writeConfig(sys_state)
package = __import__(moduleName, fromlist=True)
trainerClass= getattr(package, 'Trainer')
trainer = trainerClass(sys_state, reporter)
trainer.train()
if __name__ == '__main__':
main()
+307
View File
@@ -0,0 +1,307 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_condition_SN_multiscale.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 2:18:26 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from components.Transform import Transform_block
from utilities.utilities import denorm, Gram, img2tensor255
from pretrained_weights.vgg import VGG16
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
dataloader = self.dataloader_class(self.train_dataset,
config["batch_size"],
config["imcrop_size"],
**config["dataset_params"])
self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
# TODO To save the important scripts
# save the yaml file
import shutil
file1 = os.path.join("components", "%s.py"%model_config["g_model"]["script"])
tgtfile1 = os.path.join(self.config["project_scripts"], "%s.py"%model_config["g_model"]["script"])
shutil.copyfile(file1,tgtfile1)
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
g_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
sample_dir = self.config["project_samples"]
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
MSE_loss = torch.nn.MSELoss()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
# print("prepare the fixed labels...")
# fix_label = [i for i in range(n_class)]
# fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
VGG = VGG16().cuda()
MEAN_VAL = 127.5
SCALE_VAL= 127.5
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
style_tensor = img2tensor255(style_img).cuda()
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_tensor = VGG(style_tensor.expand([batch_size, C, H, W]))
# style_features = VGG(style_tensor)
style_gram = {}
for key, value in style_tensor.items():
style_gram[key] = Gram(value)
del style_tensor
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
s_loss = MSE_loss(Gram(value), style_gram[key])
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(epoch + 1))
self.gen.eval()
with torch.no_grad():
sample = fake_image
saved_image1 = denorm(sample.cpu().data)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
+297
View File
@@ -0,0 +1,297 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_condition_SN_multiscale.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 7:38:36 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from components.Transform import Transform_block
from utilities.utilities import denorm, Gram, img2tensor255crop
from pretrained_weights.vgg import VGG16
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
dataloader = self.dataloader_class(self.train_dataset,
config["batch_size"],
config["imcrop_size"],
**config["dataset_params"])
self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
g_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
sample_dir = self.config["project_samples"]
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
MSE_loss = torch.nn.MSELoss()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
# print("prepare the fixed labels...")
# fix_label = [i for i in range(n_class)]
# fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
VGG = VGG16().cuda()
MEAN_VAL = 127.5
SCALE_VAL= 127.5
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
style_gram = {}
for key, value in style_features.items():
style_gram[key] = Gram(value)
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
s_loss = MSE_loss(Gram(value), style_gram[key])
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(epoch + 1))
self.gen.eval()
with torch.no_grad():
sample = fake_image
saved_image1 = denorm(sample.cpu().data)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
+296
View File
@@ -0,0 +1,296 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_condition_SN_multiscale.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 9:25:13 am
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from utilities.utilities import denorm, Gram, img2tensor255crop
from pretrained_weights.vgg import VGG16
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
dataloader = self.dataloader_class(self.train_dataset,
config["batch_size"],
config["imcrop_size"],
**config["dataset_params"])
self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
g_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
sample_dir = self.config["project_samples"]
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
MSE_loss = torch.nn.MSELoss()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
# print("prepare the fixed labels...")
# fix_label = [i for i in range(n_class)]
# fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
VGG = VGG16().cuda()
MEAN_VAL = 127.5
SCALE_VAL= 127.5
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
style_gram = {}
for key, value in style_features.items():
style_gram[key] = Gram(value)
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
s_loss = MSE_loss(Gram(value), style_gram[key])
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(epoch + 1))
self.gen.eval()
with torch.no_grad():
sample = fake_image
saved_image1 = denorm(sample.cpu().data)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
+300
View File
@@ -0,0 +1,300 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_condition_SN_multiscale.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 2:28:24 am
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from utilities.utilities import denorm, img2tensor255crop
from losses.SliceWassersteinDistance import SWD
from pretrained_weights.vgg import VGG16
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
dataloader = self.dataloader_class(self.train_dataset,
config["batch_size"],
config["imcrop_size"],
**config["dataset_params"])
self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_epoch"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
g_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
swd_dim = self.config["swd_dim"]
sample_dir = self.config["project_samples"]
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
MSE_loss = torch.nn.MSELoss()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
# print("prepare the fixed labels...")
# fix_label = [i for i in range(n_class)]
# fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
VGG = VGG16().cuda()
MEAN_VAL = 127.5
SCALE_VAL= 127.5
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
# swd = SWD()
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
swd_list = {}
for key, value in style_features.items():
swd_list[key] = SWD(value.shape[1],swd_dim).cuda()
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
swd_list[key].update()
s_loss = MSE_loss(swd_list[key](value), swd_list[key](style_features[key]))
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(epoch + 1))
self.gen.eval()
with torch.no_grad():
sample = fake_image
saved_image1 = denorm(sample.cpu().data)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
+382
View File
@@ -0,0 +1,382 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_condition_SN_multiscale.py
# Created Date: Saturday April 18th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 6th July 2021 7:36:42 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from components.Transform import Transform_block
from utilities.utilities import denorm
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.dataloader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
# dataloader = self.dataloader_class(self.train_dataset,
# config["batch_size_list"][0],
# config["imcrop_size_list"][0],
# **config["dataset_params"])
# self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
dscript_name = "components." + model_config["d_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
dscript_name = self.config["com_base"] + model_config["d_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
class_name = model_config["d_model"]["class_name"]
package = __import__(dscript_name, fromlist=True)
dis_class = getattr(package, class_name)
self.dis = dis_class(**model_config["d_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
self.reporter.writeInfo("Discriminator structure:")
self.reporter.writeModel(self.dis.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
self.dis = self.dis.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["discriminator_name"]))
self.dis.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
d_train_opt = self.config['d_optim_config']
g_optim_params = []
d_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
for k, v in self.dis.named_parameters():
if v.requires_grad:
d_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
n_class = len(self.config["selected_style_dir"])
# prep_weights= self.config["layersWeight"]
feature_w = self.config["feature_weight"]
transform_w = self.config["transform_weight"]
d_step = self.config["d_step"]
g_step = self.config["g_step"]
batch_size_list = self.config["batch_size_list"]
switch_epoch_list = self.config["switch_epoch_list"]
imcrop_size_list = self.config["imcrop_size_list"]
sample_dir = self.config["project_samples"]
current_epoch_index = 0
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
Transform = Transform_block().cuda()
L1_loss = torch.nn.L1Loss()
MSE_loss = torch.nn.MSELoss()
Hinge_loss = torch.nn.ReLU().cuda()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
output_size = self.dis.get_outputs_len()
print("prepare the fixed labels...")
fix_label = [i for i in range(n_class)]
fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
for epoch in range(start, total_epoch):
# switch training image size
if epoch in switch_epoch_list:
print('Current epoch: {}'.format(epoch))
print('***Redefining the dataloader for progressive training.***')
print('***Current spatial size is {} and batch size is {}.***'.format(
imcrop_size_list[current_epoch_index], batch_size_list[current_epoch_index]))
del self.train_loader
self.train_loader = self.dataloader_class(self.train_dataset,
batch_size_list[current_epoch_index],
imcrop_size_list[current_epoch_index],
**self.config["dataset_params"])
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // (d_step + g_step)
print("Total step = %d in each epoch"%step_epoch)
current_epoch_index += 1
for step in range(step_epoch):
self.dis.train()
self.gen.train()
# ================== Train D ================== #
# Compute loss with real images
for _ in range(d_step):
content_images,style_images,label = self.train_loader.next()
label = label.long()
d_out = self.dis(style_images,label)
d_loss_real = 0
for i in range(output_size):
temp = Hinge_loss(1 - d_out[i]).mean()
d_loss_real += temp
d_loss_photo = 0
d_out = self.dis(content_images,label)
for i in range(output_size):
temp = Hinge_loss(1 + d_out[i]).mean()
d_loss_photo += temp
fake_image,_= self.gen(content_images,label)
d_out = self.dis(fake_image.detach(),label)
d_loss_fake = 0
for i in range(output_size):
temp = Hinge_loss(1 + d_out[i]).mean()
# temp *= prep_weights[i]
d_loss_fake += temp
# Backward + Optimize
d_loss = d_loss_real + d_loss_photo + d_loss_fake
self.d_optimizer.zero_grad()
d_loss.backward()
self.d_optimizer.step()
# ================== Train G ================== #
for _ in range(g_step):
content_images,_,_ = self.train_loader.next()
fake_image,real_feature = self.gen(content_images,label)
fake_feature = self.gen(fake_image, get_feature=True)
d_out = self.dis(fake_image,label.long())
g_feature_loss = L1_loss(fake_feature,real_feature)
g_transform_loss = MSE_loss(Transform(content_images), Transform(fake_image))
g_loss_fake = 0
for i in range(output_size):
temp = -d_out[i].mean()
# temp *= prep_weights[i]
g_loss_fake += temp
# backward & optimize
g_loss = g_loss_fake + g_feature_loss* feature_w + g_transform_loss* transform_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, d_loss_real: {:.4f}, \\\
d_loss_photo: {:.4f}, d_loss_fake: {:.4f}, g_loss: {:.4f}, g_loss_fake: {:.4f}, \\\
g_feature_loss: {:.4f}, g_transform_loss: {:.4f}".format(self.config["version"],
epoch + 1, total_epoch, elapsed, step + 1, step_epoch,
d_loss.item(), d_loss_real.item(), d_loss_photo.item(),
d_loss_fake.item(), g_loss.item(), g_loss_fake.item(),\
g_feature_loss.item(), g_transform_loss.item())
print(epochinformation)
self.reporter.writeRawInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/d_loss', d_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/d_loss_real', d_loss_real.item(), cum_step)
self.tensorboard_writer.add_scalar('data/d_loss_photo', d_loss_photo.item(), cum_step)
self.tensorboard_writer.add_scalar('data/d_loss_fake', d_loss_fake.item(), cum_step)
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/g_loss_fake', g_loss_fake.item(), cum_step)
self.tensorboard_writer.add_scalar('data/g_feature_loss', g_feature_loss, cum_step)
self.tensorboard_writer.add_scalar('data/g_transform_loss', g_transform_loss, cum_step)
#===============adjust learning rate============#
if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
print("Learning rate decay")
for p in self.optimizer.param_groups:
p['lr'] *= self.config["lr_decay"]
print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.save(self.dis.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["discriminator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(step + 1))
self.gen.eval()
with torch.no_grad():
sample = content_images[0, :, :, :].unsqueeze(0)
saved_image1 = denorm(sample.cpu().data)
for index in range(n_class):
fake_images,_ = self.gen(sample, fix_label[index].unsqueeze(0))
saved_image1 = torch.cat((saved_image1, denorm(fake_images.cpu().data)), 0)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(step + 1)),nrow=3)
+295
View File
@@ -0,0 +1,295 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_naiv512.py
# Created Date: Sunday January 9th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 9th January 2022 12:31:03 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import os
import time
import torch
from torchvision.utils import save_image
from utilities.utilities import denorm, Gram, img2tensor255crop
from pretrained_weights.vgg import VGG16
class Trainer(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
self.dataloader_class = dataloaderClass
dataloader = self.dataloader_class(self.train_dataset,
config["batch_size"],
config["imcrop_size"],
**config["dataset_params"])
self.train_loader= dataloader
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
#==============build tensorboard=================#
if self.config["use_tensorboard"]:
from utilities.utilities import build_tensorboard
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
print("build models...")
# TODO [import models here]
model_config = self.config["model_configs"]
if self.config["phase"] == "train":
gscript_name = "components." + model_config["g_model"]["script"]
elif self.config["phase"] == "finetune":
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
class_name = model_config["g_model"]["class_name"]
package = __import__(gscript_name, fromlist=True)
gen_class = getattr(package, class_name)
self.gen = gen_class(**model_config["g_model"]["module_params"])
# print and recorde model structure
self.reporter.writeInfo("Generator structure:")
self.reporter.writeModel(self.gen.__str__())
# train in GPU
if self.config["cuda"] >=0:
self.gen = self.gen.cuda()
# if in finetune phase, load the pretrained checkpoint
if self.config["phase"] == "finetune":
model_path = os.path.join(self.config["project_checkpoints"],
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
self.config["checkpoint_names"]["generator_name"]))
self.gen.load_state_dict(torch.load(model_path))
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
# TODO modify this function to evaluate your model
def __evaluation__(self, epoch, step = 0):
# Evaluate the checkpoint
self.network.eval()
total_psnr = 0
total_num = 0
with torch.no_grad():
for _ in range(self.eval_iter):
hr, lr = self.eval_loader()
if self.config["cuda"] >=0:
hr = hr.cuda()
lr = lr.cuda()
hr = (hr + 1.0)/2.0 * 255.0
hr = torch.clamp(hr,0.0,255.0)
lr = (lr + 1.0)/2.0 * 255.0
lr = torch.clamp(lr,0.0,255.0)
res = self.network(lr)
# res = (res + 1.0)/2.0 * 255.0
# hr = (hr + 1.0)/2.0 * 255.0
res = torch.clamp(res,0.0,255.0)
diff = (res-hr) ** 2
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
psnrs = 20. * (255. / diff).log10()
total_psnr+= psnrs.sum()
total_num+=res.shape[0]
final_psnr = total_psnr/total_num
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
epoch, final_psnr))
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
g_train_opt = self.config['g_optim_config']
g_optim_params = []
for k, v in self.gen.named_parameters():
if v.requires_grad:
g_optim_params.append(v)
else:
self.reporter.writeInfo(f'Params {k} will not be optimized.')
print(f'Params {k} will not be optimized.')
optim_type = self.config['optim_type']
if optim_type == 'Adam':
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
# self.optimizers.append(self.optimizer_g)
def train(self):
ckpt_dir = self.config["project_checkpoints"]
log_frep = self.config["log_step"]
model_freq = self.config["model_save_epoch"]
total_epoch = self.config["total_epoch"]
batch_size = self.config["batch_size"]
style_img = self.config["style_img_path"]
# prep_weights= self.config["layersWeight"]
content_w = self.config["content_weight"]
style_w = self.config["style_weight"]
crop_size = self.config["imcrop_size"]
sample_dir = self.config["project_samples"]
#===============build framework================#
self.__init_framework__()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
#===============build losses===================#
# TODO replace below lines to build your losses
MSE_loss = torch.nn.MSELoss()
# set the start point for training loop
if self.config["phase"] == "finetune":
start = self.config["checkpoint_epoch"] - 1
else:
start = 0
# print("prepare the fixed labels...")
# fix_label = [i for i in range(n_class)]
# fix_label = torch.tensor(fix_label).long().cuda()
# fix_label = fix_label.view(n_class,1)
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()
start_time = time.time()
# Caculate the epoch number
step_epoch = len(self.train_loader)
step_epoch = step_epoch // batch_size
print("Total step = %d in each epoch"%step_epoch)
VGG = VGG16().cuda()
MEAN_VAL = 127.5
SCALE_VAL= 127.5
# Get Style Features
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
style_tensor = style_tensor.add(imagenet_neg_mean)
B, C, H, W = style_tensor.shape
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
style_gram = {}
for key, value in style_features.items():
style_gram[key] = Gram(value)
# step_epoch = 2
for epoch in range(start, total_epoch):
for step in range(step_epoch):
self.gen.train()
content_images = self.train_loader.next()
fake_image = self.gen(content_images)
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
style_loss = 0.0
for key, value in generated_features.items():
s_loss = MSE_loss(Gram(value), style_gram[key])
style_loss += s_loss
# backward & optimize
g_loss = content_loss* content_w + style_loss* style_w
self.g_optimizer.zero_grad()
g_loss.backward()
self.g_optimizer.step()
# Print out log info
if (step + 1) % log_frep == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
# cumulative steps
cum_step = (step_epoch * epoch + step + 1)
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
print(epochinformation)
self.reporter.writeInfo(epochinformation)
if self.config["use_tensorboard"]:
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
#===============adjust learning rate============#
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
# print("Learning rate decay")
# for p in self.optimizer.param_groups:
# p['lr'] *= self.config["lr_decay"]
# print("Current learning rate is %f"%p['lr'])
#===============save checkpoints================#
if (epoch+1) % model_freq==0:
print("Save epoch %d model checkpoint!"%(epoch+1))
torch.save(self.gen.state_dict(),
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
self.config["checkpoint_names"]["generator_name"])))
torch.cuda.empty_cache()
print('Sample images {}_fake.jpg'.format(epoch + 1))
self.gen.eval()
with torch.no_grad():
sample = fake_image
saved_image1 = denorm(sample.cpu().data)
save_image(saved_image1,
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
+83
View File
@@ -0,0 +1,83 @@
# Related scripts
train_script_name: FastNST
# models' scripts
model_configs:
g_model:
script: FastNST
class_name: Generator
module_params:
g_conv_dim: 32
g_kernel_size: 3
res_num: 6
n_class: 11
image_size: 256
window_size: 8
# Training information
total_epoch: 120
batch_size: 16
imcrop_size: 256
max2Keep: 10
# Dataset
style_img_path: "G:\\UltraHighStyleTransfer\\reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\images\\mosaic.jpg"
dataloader: place365
dataset_name: styletransfer
dataset_params:
random_seed: 1234
dataloader_workers: 8
color_jitter: Enable
color_config:
brightness: 0.05
contrast: 0.05
saturation: 0.05
hue: 0.05
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
'b/building_facade', 'c/cemetery']
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: !!float 2e-4
betas: [0.9, 0.99]
content_weight: 2.0
style_weight: 1.0
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
# Log
log_step: 100
model_save_epoch: 1
use_tensorboard: True
checkpoint_names:
generator_name: Generator
+108
View File
@@ -0,0 +1,108 @@
# Related scripts
train_script_name: FastNST_CNN
# models' scripts
model_configs:
g_model:
script: FastNST_CNN
class_name: Generator
module_params:
g_conv_dim: 32
g_kernel_size: 3
res_num: 6
n_class: 11
image_size: 256
window_size: 8
# Training information
total_epoch: 120
batch_size: 16
imcrop_size: 256
max2Keep: 10
# Dataset
style_img_path: "images/mosaic.jpg"
dataloader: place365
dataset_name: styletransfer
dataset_params:
random_seed: 1234
dataloader_workers: 8
color_jitter: Enable
color_config:
brightness: 0.05
contrast: 0.05
saturation: 0.05
hue: 0.05
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
'b/building_facade', 'c/cemetery']
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
# 'b_building_facade',
# 'c_cemetery'
# ]
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: !!float 2e-4
betas: [0.9, 0.99]
content_weight: 2.0
style_weight: 1.0
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
# Log
log_step: 100
model_save_epoch: 1
use_tensorboard: True
checkpoint_names:
generator_name: Generator
+108
View File
@@ -0,0 +1,108 @@
# Related scripts
train_script_name: FastNST_CNN
# models' scripts
model_configs:
g_model:
script: FastNST_CNN_Resblock
class_name: Generator
module_params:
g_conv_dim: 32
g_kernel_size: 3
res_num: 6
n_class: 11
image_size: 256
window_size: 8
# Training information
total_epoch: 120
batch_size: 16
imcrop_size: 256
max2Keep: 10
# Dataset
style_img_path: "images/mosaic.jpg"
dataloader: place365
dataset_name: styletransfer
dataset_params:
random_seed: 1234
dataloader_workers: 8
color_jitter: Enable
color_config:
brightness: 0.05
contrast: 0.05
saturation: 0.05
hue: 0.05
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
'b/building_facade', 'c/cemetery']
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
# 'b_building_facade',
# 'c_cemetery'
# ]
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: !!float 2e-4
betas: [0.9, 0.99]
content_weight: 2.0
style_weight: 1.0
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
# Log
log_step: 100
model_save_epoch: 1
use_tensorboard: True
checkpoint_names:
generator_name: Generator
+110
View File
@@ -0,0 +1,110 @@
# Related scripts
train_script_name: FastNST_Liif
# models' scripts
model_configs:
g_model:
script: FastNST_Liif
class_name: Generator
module_params:
g_conv_dim: 32
g_kernel_size: 3
res_num: 6
n_class: 11
image_size: 256
mlp_hidden_list: [32,32]
batch_size: 10
window_size: 8
# Training information
total_epoch: 120
batch_size: 10
imcrop_size: 256
max2Keep: 10
# Dataset
style_img_path: "images/mosaic.jpg"
dataloader: place365
dataset_name: styletransfer
dataset_params:
random_seed: 1234
dataloader_workers: 8
color_jitter: Enable
color_config:
brightness: 0.05
contrast: 0.05
saturation: 0.05
hue: 0.05
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
'b/building_facade', 'c/cemetery']
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
# 'b_building_facade',
# 'c_cemetery'
# ]
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: !!float 2e-4
betas: [0.9, 0.99]
content_weight: 2.0
style_weight: 1.0
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
# Log
log_step: 100
model_save_epoch: 1
use_tensorboard: True
checkpoint_names:
generator_name: Generator
+109
View File
@@ -0,0 +1,109 @@
# Related scripts
train_script_name: FastNST_Liif
# models' scripts
model_configs:
g_model:
script: FastNST_Liif_warp
class_name: Generator
module_params:
g_conv_dim: 32
g_kernel_size: 3
res_num: 6
n_class: 11
image_size: 256
batch_size: 16
window_size: 8
# Training information
total_epoch: 120
batch_size: 16
imcrop_size: 256
max2Keep: 10
# Dataset
style_img_path: "images/mosaic.jpg"
dataloader: place365
dataset_name: styletransfer
dataset_params:
random_seed: 1234
dataloader_workers: 8
color_jitter: Enable
color_config:
brightness: 0.05
contrast: 0.05
saturation: 0.05
hue: 0.05
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
'b/building_facade', 'c/cemetery']
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
# 'b_building_facade',
# 'c_cemetery'
# ]
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: !!float 2e-4
betas: [0.9, 0.99]
content_weight: 2.0
style_weight: 1.0
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
# Log
log_step: 100
model_save_epoch: 1
use_tensorboard: True
checkpoint_names:
generator_name: Generator
@@ -0,0 +1,109 @@
# Related scripts
train_script_name: FastNST_Liif
# models' scripts
model_configs:
g_model:
script: FastNST_Liif_warp
class_name: Generator
module_params:
g_conv_dim: 32
g_kernel_size: 3
res_num: 6
n_class: 11
image_size: 256
batch_size: 16
window_size: 8
# Training information
total_epoch: 120
batch_size: 16
imcrop_size: 256
max2Keep: 10
# Dataset
style_img_path: "images/mosaic.jpg"
dataloader: place365
dataset_name: styletransfer
dataset_params:
random_seed: 1234
dataloader_workers: 8
color_jitter: Enable
color_config:
brightness: 0.05
contrast: 0.05
saturation: 0.05
hue: 0.05
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
'b/building_facade', 'c/cemetery']
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
# 'b_building_facade',
# 'c_cemetery'
# ]
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: !!float 2e-4
betas: [0.9, 0.99]
content_weight: 2.0
style_weight: 1.0
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
# Log
log_step: 100
model_save_epoch: 1
use_tensorboard: True
checkpoint_names:
generator_name: Generator
+109
View File
@@ -0,0 +1,109 @@
# Related scripts
train_script_name: FastNST_SWD
# models' scripts
model_configs:
g_model:
script: FastNST_CNN_Resblock
class_name: Generator
module_params:
g_conv_dim: 32
g_kernel_size: 3
res_num: 6
n_class: 11
image_size: 256
window_size: 8
# Training information
total_epoch: 120
batch_size: 16
imcrop_size: 256
max2Keep: 10
# Dataset
style_img_path: "images/mosaic.jpg"
dataloader: place365
dataset_name: styletransfer
dataset_params:
random_seed: 1234
dataloader_workers: 8
color_jitter: Enable
color_config:
brightness: 0.05
contrast: 0.05
saturation: 0.05
hue: 0.05
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
'b/building_facade', 'c/cemetery']
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
# 'b_building_facade',
# 'c_cemetery'
# ]
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: !!float 2e-4
betas: [0.9, 0.99]
content_weight: 1.0
style_weight: 10.0
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
swd_dim: 32
# Log
log_step: 100
model_save_epoch: 1
use_tensorboard: True
checkpoint_names:
generator_name: Generator
+98
View File
@@ -0,0 +1,98 @@
# Related scripts
train_script_name: gan
# models' scripts
model_configs:
g_model:
script: Conditional_Generator_Noskip
class_name: Generator
module_params:
g_conv_dim: 32
g_kernel_size: 3
res_num: 8
n_class: 11
d_model:
script: Conditional_Discriminator_Projection_big
class_name: Discriminator
module_params:
d_conv_dim: 32
d_kernel_size: 5
# Training information
total_epoch: 120
batch_size_list: [8, 4, 2]
switch_epoch_list: [0, 5, 10]
imcrop_size_list: [256, 512, 768]
max2Keep: 10
movingAverage: 0.05
d_success_threshold: 0.8
d_step: 3
g_step: 1
# Dataset
dataloader: condition
dataset_name: styletransfer
dataset_params:
random_seed: 1234
dataloader_workers: 8
color_jitter: Enable
color_config:
brightness: 0.05
contrast: 0.05
saturation: 0.05
hue: 0.05
selected_style_dir: ['berthe-morisot','edvard-munch',
'ernst-ludwig-kirchner','jackson-pollock','kandinsky','monet',
'nicholas','paul-cezanne','picasso','samuel','vangogh']
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
'b/building_facade', 'c/cemetery']
eval_dataloader: DIV2K_hdf5
eval_dataset_name: DF2K_H5_Eval
eval_batch_size: 2
# Dataset
# Optimizer
optim_type: Adam
g_optim_config:
lr: !!float 2e-4
betas: [0.9, 0.99]
d_optim_config:
lr: !!float 2e-4
betas: [0.9, 0.99]
feature_weight: 50.0
transform_weight: 50.0
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
# Log
log_step: 1000
sampleStep: 2000
model_save_epoch: 1
useTensorboard: True
checkpoint_names:
generator_name: Generator
discriminator_name: Discriminator
+100
View File
@@ -0,0 +1,100 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: checkpoint_manager.py
# Created Date: Sunday July 12th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 27th July 2020 11:01:16 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import os
import torch
class CheckpointManager(object):
def __init__(self):
pass
self.maxCkpNum = -1
self.ckpList = []
self.modelsDict = {} # key model name, value model
self.currentEpoch = 0
def registerModels(self):
pass
def __updateCkpList__(self):
pass
def saveModel(self):
pass
def loadModel(self):
pass
def saveLR(self):
pass
def loadLR(self):
pass
def loadPretrainedModel(chechpointStep,modelSavePath,gModel,dModel,cuda,**kwargs):
gModel.load_state_dict(torch.load(os.path.join(
modelSavePath, 'Epoch{}_LocalG.pth'.format(chechpointStep)),map_location=cuda))
dModel.load_state_dict(torch.load(os.path.join(
modelSavePath, 'Epoch{}_GlobalD.pth'.format(chechpointStep)),map_location=cuda))
print('loaded trained models (epoch: {}) successful!'.format(chechpointStep))
if not kwargs:
return
for k,v in kwargs.items():
v.load_state_dict(torch.load(os.path.join(
modelSavePath, 'Epoch{}_{}.pth'.format(chechpointStep,k)),map_location=cuda))
print("Loaded param %s"%k)
def loadPretrainedModelByDict(chechpointStep,modelSavePath,cuda,**kwargs):
if not kwargs:
return
for k,v in kwargs.items():
v.load_state_dict(torch.load(os.path.join(
modelSavePath, 'Epoch{}_{}.pth'.format(chechpointStep,k)),map_location=cuda))
print("Loaded param %s"%k)
def loadLR(chechpointStep,modelSavePath,dlr,glr):
glr.load_state_dict(torch.load(os.path.join(
modelSavePath, 'Epoch{}_LocalGlr.pth'.format(chechpointStep))))
dlr.load_state_dict(torch.load(os.path.join(
modelSavePath, 'Epoch{}_GlobalDlr.pth'.format(chechpointStep))))
print("Generator learning rate:%f"%glr.get_lr()[0])
print("Discriminator learning rate:%f"%dlr.get_lr()[0])
def saveLR(step,modelSavePath,dlr,glr):
torch.save(glr.state_dict(),os.path.join(modelSavePath, 'Epoch{}_LocalGlr.pth'.format(step + 1)))
torch.save(dlr.state_dict(),os.path.join(modelSavePath, 'Epoch{}_GlobalDlr.pth'.format(step + 1)))
print("Epoch:{} models learning rate saved!".format(step+1))
def saveModel(step,modelSavePath,gModel,dModel,**kwargs):
torch.save(gModel.state_dict(),
os.path.join(modelSavePath, 'Epoch{}_LocalG.pth'.format(step + 1)))
torch.save(dModel.state_dict(),
os.path.join(modelSavePath, 'Epoch{}_GlobalD.pth'.format(step + 1)))
print("Epoch:{} models saved!".format(step+1))
if not kwargs:
return
for k,v in kwargs.items():
torch.save(v.state_dict(),
os.path.join(modelSavePath, 'Epoch{}_{}.pth'.format(step + 1,k)))
print("Epoch:{} models param {} saved!".format(step+1,k))
def saveModelByDict(step,modelSavePath,**kwargs):
if not kwargs:
return
for k,v in kwargs.items():
torch.save(v.state_dict(),
os.path.join(modelSavePath, 'Epoch{}_{}.pth'.format(step + 1,k)))
print("Epoch:{} models param {} saved!".format(step+1,k))
+22
View File
@@ -0,0 +1,22 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: figure.py
# Created Date: Tuesday October 13th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 13th October 2020 2:54:30 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import matplotlib.pyplot as plt
def plot_loss_curve(losses, save_path):
for key in losses.keys():
plt.plot(range(len(losses[key])), losses[key], label=key)
plt.xlabel('iteration')
plt.title(f'loss curve')
plt.legend()
plt.savefig(save_path)
plt.clf()
+15
View File
@@ -0,0 +1,15 @@
import json
def readConfig(path):
with open(path,'r') as cf:
nodelocaltionstr = cf.read()
nodelocaltioninf = json.loads(nodelocaltionstr)
if isinstance(nodelocaltioninf,str):
nodelocaltioninf = json.loads(nodelocaltioninf)
return nodelocaltioninf
def writeConfig(path, info):
with open(path, 'w') as cf:
configjson = json.dumps(info, indent=4)
cf.writelines(configjson)
+135
View File
@@ -0,0 +1,135 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: learningrate_scheduler.py
# Created Date: Tuesday January 5th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 5th January 2021 2:04:00 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
# Refer to basicSR https://github.com/xinntao/BasicSR
import math
from collections import Counter
from torch.optim.lr_scheduler import _LRScheduler
class MultiStepRestartLR(_LRScheduler):
""" MultiStep with restarts learning rate scheme.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
milestones (list): Iterations that will decrease learning rate.
gamma (float): Decrease ratio. Default: 0.1.
restarts (list): Restart iterations. Default: [0].
restart_weights (list): Restart weights at each restart iteration.
Default: [1].
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
def __init__(self,
optimizer,
milestones,
gamma=0.1,
restarts=(0,),
restart_weights=(1,),
last_epoch=-1):
self.milestones = Counter(milestones)
self.gamma = gamma
self.restarts = restarts
self.restart_weights = restart_weights
print(type(self.restarts),self.restarts)
print(type(self.restart_weights),self.restart_weights)
assert len(self.restarts) == len(
self.restart_weights), 'restarts and their weights do not match.'
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch in self.restarts:
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [
group['initial_lr'] * weight
for group in self.optimizer.param_groups
]
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [
group['lr'] * self.gamma**self.milestones[self.last_epoch]
for group in self.optimizer.param_groups
]
def get_position_from_periods(iteration, cumulative_period):
"""Get the position from a period list.
It will return the index of the right-closest number in the period list.
For example, the cumulative_period = [100, 200, 300, 400],
if iteration == 50, return 0;
if iteration == 210, return 2;
if iteration == 300, return 2.
Args:
iteration (int): Current iteration.
cumulative_period (list[int]): Cumulative period list.
Returns:
int: The position of the right-closest number in the period list.
"""
for i, period in enumerate(cumulative_period):
if iteration <= period:
return i
class CosineAnnealingRestartLR(_LRScheduler):
""" Cosine annealing with restarts learning rate scheme.
An example of config:
periods = [10, 10, 10, 10]
restart_weights = [1, 0.5, 0.5, 0.5]
eta_min=1e-7
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
scheduler will restart with the weights in restart_weights.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
periods (list): Period for each cosine anneling cycle.
restart_weights (list): Restart weights at each restart iteration.
Default: [1].
eta_min (float): The mimimum lr. Default: 0.
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
def __init__(self,
optimizer,
periods,
restart_weights=(1),
eta_min=0,
last_epoch=-1):
self.periods = periods
self.restart_weights = restart_weights
self.eta_min = eta_min
assert (len(self.periods) == len(self.restart_weights)
), 'periods and restart_weights should have the same length.'
self.cumulative_period = [
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
]
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
idx = get_position_from_periods(self.last_epoch,
self.cumulative_period)
current_weight = self.restart_weights[idx]
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
current_period = self.periods[idx]
return [
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
(1 + math.cos(math.pi * (
(self.last_epoch - nearest_restart) / current_period)))
for base_lr in self.base_lrs
]
+44
View File
@@ -0,0 +1,44 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: logo_class.py
# Created Date: Tuesday June 29th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 11th October 2021 12:39:55 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
class logo_class:
@staticmethod
def print_group_logo():
logo_str = """
Neural Rendering Special Interesting Group of SJTU
"""
print(logo_str)
@staticmethod
def print_start_training():
logo_str = """
_____ __ __ ______ _ _
/ ___/ / /_ ____ _ _____ / /_ /_ __/_____ ____ _ (_)____ (_)____ ____ _
\__ \ / __// __ `// ___// __/ / / / ___// __ `// // __ \ / // __ \ / __ `/
___/ // /_ / /_/ // / / /_ / / / / / /_/ // // / / // // / / // /_/ /
/____/ \__/ \__,_//_/ \__/ /_/ /_/ \__,_//_//_/ /_//_//_/ /_/ \__, /
/____/
"""
print(logo_str)
if __name__=="__main__":
# logo_class.print_group_logo()
logo_class.print_start_training()
+56
View File
@@ -0,0 +1,56 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Reporter.py
# Created Date: Tuesday September 24th 2019
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 4th July 2021 11:50:12 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2019 Shanghai Jiao Tong University
#############################################################
import datetime
import os
class Reporter:
def __init__(self,reportPath):
self.path = reportPath
self.withTimeStamp = False
self.index = 1
self.timeStrFormat = '%Y-%m-%d %H:%M:%S'
timeStr = datetime.datetime.strftime(datetime.datetime.now(),'%Y%m%d%H%M%S')
self.path = self.path + "-%s.log"%timeStr
if not os.path.exists(self.path):
f = open(self.path,'w')
f.close()
def writeInfo(self,strLine):
with open(self.path,'a+') as logf:
timeStr = datetime.datetime.strftime(datetime.datetime.now(),self.timeStrFormat)
logf.writelines("[%d]-[%s]-[info] %s\n"%(self.index,timeStr,strLine))
self.index += 1
def writeConfig(self,config):
with open(self.path,'a+') as logf:
for item in config.items():
text = "[%d]-[parameters] %s--%s\n"%(self.index,item[0],str(item[1]))
logf.writelines(text)
self.index +=1
def writeModel(self,modelText):
with open(self.path,'a+') as logf:
logf.writelines("[%d]-[model] %s\n"%(self.index,modelText))
self.index += 1
def writeRawInfo(self, strLine):
with open(self.path,'a+') as logf:
timeStr = datetime.datetime.strftime(datetime.datetime.now(),self.timeStrFormat)
logf.writelines("[%d]-[info] %s\n"%(self.index,timeStr,strLine))
self.index += 1
def writeTrainLog(self, epoch, step, logText):
with open(self.path,'a+') as logf:
timeStr = datetime.datetime.strftime(datetime.datetime.now(),self.timeStrFormat)
logf.writelines("[%d]-[%s]-[logInfo]-epoch[%d]-step[%d] %s\n"%(self.index,timeStr,epoch,step,logText))
self.index += 1
+57
View File
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: save_heatmap.py
# Created Date: Friday January 15th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 15th January 2021 10:23:13 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import shutil
import seaborn as sns
import matplotlib.pyplot as plt
import cv2
import numpy as np
def SaveHeatmap(heatmaps, path, row=-1, dpi=72):
"""
The input tensor must be B X 1 X H X W
"""
batch_size = heatmaps.shape[0]
temp_path = ".temp/"
if not os.path.exists(temp_path):
os.makedirs(temp_path)
final_img = None
if row < 1:
col = batch_size
row = 1
else:
col = batch_size // row
if row * col <batch_size:
col +=1
row_i = 0
col_i = 0
for i in range(batch_size):
img_path = os.path.join(temp_path,'temp_batch_{}.png'.format(i))
sns.heatmap(heatmaps[i,0,:,:],vmin=0,vmax=heatmaps[i,0,:,:].max(),cbar=False)
plt.savefig(img_path, dpi=dpi, bbox_inches = 'tight', pad_inches = 0)
img = cv2.imread(img_path)
if i == 0:
H,W,C = img.shape
final_img = np.zeros((H*row,W*col,C))
final_img[H*row_i:H*(row_i+1),W*col_i:W*(col_i+1),:] = img
col_i += 1
if col_i >= col:
col_i = 0
row_i += 1
cv2.imwrite(path,final_img)
if __name__ == "__main__":
random_map = np.random.randn(16,1,10,10)
SaveHeatmap(random_map,"./wocao.png",1)
+127
View File
@@ -0,0 +1,127 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: sshupload.py
# Created Date: Tuesday September 24th 2019
# Author: Lcx
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th January 2021 2:02:12 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2019 Shanghai Jiao Tong University
#############################################################
import paramiko,os
# ssh传输类:
class fileUploaderClass(object):
def __init__(self,serverIp,userName,passWd,port=22):
self.__ip__ = serverIp
self.__userName__ = userName
self.__passWd__ = passWd
self.__port__ = port
self.__ssh__ = paramiko.SSHClient()
self.__ssh__.set_missing_host_key_policy(paramiko.AutoAddPolicy())
def sshScpPut(self,localFile,remoteFile):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
remoteDir = remoteFile.split("/")
if remoteFile[0]=='/':
sftp.chdir('/')
for item in remoteDir[0:-1]:
if item == "":
continue
try:
sftp.chdir(item)
except:
sftp.mkdir(item)
sftp.chdir(item)
sftp.put(localFile,remoteDir[-1])
sftp.close()
self.__ssh__.close()
print("ssh localfile:%s remotefile:%s success"%(localFile,remoteFile))
def sshScpGetNames(self,remoteDir):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
wocao = sftp.listdir(remoteDir)
return wocao
def sshScpGet(self, remoteFile, localFile, showProgress=False):
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
try:
sftp.stat(remoteFile)
print("Remote file exists!")
except:
print("Remote file does not exist!")
return False
sftp = self.__ssh__.open_sftp()
if showProgress:
sftp.get(remoteFile, localFile,callback=self.__putCallBack__)
else:
sftp.get(remoteFile, localFile)
sftp.close()
self.__ssh__.close()
return True
def __putCallBack__(self,transferred,total):
print("current transferred %3.1f percent"%(transferred/total*100),end='\r')
def sshScpGetmd5(self, file_path):
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
try:
file = sftp.open(file_path, 'rb')
res = (True,hashlib.new('md5', file.read()).hexdigest())
sftp.close()
self.__ssh__.close()
return res
except:
sftp.close()
self.__ssh__.close()
return (False,None)
def sshScpRename(self, oldpath, newpath):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
sftp.rename(oldpath,newpath)
sftp.close()
self.__ssh__.close()
print("ssh oldpath:%s newpath:%s success"%(oldpath,newpath))
def sshScpDelete(self,path):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
sftp.remove(path)
sftp.close()
self.__ssh__.close()
print("ssh delete:%s success"%(path))
def sshScpDeleteDir(self,path):
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
sftp = self.__ssh__.open_sftp()
self.__rm__(sftp,path)
sftp.close()
self.__ssh__.close()
def __rm__(self,sftp,path):
try:
files = sftp.listdir(path=path)
print(files)
for f in files:
filepath = os.path.join(path, f).replace('\\','/')
self.__rm__(sftp,filepath)
sftp.rmdir(path)
print("ssh delete:%s success"%(path))
except:
print(path)
sftp.remove(path)
print("ssh delete:%s success"%(path))
+146
View File
@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: transfer_checkpoint.py
# Created Date: Wednesday February 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 4th February 2021 1:27:09 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
import os
import numpy as np
import scipy.io as io
class RepSRPlain_pixel(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
Currently, it supports x4 upsampling scale factor.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(self,
num_in_ch,
num_out_ch,
num_feat=32,
num_layer = 3,
upsampling=4):
super(RepSRPlain_pixel, self).__init__()
self.scale = upsampling
self.ssqu = upsampling ** 2
self.rep1 = nn.Conv2d(num_in_ch, num_feat,3,1,1)
self.rep2 = nn.Conv2d(num_feat, num_feat*2,3,1,1)
self.rep3 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
self.rep4 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
self.rep5 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
self.rep6 = nn.Conv2d(num_feat*2, num_feat,3,1,1)
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.activator = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# self.activator = nn.ReLU(inplace=True)
# default_init_weights([self.conv_up1,self.conv_up2,self.conv_hr,self.conv_last], 0.1)
def forward(self, x):
f_d = self.activator(self.rep1(x))
f_d = self.activator(self.rep2(f_d))
f_d = self.activator(self.rep3(f_d))
f_d = self.activator(self.rep4(f_d))
f_d = self.activator(self.rep5(f_d))
f_d = self.activator(self.rep6(f_d))
feat = self.activator(
self.conv_up1(F.interpolate(f_d, scale_factor=2, mode='nearest')))
feat = self.activator(
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.activator(self.conv_hr(feat)))
return out
def create_identity_conv(dim,kernel_size=3):
zeros = torch.zeros((dim,dim,kernel_size,kernel_size)).cuda()
for i_dim in range(dim):
zeros[i_dim,i_dim,kernel_size//2,kernel_size//2] = 1.0
return zeros
def fill_conv_kernel(in_tensor,kernel_size=3):
shape = in_tensor.shape
zeros = torch.zeros(shape[0],shape[1],kernel_size,kernel_size).cuda()
for i_dim in range(shape[0]):
zeros[i_dim,:,kernel_size//2,kernel_size//2] = in_tensor[i_dim,:,0,0]
return zeros
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
base_path = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\"
version = "repsr_pixel_0"
epoch = 73
path_ckp= os.path.join(base_path,version,"checkpoints\\epoch%d_RepSR.pth"%epoch)
path_plain_ckp= os.path.join(base_path,version,"checkpoints\\epoch%d_RepSR_Plain.pth"%epoch)
network = RepSRPlain_pixel(3,
3,
64,
3,
4
)
network = network.cuda()
wocao = network.state_dict()
# for data_key in wocao.keys():
# print(data_key)
# print(wocao[data_key].shape)
wocao_cpk = torch.load(path_ckp)
# for data_key in wocao_cpk.keys():
# print(data_key)
# print(wocao_cpk[data_key].shape)
name_list = ["rep1","rep2","rep3","rep4","rep5","rep6"]
other_list = ["conv_up1","conv_up2","conv_hr","conv_last"]
for i_name in name_list:
temp= wocao_cpk[i_name+".conv3.weight"] + fill_conv_kernel(wocao_cpk[i_name+".conv1x1.weight"])
wocao[i_name+".weight"] = temp
temp= wocao_cpk[i_name+".conv3.bias"] + wocao_cpk[i_name+".conv1x1.bias"]
wocao[i_name+".bias"] = temp
if wocao_cpk[i_name+".conv3.weight"].shape[0] == wocao_cpk[i_name+".conv3.weight"].shape[1]:
print("include identity")
temp = wocao[i_name+".weight"] + create_identity_conv(wocao_cpk[i_name+".conv3.weight"].shape[0])
wocao[i_name+".weight"] = temp
for i_name in other_list:
wocao[i_name+".weight"] = wocao_cpk[i_name+".weight"]
wocao[i_name+".bias"] = wocao_cpk[i_name+".bias"]
torch.save(wocao,path_plain_ckp)
# wocao = torch.load(path_plain_ckp)
# for data_key in wocao.keys():
# result1 = wocao[data_key].cpu().numpy()
# # np.savetxt(i_name+"_conv3_weight.txt",result1)
# str_temp = ("%s"%data_key).replace(".","_")
# io.savemat(str_temp+".mat",{str_temp:result1})
# for data_key in wocao.keys():
# print(data_key)
# print(wocao[data_key].shape)
+335
View File
@@ -0,0 +1,335 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: utilities.py
# Created Date: Monday April 6th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 2:18:05 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import cv2
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
# Gram Matrix
def Gram(tensor: torch.Tensor):
B, C, H, W = tensor.shape
x = tensor.view(B, C, H*W)
x_t = x.transpose(1, 2)
return torch.bmm(x, x_t) / (C*H*W)
def build_tensorboard(summary_path):
from tensorboardX import SummaryWriter
# from logger import Logger
# self.logger = Logger(self.log_path)
writer = SummaryWriter(log_dir=summary_path)
return writer
def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
def tensor2img(img_tensor):
"""
Input image tensor shape must be [B C H W]
the return image numpy array shape is [B H W C]
"""
res = img_tensor.numpy()
res = (res + 1) / 2
res = np.clip(res, 0.0, 1.0)
res = res * 255
res = res.transpose((0,2,3,1))
return res
def img2tensor255(path, max_size=None):
image = Image.open(path)
# Rescale the image
if (max_size==None):
itot_t = transforms.Compose([
#transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
else:
H, W, C = image.shape
image_size = tuple([int((float(max_size) / max([H,W]))*x) for x in [H, W]])
itot_t = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
# Convert image to tensor
tensor = itot_t(image)
# Add the batch_size dimension
tensor = tensor.unsqueeze(dim=0)
return tensor
def img2tensor255crop(path, crop_size=256):
image = Image.open(path)
# Rescale the image
itot_t = transforms.Compose([
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
# Convert image to tensor
tensor = itot_t(image)
# Add the batch_size dimension
tensor = tensor.unsqueeze(dim=0)
return tensor
# def img2tensor255(path, crop_size=None):
# """
# Input image tensor shape must be [B C H W]
# the return image numpy array shape is [B H W C]
# """
# img = cv2.imread(path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float)
# img = torch.from_numpy(img).transpose((2,0,1)).unsqueeze(0)
# return img
def img2tensor1(img_tensor):
"""
Input image tensor shape must be [B C H W]
the return image numpy array shape is [B H W C]
"""
res = img_tensor.numpy()
res = (res + 1) / 2
res = np.clip(res, 0.0, 1.0)
res = res * 255
res = res.transpose((0,2,3,1))
return res
def _convert_input_type_range(img):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
convertion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type = img.dtype
img = img.astype(np.float32)
if img_type == np.float32:
pass
elif img_type == np.uint8:
img /= 255.
else:
raise TypeError('The img type should be np.float32 or np.uint8, '
f'but got {img_type}')
return img
def _convert_output_type_range(img, dst_type):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace convertion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if dst_type not in (np.uint8, np.float32):
raise TypeError('The dst_type should be np.float32 or np.uint8, '
f'but got {dst_type}')
if dst_type == np.uint8:
img = img.round()
else:
img /= 255.
return img.astype(dst_type)
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
# out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 #RGB
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.
def calculate_psnr(img1,
img2,
# crop_border=0,
test_y_channel=True):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the PSNR calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: psnr result.
"""
# if crop_border != 0:
# img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
# img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
if test_y_channel:
img1 = to_y_channel(img1)
img2 = to_y_channel(img2)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20. * np.log10(255. / np.sqrt(mse))
def _ssim(img1, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: ssim result.
"""
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) *
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim(img1,
img2,
test_y_channel=True):
"""Calculate SSIM (structural similarity).
Ref:
Image quality assessment: From error visibility to structural similarity
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the SSIM calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: ssim result.
"""
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
if test_y_channel:
img1 = to_y_channel(img1)
img2 = to_y_channel(img2)
ssims = []
for i in range(img1.shape[2]):
ssims.append(_ssim(img1[..., i], img2[..., i]))
return np.array(ssims).mean()
+29
View File
@@ -0,0 +1,29 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Config_from_yaml.py
# Created Date: Monday February 17th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 28th February 2020 4:30:01 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
import yaml
def getConfigYaml(yaml_file):
with open(yaml_file, 'r') as config_file:
try:
config_dict = yaml.load(config_file, Loader=yaml.FullLoader)
return config_dict
except ValueError:
print('INVALID YAML file format.. Please provide a good yaml file')
exit(-1)
if __name__ == "__main__":
a= getConfigYaml("./train_256.yaml")
sys_state = {}
for item in a.items():
sys_state[item[0]] = item[1]