init
This commit is contained in:
@@ -112,3 +112,9 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
/train_logs
|
||||
/test_logs
|
||||
/GUI
|
||||
/benchmark
|
||||
/reference
|
||||
@@ -0,0 +1,924 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: GUI copy 2.py
|
||||
# Created Date: Wednesday December 22nd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 10th January 2022 1:47:55 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import tkinter
|
||||
try:
|
||||
import paramiko
|
||||
except:
|
||||
from pip._internal import main
|
||||
main(['install', 'paramiko'])
|
||||
import paramiko
|
||||
|
||||
import threading
|
||||
import tkinter as tk
|
||||
import tkinter.ttk as ttk
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
||||
|
||||
#############################################################
|
||||
# Predefined functions
|
||||
#############################################################
|
||||
|
||||
def read_config(path):
|
||||
with open(path,'r') as cf:
|
||||
nodelocaltionstr = cf.read()
|
||||
nodelocaltioninf = json.loads(nodelocaltionstr)
|
||||
if isinstance(nodelocaltioninf,str):
|
||||
nodelocaltioninf = json.loads(nodelocaltioninf)
|
||||
return nodelocaltioninf
|
||||
|
||||
def write_config(path, info):
|
||||
with open(path, 'w') as cf:
|
||||
configjson = json.dumps(info, indent=4)
|
||||
cf.writelines(configjson)
|
||||
|
||||
class fileUploaderClass(object):
|
||||
def __init__(self,serverIp,userName,passWd,port=22):
|
||||
self.__ip__ = serverIp
|
||||
self.__userName__ = userName
|
||||
self.__passWd__ = passWd
|
||||
self.__port__ = port
|
||||
self.__ssh__ = paramiko.SSHClient()
|
||||
self.__ssh__.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
def sshScpPut(self,localFile,remoteFile):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
remoteDir = remoteFile.split("/")
|
||||
if remoteFile[0]=='/':
|
||||
sftp.chdir('/')
|
||||
|
||||
for item in remoteDir[0:-1]:
|
||||
if item == "":
|
||||
continue
|
||||
try:
|
||||
sftp.chdir(item)
|
||||
except:
|
||||
sftp.mkdir(item)
|
||||
sftp.chdir(item)
|
||||
sftp.put(localFile,remoteDir[-1])
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
print("[To %s]:%s remotefile:%s success"%(self.__ip__,localFile,remoteFile))
|
||||
|
||||
def sshScpPuts(self,localFiles,remoteFiles):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
for i_dir in range(len(remoteFiles)):
|
||||
remoteDir = remoteFiles[i_dir].split("/")
|
||||
if remoteFiles[i_dir][0]=='/':
|
||||
sftp.chdir('/')
|
||||
for item in remoteDir[0:-1]:
|
||||
if item == "":
|
||||
continue
|
||||
try:
|
||||
sftp.chdir(item)
|
||||
except:
|
||||
sftp.mkdir(item)
|
||||
sftp.chdir(item)
|
||||
sftp.put(localFiles[i_dir],remoteDir[-1])
|
||||
print("[To %s]:%s remotefile:%s success"%(self.__ip__,localFiles[i_dir],remoteFiles[i_dir]))
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
|
||||
def sshExec(self, cmd):
|
||||
try:
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
_, stdout, _ = self.__ssh__.exec_command(cmd)
|
||||
results = stdout.read().strip().decode('utf-8')
|
||||
self.__ssh__.close()
|
||||
return results
|
||||
except Exception as e:
|
||||
print(e)
|
||||
finally:
|
||||
self.__ssh__.close()
|
||||
|
||||
def sshScpGetNames(self,remoteDir):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
wocao = sftp.listdir(remoteDir)
|
||||
# print(wocao.st_mtime)
|
||||
roots = {}
|
||||
for item in wocao:
|
||||
wocao = sftp.stat(remoteDir+"/"+item)
|
||||
roots[item] = {
|
||||
"t":wocao.st_mtime,
|
||||
"p":remoteDir+"/"+item
|
||||
}
|
||||
# temp= remoteDir+ "/"+item
|
||||
# child_dirs = sftp.listdir(temp)
|
||||
# child_dirs = ["save\\" +item + "\\" + i for i in child_dirs]
|
||||
# list_name += child_dirs
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
return roots
|
||||
|
||||
def sshScpGetRNames(self,remoteDir):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
wocao = sftp.listdir(remoteDir)
|
||||
# print(wocao.st_mtime)
|
||||
roots = {}
|
||||
for item in wocao:
|
||||
wocao = sftp.stat(remoteDir+"/"+item)
|
||||
roots[item] = {
|
||||
"t":wocao.st_mtime,
|
||||
"p":remoteDir+"/"+item
|
||||
}
|
||||
# temp= remoteDir+ "/"+item
|
||||
# child_dirs = sftp.listdir(temp)
|
||||
# child_dirs = ["save\\" +item + "\\" + i for i in child_dirs]
|
||||
# list_name += child_dirs
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
return roots
|
||||
|
||||
def sshScpGet(self, remoteFile, localFile, showProgress=False):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
if showProgress:
|
||||
sftp.get(remoteFile, localFile,callback=self.__putCallBack__)
|
||||
else:
|
||||
sftp.get(remoteFile, localFile)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
|
||||
def sshScpGetFiles(self, remoteFiles, localFiles, showProgress=False):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
for i in range(len(remoteFiles)):
|
||||
if showProgress:
|
||||
sftp.get(remoteFiles[i], localFiles[i],callback=self.__putCallBack__)
|
||||
else:
|
||||
sftp.get(remoteFiles[i], localFiles[i])
|
||||
print("Get %s success!"%(remoteFiles[i]))
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
|
||||
def sshScpGetDir(self, remoteDir, localDir, showProgress=False):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
files = sftp.listdir(remoteDir)
|
||||
for i_f in files:
|
||||
i_remote_file = Path(remoteDir,i_f).as_posix()
|
||||
local_file = Path(localDir,i_f)
|
||||
if showProgress:
|
||||
sftp.get(i_remote_file, local_file,callback=self.__putCallBack__)
|
||||
else:
|
||||
sftp.get(i_remote_file, local_file)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
|
||||
def __putCallBack__(self,transferred,total):
|
||||
print("current transferred %.1f percent"%(transferred/total*100))
|
||||
|
||||
def sshScpRename(self, oldpath, newpath):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
sftp.rename(oldpath,newpath)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
print("ssh oldpath:%s newpath:%s success"%(oldpath,newpath))
|
||||
|
||||
def sshScpDelete(self,path):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
sftp.remove(path)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
print("ssh delete:%s success"%(path))
|
||||
|
||||
def sshScpDeleteDir(self,path):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
self.__rm__(sftp,path)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
|
||||
def __rm__(self,sftp,path):
|
||||
try:
|
||||
files = sftp.listdir(path=path)
|
||||
print(files)
|
||||
for f in files:
|
||||
filepath = os.path.join(path, f).replace('\\','/')
|
||||
self.__rm__(sftp,filepath)
|
||||
sftp.rmdir(path)
|
||||
print("ssh delete:%s success"%(path))
|
||||
except:
|
||||
print(path)
|
||||
sftp.remove(path)
|
||||
print("ssh delete:%s success"%(path))
|
||||
|
||||
class TextRedirector(object):
|
||||
def __init__(self, widget, tag="stdout"):
|
||||
self.widget = widget
|
||||
self.tag = tag
|
||||
|
||||
def write(self, str):
|
||||
self.widget.configure(state="normal")
|
||||
self.widget.insert("end", str, (self.tag,))
|
||||
self.widget.configure(state="disabled")
|
||||
self.widget.see(tk.END)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
#############################################################
|
||||
# Main Class
|
||||
#############################################################
|
||||
|
||||
class Application(tk.Frame):
|
||||
|
||||
tab_info = []
|
||||
tab_body = None
|
||||
current_index = 0
|
||||
gui_root = "GUI/"
|
||||
machine_json = gui_root + "machines.json"
|
||||
filesynlogroot = "file_sync/"
|
||||
filesynlogroot = gui_root + filesynlogroot
|
||||
ignore_json = gui_root + "guiignore.json"
|
||||
machine_list = []
|
||||
machine_dict = {}
|
||||
|
||||
ignore_text={
|
||||
"white_list":{
|
||||
"extension":["py",
|
||||
"yaml"
|
||||
],
|
||||
"file":[],
|
||||
"path":[]
|
||||
},
|
||||
"black_list":{
|
||||
"extension":[
|
||||
"png",
|
||||
"yaml"
|
||||
],
|
||||
"file":[],
|
||||
"path":["save/", "GUI/",]
|
||||
}
|
||||
}
|
||||
env_text={
|
||||
"train_log_root":"./train_logs",
|
||||
"test_log_root":"./test_logs",
|
||||
"systemLog":"./system/system_log.log",
|
||||
"dataset_paths":{
|
||||
"train_dataset_root":"",
|
||||
"val_dataset_root":"",
|
||||
"test_dataset_root":""
|
||||
},
|
||||
"train_config_path":"./train_yamls",
|
||||
"train_scripts_path":"./train_scripts",
|
||||
"test_scripts_path":"./test_scripts",
|
||||
"config_json_name":"model_config.json"
|
||||
}
|
||||
machine_text = {
|
||||
"ip": "0.0.0.0",
|
||||
"user": "username",
|
||||
"port": 22,
|
||||
"passwd": "12345678",
|
||||
"path": "/path/to/remote_host",
|
||||
"ckp_path":"save",
|
||||
"logfilename": "filestate_machine0.json"
|
||||
}
|
||||
current_log = {}
|
||||
|
||||
|
||||
def __init__(self, master=None):
|
||||
tk.Frame.__init__(self, master,bg='black')
|
||||
# self.font_size = 16
|
||||
self.font_list = ("Times New Roman",14)
|
||||
self.padx = 5
|
||||
self.pady = 5
|
||||
if not Path(self.gui_root).exists():
|
||||
Path(self.gui_root).mkdir(parents=True)
|
||||
self.window_init()
|
||||
|
||||
def __label_text__(self, usr, root):
|
||||
return "User Name: %s\nWorkspace: %s"%(usr, root)
|
||||
|
||||
def window_init(self):
|
||||
cwd = os.getcwd()
|
||||
self.master.title('File Synchronize - %s'%cwd)
|
||||
# self.master.iconbitmap('./utilities/_logo.ico')
|
||||
self.master.geometry("{}x{}".format(640, 800))
|
||||
font_list = self.font_list
|
||||
try:
|
||||
self.machines = read_config(self.machine_json)
|
||||
except:
|
||||
self.machine_list = [self.machine_text,]
|
||||
write_config(self.machine_json,self.machine_list)
|
||||
# subprocess.call("start %s"%self.machine_json, shell=True)
|
||||
#################################################################################################
|
||||
list_frame = tk.Frame(self.master)
|
||||
list_frame.pack(fill="both", padx=5,pady=5)
|
||||
list_frame.columnconfigure(0, weight=1)
|
||||
list_frame.columnconfigure(1, weight=1)
|
||||
list_frame.columnconfigure(2, weight=1)
|
||||
|
||||
self.mac_var = tkinter.StringVar()
|
||||
|
||||
self.list_com = ttk.Combobox(list_frame, textvariable=self.mac_var)
|
||||
self.list_com.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
|
||||
open_button = tk.Button(list_frame, text = "Update",
|
||||
font=font_list, command = self.Machines_Update, bg='#F4A460', fg='#F5F5F5')
|
||||
open_button.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
open_button = tk.Button(list_frame, text = "Machines",
|
||||
font=font_list, command = self.MachineConfig, bg='#F4A460', fg='#F5F5F5')
|
||||
open_button.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
self.mac_text = tk.StringVar()
|
||||
mac_label = tk.Label(self.master, textvariable=self.mac_text,font=self.font_list,justify="left")
|
||||
mac_label.pack(fill="both", padx=5,pady=5)
|
||||
self.mac_text.set(self.list_com.get())
|
||||
self.machines_update()
|
||||
def xFunc(event):
|
||||
ip = self.list_com.get()
|
||||
cur_mac = self.machine_dict[ip]
|
||||
str_temp= self.__label_text__(cur_mac["user"],cur_mac["path"])
|
||||
self.mac_text.set(str_temp)
|
||||
self.update_log_task()
|
||||
self.update_ckpt_task()
|
||||
self.list_com.bind("<<ComboboxSelected>>",xFunc)
|
||||
#################################################################################################
|
||||
run_frame = tk.Frame(self.master)
|
||||
run_frame.pack(fill="both", padx=5,pady=5)
|
||||
run_frame.columnconfigure(0, weight=1)
|
||||
run_frame.columnconfigure(1, weight=1)
|
||||
run_test_button = tk.Button(run_frame, text = "Synch Files",
|
||||
font=font_list, command = self.Synchronize, bg='#006400', fg='#FF0000')
|
||||
run_test_button.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
open_button = tk.Button(run_frame, text = "Synch To All",
|
||||
font=font_list, command = self.SynchronizeAll, bg='#F4A460', fg='#F5F5F5')
|
||||
open_button.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
ssh_frame = tk.Frame(self.master)
|
||||
ssh_frame.pack(fill="both", padx=5,pady=5)
|
||||
ssh_frame.columnconfigure(0, weight=1)
|
||||
ssh_frame.columnconfigure(1, weight=1)
|
||||
ssh_button = tk.Button(ssh_frame, text = "Open SSH",
|
||||
font=font_list, command = self.OpenSSH, bg='#990033', fg='#F5F5F5')
|
||||
ssh_button.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
ssh_button = tk.Button(ssh_frame, text = "Pull Log",
|
||||
font=font_list, command = self.PullLog, bg='#990033', fg='#F5F5F5')
|
||||
ssh_button.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
config_frame = tk.Frame(self.master)
|
||||
config_frame.pack(fill="both", padx=5,pady=5)
|
||||
config_frame.columnconfigure(0, weight=1)
|
||||
config_frame.columnconfigure(1, weight=1)
|
||||
config_frame.columnconfigure(2, weight=1)
|
||||
|
||||
cmd_btn = tk.Button(config_frame, text = "Open CMD",
|
||||
font=font_list, command = self.OpenCMD, bg='#0033FF', fg='#F5F5F5')
|
||||
cmd_btn.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
cwd_btn = tk.Button(config_frame, text = "CWD",
|
||||
font=font_list, command = self.CWD, bg='#0033FF', fg='#F5F5F5')
|
||||
cwd_btn.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
gpu_btn = tk.Button(config_frame, text = "GPU Usage",
|
||||
font=font_list, command = self.GPUUsage, bg='#0033FF', fg='#F5F5F5')
|
||||
gpu_btn.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
################################################################################################
|
||||
|
||||
config_frame = tk.Frame(self.master)
|
||||
config_frame.pack(fill="both", padx=5,pady=5)
|
||||
config_frame.columnconfigure(0, weight=1)
|
||||
config_frame.columnconfigure(1, weight=1)
|
||||
# config_frame.columnconfigure(2, weight=1)
|
||||
|
||||
machine_btn = tk.Button(config_frame, text = "Ignore Conf",
|
||||
font=font_list, command = self.IgnoreConfig, bg='#660099', fg='#F5F5F5')
|
||||
machine_btn.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
machine_btn2 = tk.Button(config_frame, text = "Env Conf",
|
||||
font=font_list, command = self.EnvConfig, bg='#660099', fg='#F5F5F5')
|
||||
machine_btn2.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
log_frame = tk.Frame(self.master)
|
||||
log_frame.pack(fill="both", padx=5,pady=5)
|
||||
log_frame.columnconfigure(0, weight=1)
|
||||
log_frame.columnconfigure(1, weight=1)
|
||||
log_frame.columnconfigure(2, weight=1)
|
||||
|
||||
self.log_var = tkinter.StringVar()
|
||||
|
||||
self.log_com = ttk.Combobox(log_frame, textvariable=self.log_var)
|
||||
self.log_com.grid(row=0,column=0,sticky=tk.EW)
|
||||
def select_log(event):
|
||||
self.update_ckpt_task()
|
||||
self.log_com.bind("<<ComboboxSelected>>",select_log)
|
||||
|
||||
|
||||
log_update_button = tk.Button(log_frame, text = "Fresh",
|
||||
font=font_list, command = self.UpdateLog, bg='#F4A460', fg='#F5F5F5')
|
||||
log_update_button.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
log_update_button = tk.Button(log_frame, text = "Pull Log",
|
||||
font=font_list, command = self.PullLog, bg='#F4A460', fg='#F5F5F5')
|
||||
log_update_button.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
test_frame = tk.Frame(self.master)
|
||||
test_frame.pack(fill="both", padx=5,pady=5)
|
||||
test_frame.columnconfigure(0, weight=1)
|
||||
test_frame.columnconfigure(1, weight=1)
|
||||
test_frame.columnconfigure(2, weight=1)
|
||||
|
||||
self.test_var = tkinter.StringVar()
|
||||
|
||||
self.test_com = ttk.Combobox(test_frame, textvariable=self.test_var)
|
||||
self.test_com.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
|
||||
test_update_button = tk.Button(test_frame, text = "Fresh CKPT",
|
||||
font=font_list, command = self.UpdateCKPT, bg='#F4A460', fg='#F5F5F5')
|
||||
test_update_button.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
test_update_button = tk.Button(test_frame, text = "Test",
|
||||
font=font_list, command = self.Test, bg='#F4A460', fg='#F5F5F5')
|
||||
test_update_button.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
|
||||
# #################################################################################################
|
||||
# tbtext_frame = tk.Frame(self.master)
|
||||
# tbtext_frame.pack(fill="both", padx=5,pady=5)
|
||||
# tbtext_frame.columnconfigure(0, weight=5)
|
||||
|
||||
# self.tensorlog_str = tk.StringVar()
|
||||
# tb_text = tk.Entry(tbtext_frame,font=font_list, textvariable=self.tensorlog_str )
|
||||
# tb_text.grid(row=0,column=0,sticky=tk.EW)
|
||||
# config_tb = tk.Button(tbtext_frame, text = "Config",
|
||||
# font=font_list, command = self.OpenConfig, bg='#003472', fg='#F5F5F5')
|
||||
# config_tb.grid(row=0,column=1,sticky=tk.EW)
|
||||
# #################################################################################################
|
||||
|
||||
|
||||
# #################################################################################################
|
||||
# tb_frame = tk.Frame(self.master)
|
||||
# tb_frame.pack(fill="both", padx=5,pady=5)
|
||||
# tb_frame.columnconfigure(1, weight=1)
|
||||
# tb_frame.columnconfigure(0, weight=1)
|
||||
# open_tb = tk.Button(tb_frame, text = "Open Tensorboard",
|
||||
# font=font_list, command = self.OpenTensorboard, bg='#003472', fg='#F5F5F5')
|
||||
# open_tb.grid(row=0,column=0,sticky=tk.EW)
|
||||
# download_tb = tk.Button(tb_frame, text = "Update Tensorboard Logs",
|
||||
# font=font_list, command = self.DownloadTBLogs, bg='#003472', fg='#F5F5F5')
|
||||
# download_tb.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
# #################################################################################################
|
||||
|
||||
text = tk.Text(self.master, wrap="word")
|
||||
text.pack(fill="both",expand="yes", padx=5,pady=5)
|
||||
|
||||
|
||||
sys.stdout = TextRedirector(text, "stdout")
|
||||
|
||||
self.master.protocol("WM_DELETE_WINDOW", self.on_closing)
|
||||
|
||||
# def __scaning_logs__(self):
|
||||
def UpdateCKPT(self):
|
||||
thread_update = threading.Thread(target=self.update_ckpt_task)
|
||||
thread_update.start()
|
||||
|
||||
def update_ckpt_task(self):
|
||||
ip = self.list_com.get()
|
||||
log = self.log_com.get()
|
||||
cur_mac = self.machine_dict[ip]
|
||||
files = Path('.',cur_mac["ckp_path"], log)
|
||||
files = files.glob('*.pth')
|
||||
all_files = []
|
||||
for one_file in files:
|
||||
all_files.append(one_file.name)
|
||||
self.test_com["value"] =all_files
|
||||
self.test_com.current(0)
|
||||
|
||||
def Test(self):
|
||||
def test_task():
|
||||
log = self.log_com.get()
|
||||
ckpt = self.test_com.get()
|
||||
cwd = os.getcwd()
|
||||
files = str(Path(log, ckpt))
|
||||
print(files)
|
||||
subprocess.check_call("start cmd /k \"cd /d %s && conda activate base \
|
||||
&& python test.py --model %s\""%(cwd, files), shell=True)
|
||||
thread_update = threading.Thread(target=test_task)
|
||||
thread_update.start()
|
||||
|
||||
def Machines_Update(self):
|
||||
self.update_log_task()
|
||||
thread_update = threading.Thread(target=self.machines_update)
|
||||
thread_update.start()
|
||||
|
||||
def machines_update(self):
|
||||
self.machine_list = read_config(self.machine_json)
|
||||
ip_list = []
|
||||
for item in self.machine_list:
|
||||
self.machine_dict[item["ip"]] = item
|
||||
ip_list.append(item["ip"])
|
||||
self.list_com["value"] = ip_list
|
||||
self.list_com.current(0)
|
||||
ip = self.list_com.get()
|
||||
cur_mac = self.machine_dict[ip]
|
||||
str_temp= self.__label_text__(cur_mac["user"],cur_mac["path"])
|
||||
self.mac_text.set(str_temp)
|
||||
print("Machine list update success!")
|
||||
|
||||
def connection(self):
|
||||
ip = self.list_com.get()
|
||||
cur_mac = self.machine_dict[ip]
|
||||
ssh_ip = cur_mac["ip"]
|
||||
ssh_username = cur_mac["user"]
|
||||
ssh_passwd = cur_mac["passwd"]
|
||||
ssh_port = int(cur_mac["port"])
|
||||
print(ssh_ip)
|
||||
if ip.lower() == "local" or ip.lower() == "localhost":
|
||||
print("localhost no need to connect!")
|
||||
return [], cur_mac
|
||||
remotemachine = fileUploaderClass(ssh_ip,ssh_username,ssh_passwd,ssh_port)
|
||||
return remotemachine, cur_mac
|
||||
|
||||
def __decode_filestr__(self, filestr):
|
||||
cells = filestr.split("\n")
|
||||
print(cells)
|
||||
|
||||
def update_log_task(self):
|
||||
remotemachine,mac = self.connection()
|
||||
remote_path = os.path.join(mac["path"],mac["ckp_path"]).replace("\\", "/")
|
||||
if remotemachine == []:
|
||||
files = Path(remote_path).glob("*/")
|
||||
first_level = {}
|
||||
for one_file in files:
|
||||
first_level[one_file.name] = {
|
||||
"t":"",
|
||||
"p":""
|
||||
}
|
||||
else:
|
||||
first_level = remotemachine.sshScpGetNames(remote_path)
|
||||
logs = []
|
||||
for k,v in first_level.items():
|
||||
logs.append(k)
|
||||
logs = sorted(logs)
|
||||
self.log_com["value"] =logs
|
||||
self.log_com.current(0)
|
||||
self.current_log = first_level
|
||||
self.update_ckpt_task()
|
||||
|
||||
def UpdateLog(self):
|
||||
thread_update = threading.Thread(target=self.update_log_task)
|
||||
thread_update.start()
|
||||
|
||||
def PullLog(self):
|
||||
def pull_log_task():
|
||||
remotemachine,mac = self.connection()
|
||||
log = self.log_com.get()
|
||||
remote_path = self.current_log[log]["p"]
|
||||
if remotemachine == []:
|
||||
return
|
||||
all_level = remotemachine.sshScpGetRNames(remote_path)
|
||||
file_need_download = []
|
||||
local_position = []
|
||||
local_dir = Path("./",mac["ckp_path"],log)
|
||||
if not local_dir.exists():
|
||||
local_dir.mkdir()
|
||||
|
||||
for k,v in all_level.items():
|
||||
local_file = Path("./",mac["ckp_path"],log,k)
|
||||
if local_file.exists():
|
||||
if int(local_file.stat().st_mtime) < v["t"]:
|
||||
file_need_download.append(v["p"])
|
||||
local_position.append(str(local_file))
|
||||
# print(int(local_file.stat().st_mtime))
|
||||
# print(v["t"])
|
||||
else:
|
||||
file_need_download.append(v["p"])
|
||||
local_position.append(str(local_file))
|
||||
if len(file_need_download) > 0 :
|
||||
remotemachine.sshScpGetFiles(file_need_download, local_position)
|
||||
|
||||
else:
|
||||
print("No file need to pull......")
|
||||
self.update_ckpt_task()
|
||||
|
||||
thread_update = threading.Thread(target=pull_log_task)
|
||||
thread_update.start()
|
||||
|
||||
def OpenCMD(self):
|
||||
def open_cmd_task():
|
||||
subprocess.call("start cmd", shell=True)
|
||||
thread_update = threading.Thread(target=open_cmd_task)
|
||||
thread_update.start()
|
||||
|
||||
def CWD(self):
|
||||
def open_cmd_task():
|
||||
cwd = os.getcwd()
|
||||
subprocess.call("explorer "+cwd, shell=False)
|
||||
thread_update = threading.Thread(target=open_cmd_task)
|
||||
thread_update.start()
|
||||
|
||||
def OpenSSH(self):
|
||||
def open_ssh_task():
|
||||
ip = self.list_com.get()
|
||||
if ip.lower() == "local" or ip.lower() == "localhost":
|
||||
print("localhost no need to connect!")
|
||||
cur_mac = self.machine_dict[ip]
|
||||
ssh_ip = cur_mac["ip"]
|
||||
ssh_username = cur_mac["user"]
|
||||
ssh_passwd = cur_mac["passwd"]
|
||||
ssh_port = cur_mac["port"]
|
||||
# subprocess.call("start cmd", shell=True)
|
||||
subprocess.call("start cmd /k ssh %s@%s -p %s"%(ssh_username, ssh_ip, ssh_port), shell=True)
|
||||
# subprocess.call("start echo %s"%(ssh_passwd), shell=True)
|
||||
# p = Popen("cp -rf a/* b/", shell=True, stdout=PIPE, stderr=PIPE)
|
||||
# proc = subprocess.Popen("ssh %s@%s -p %s"%(ssh_username, ssh_ip, ssh_port),
|
||||
# stdin=subprocess.PIPE, stdout=subprocess.PIPE,
|
||||
# stderr=subprocess.PIPE,creationflags =subprocess.CREATE_NEW_CONSOLE)
|
||||
# # out, err = proc.communicate(ssh_passwd.encode("utf-8"))
|
||||
# proc.stdin.write(ssh_passwd.encode('utf-8'))
|
||||
# print(out.decode('utf-8'))
|
||||
|
||||
thread_update = threading.Thread(target=open_ssh_task)
|
||||
thread_update.start()
|
||||
|
||||
def GPUUsage(self):
|
||||
def gpu_usage_task():
|
||||
remotemachine,_ = self.connection()
|
||||
results = remotemachine.sshExec("nvidia-smi")
|
||||
print(results)
|
||||
|
||||
thread_update = threading.Thread(target=gpu_usage_task)
|
||||
thread_update.start()
|
||||
|
||||
def IgnoreConfig(self):
|
||||
def ignore_config_task():
|
||||
if not os.path.exists(self.ignore_json):
|
||||
print("guiignore.json file does not exist...")
|
||||
|
||||
if not os.path.exists(self.gui_root):
|
||||
os.makedirs(self.gui_root)
|
||||
write_config(self.ignore_json,self.ignore_text)
|
||||
subprocess.call("start %s"%self.ignore_json, shell=True)
|
||||
thread_update = threading.Thread(target=ignore_config_task)
|
||||
thread_update.start()
|
||||
|
||||
def EnvConfig(self):
|
||||
def env_config_task():
|
||||
root_dir = os.getcwd()
|
||||
logs_dir = os.path.join(root_dir,"env","env.json")
|
||||
if not os.path.exists(logs_dir):
|
||||
print("env.json file does not exist...")
|
||||
|
||||
if not os.path.exists(os.path.join(root_dir,"env")):
|
||||
os.makedirs(os.path.join(root_dir,"env"))
|
||||
write_config(logs_dir,self.env_text)
|
||||
subprocess.call("start env/env.json", shell=True)
|
||||
|
||||
thread_update = threading.Thread(target=env_config_task)
|
||||
thread_update.start()
|
||||
|
||||
def MachineConfig(self):
|
||||
def machine_config_task():
|
||||
subprocess.call("start %s"%self.machine_json, shell=True)
|
||||
thread_update = threading.Thread(target=machine_config_task)
|
||||
thread_update.start()
|
||||
|
||||
def OpenConfig(self):
|
||||
def open_config_task():
|
||||
root_dir = os.getcwd()
|
||||
logs_dir = os.path.join(root_dir,"env","logs_position.json")
|
||||
if not os.path.exists(logs_dir):
|
||||
print("logs configuration file does not exist...")
|
||||
positions={
|
||||
"template":{
|
||||
"root_path":"./",
|
||||
"machine_name":"localhost",
|
||||
}
|
||||
}
|
||||
if not os.path.exists(os.path.join(root_dir,"env")):
|
||||
os.makedirs(os.path.join(root_dir,"env"))
|
||||
write_config(logs_dir,positions)
|
||||
subprocess.call("start env/logs_position.json", shell=True)
|
||||
# time.sleep(5)
|
||||
# subprocess.call("start http://localhost:6006/", shell=True)
|
||||
thread_update = threading.Thread(target=open_config_task)
|
||||
thread_update.start()
|
||||
|
||||
def OpenTensorboard(self):
|
||||
thread_update = threading.Thread(target=self.open_tensorboard_task)
|
||||
thread_update.start()
|
||||
|
||||
def open_tensorboard_task(self):
|
||||
self.download_tblogs()
|
||||
root_dir = os.getcwd()
|
||||
logs_dir = os.path.join(root_dir,"train_logs")
|
||||
subprocess.call("start cmd /k tensorboard --logdir=\"%s\""%(logs_dir), shell=True)
|
||||
time.sleep(5)
|
||||
subprocess.call("start http://localhost:6006/", shell=True)
|
||||
|
||||
def DownloadTBLogs(self):
|
||||
thread_update = threading.Thread(target=self.download_tblogs)
|
||||
thread_update.start()
|
||||
|
||||
def download_tblogs(self):
|
||||
tb_monitor_logs = self.tensorlog_str.get()
|
||||
tb_monitor_logs = tb_monitor_logs.split(";")
|
||||
root_dir = os.getcwd()
|
||||
mach_dir = os.path.join(root_dir,"env","machine_config.json")
|
||||
machines = read_config(mach_dir)
|
||||
logs_dir = os.path.join(root_dir,"env","logs_position.json")
|
||||
tb_logs = read_config(logs_dir)
|
||||
|
||||
for i_logs in tb_monitor_logs:
|
||||
try:
|
||||
mac_name = tb_logs[i_logs]["machine_name"]
|
||||
i_mac = machines[mac_name]
|
||||
i_mac["log_name"] = i_logs
|
||||
# mac_list.append(i_mac)
|
||||
remotemachine = fileUploaderClass(i_mac["ip"],i_mac["usrname"],i_mac["passwd"],i_mac["port"])
|
||||
path_temp = Path(i_mac["root"],"train_logs",i_logs,"summary").as_posix()
|
||||
local_dir = Path(root_dir,"train_logs",i_logs,"summary")
|
||||
if not Path(local_dir).exists():
|
||||
Path(local_dir).mkdir(parents=True)
|
||||
remotemachine.sshScpGetDir(path_temp,local_dir)
|
||||
print("%s log files download successful!"%i_logs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def Synchronize(self):
|
||||
def update():
|
||||
self.update_action()
|
||||
thread_update = threading.Thread(target=update)
|
||||
thread_update.start()
|
||||
|
||||
def SynchronizeAll(self):
|
||||
def update_all():
|
||||
for i_mach in range(len(self.tab_info["configs"])):
|
||||
self.update_action(i_mach)
|
||||
thread_update = threading.Thread(target=update_all)
|
||||
thread_update.start()
|
||||
|
||||
def update_action(self):
|
||||
last_state = {}
|
||||
changed_files = []
|
||||
|
||||
ip = self.list_com.get()
|
||||
cur_mac = self.machine_dict[ip]
|
||||
if ip.lower() == "local" or ip.lower() == "localhost":
|
||||
print("localhost no need to update!")
|
||||
return
|
||||
ssh_ip = cur_mac["ip"]
|
||||
ssh_username = cur_mac["user"]
|
||||
ssh_passwd = cur_mac["passwd"]
|
||||
ssh_port = cur_mac["port"]
|
||||
root_path = cur_mac["path"]
|
||||
|
||||
log_path = os.path.join(self.filesynlogroot,cur_mac["logfilename"])
|
||||
|
||||
if not Path(self.filesynlogroot).exists():
|
||||
Path(self.filesynlogroot).mkdir(parents=True)
|
||||
else:
|
||||
if Path(log_path).exists():
|
||||
with open(log_path,'r') as cf:
|
||||
nodelocaltionstr = cf.read()
|
||||
last_state = json.loads(nodelocaltionstr)
|
||||
|
||||
all_files = []
|
||||
# scan files
|
||||
file_filter = read_config("./GUI/guiignore.json")
|
||||
|
||||
white_list = file_filter["white_list"]
|
||||
|
||||
black_list = file_filter["black_list"]
|
||||
|
||||
white_ext = white_list["extension"]
|
||||
|
||||
black_path = black_list["path"]
|
||||
|
||||
black_file = black_list["file"]
|
||||
|
||||
for item in white_ext:
|
||||
if item=="":
|
||||
print("something error in the white list")
|
||||
continue
|
||||
files = Path('.').rglob('*.%s'%item) # ./*
|
||||
for one_file in files:
|
||||
all_files.append(one_file)
|
||||
for i_dir in black_path:
|
||||
files = Path('.', i_dir).rglob('*.%s'%item)
|
||||
for one_file in files:
|
||||
# print(one_file)
|
||||
all_files.remove(one_file)
|
||||
for item in black_file:
|
||||
try:
|
||||
all_files.remove(Path('.', item))
|
||||
except:
|
||||
print("%s does not exist!"%item)
|
||||
|
||||
# check updated files
|
||||
for item in all_files:
|
||||
temp = item.stat().st_mtime
|
||||
if item._str in last_state:
|
||||
last_mtime = last_state[item._str]
|
||||
if last_mtime != temp:
|
||||
changed_files.append(item._str)
|
||||
last_state[item._str] = temp
|
||||
else:
|
||||
changed_files.append(item._str)
|
||||
last_state[item._str] = temp
|
||||
|
||||
print("[To %s]"%ssh_ip,changed_files)
|
||||
|
||||
localfiles = []
|
||||
remotefiles = []
|
||||
|
||||
for item in changed_files:
|
||||
localfiles.append(item)
|
||||
remotefiles.append(Path(root_path,item).as_posix())
|
||||
|
||||
try:
|
||||
remotemachine = fileUploaderClass(ssh_ip,ssh_username,ssh_passwd,ssh_port)
|
||||
remotemachine.sshScpPuts(localfiles,remotefiles)
|
||||
with open(log_path, 'w') as cf:
|
||||
configjson = json.dumps(last_state, indent=4)
|
||||
cf.writelines(configjson)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("File Synchronize Failed!")
|
||||
|
||||
# def __save_config__(self):
|
||||
|
||||
# previous_info = read_config(self.log_path)
|
||||
|
||||
# for i in range(len(self.tab_info["names"])):
|
||||
|
||||
# databind = self.tab_info["databind"][i]
|
||||
|
||||
# data_aquire = {
|
||||
# "name": self.tab_info["names"][i],
|
||||
# "remote_ip": databind["remote_ip"].get(),
|
||||
# "remote_user": databind["remote_user"].get(),
|
||||
# "remote_port": databind["remote_port"].get(),
|
||||
# "remote_passwd":databind["remote_passwd"].get(),
|
||||
# "remote_path": databind["remote_path"].get(),
|
||||
# "logfilename": "filestate_%s.json"%self.tab_info["names"][i]
|
||||
# }
|
||||
# if self.tab_info["names"][i] in previous_info["names"]:
|
||||
# location = previous_info["names"].index(self.tab_info["names"][i])
|
||||
# previous_info["configs"][location] = data_aquire
|
||||
|
||||
# else:
|
||||
# previous_info["names"].append(self.tab_info["names"][i])
|
||||
# previous_info["configs"].append(data_aquire)
|
||||
|
||||
# previous_info["databind"] = []
|
||||
# write_config(self.log_path,previous_info)
|
||||
|
||||
def on_closing(self):
|
||||
|
||||
# self.__save_config__()
|
||||
self.master.destroy()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = Application()
|
||||
app.mainloop()
|
||||
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Conditional_Discriminator copy.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 29th June 2021 4:26:33 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch.nn import utils
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, chn=32, k_size=3, n_class=3):
|
||||
super().__init__()
|
||||
# padding_size = int((k_size -1)/2)
|
||||
slop = 0.2
|
||||
enable_bias = True
|
||||
|
||||
# stage 1
|
||||
self.block1 = nn.Sequential(
|
||||
utils.spectral_norm(nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size= k_size, stride = 2, padding=2,bias= enable_bias)),
|
||||
nn.LeakyReLU(slop),
|
||||
utils.spectral_norm(nn.Conv2d(in_channels = chn, out_channels = chn * 2 , kernel_size= k_size, stride = 2,padding=2, bias= enable_bias)), # 1/4
|
||||
nn.LeakyReLU(slop)
|
||||
)
|
||||
self.aux_classfier1 = nn.Sequential(
|
||||
utils.spectral_norm(nn.Conv2d(in_channels = chn * 2 , out_channels = chn , kernel_size= 5, bias=enable_bias)),
|
||||
nn.LeakyReLU(slop),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
)
|
||||
self.embed1 = utils.spectral_norm(nn.Embedding(n_class, chn))
|
||||
self.linear1= utils.spectral_norm(nn.Linear(chn, 1))
|
||||
|
||||
# stage 2
|
||||
self.block2 = nn.Sequential(
|
||||
utils.spectral_norm(nn.Conv2d(in_channels = chn * 2 , out_channels = chn * 4 , kernel_size= k_size, stride = 2, padding=2, bias= enable_bias)),# 1/8
|
||||
nn.LeakyReLU(slop),
|
||||
utils.spectral_norm(nn.Conv2d(in_channels = chn * 4, out_channels = chn * 8 , kernel_size= k_size, stride = 2,padding=2, bias= enable_bias)),# 1/16
|
||||
nn.LeakyReLU(slop)
|
||||
)
|
||||
self.aux_classfier2 = nn.Sequential(
|
||||
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8 , out_channels = chn , kernel_size= 5, bias= enable_bias)),
|
||||
nn.LeakyReLU(slop),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
)
|
||||
self.embed2 = utils.spectral_norm(nn.Embedding(n_class, chn))
|
||||
self.linear2= utils.spectral_norm(nn.Linear(chn, 1))
|
||||
|
||||
# stage 3
|
||||
self.block3 = nn.Sequential(
|
||||
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8 , out_channels = chn * 8 , kernel_size= k_size, stride = 2,padding=3, bias= enable_bias)),# 1/32
|
||||
nn.LeakyReLU(slop),
|
||||
utils.spectral_norm(nn.Conv2d(in_channels = chn * 8, out_channels = chn * 16 , kernel_size= k_size, stride = 2,padding=3, bias= enable_bias)),# 1/64
|
||||
nn.LeakyReLU(slop)
|
||||
)
|
||||
self.aux_classfier3 = nn.Sequential(
|
||||
utils.spectral_norm(nn.Conv2d(in_channels = chn * 16 , out_channels = chn, kernel_size= 5, bias= enable_bias)),
|
||||
nn.LeakyReLU(slop),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
)
|
||||
self.embed3 = utils.spectral_norm(nn.Embedding(n_class, chn))
|
||||
self.linear3= utils.spectral_norm(nn.Linear(chn, 1))
|
||||
self.__weights_init__()
|
||||
|
||||
def __weights_init__(self):
|
||||
print("Init weights")
|
||||
for m in self.modules():
|
||||
if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
try:
|
||||
nn.init.zeros_(m.bias)
|
||||
except:
|
||||
print("No bias found!")
|
||||
|
||||
if isinstance(m, nn.Embedding):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
|
||||
def forward(self, input, condition):
|
||||
|
||||
h = self.block1(input)
|
||||
prep1 = self.aux_classfier1(h)
|
||||
prep1 = prep1.view(prep1.size()[0], -1)
|
||||
y1 = self.embed1(condition)
|
||||
y1 = torch.sum(y1 * prep1, dim=1, keepdim=True)
|
||||
prep1 = self.linear1(prep1) + y1
|
||||
|
||||
h = self.block2(h)
|
||||
prep2 = self.aux_classfier2(h)
|
||||
prep2 = prep2.view(prep2.size()[0], -1)
|
||||
y2 = self.embed2(condition)
|
||||
y2 = torch.sum(y2 * prep2, dim=1, keepdim=True)
|
||||
prep2 = self.linear2(prep2) + y2
|
||||
|
||||
h = self.block3(h)
|
||||
prep3 = self.aux_classfier3(h)
|
||||
prep3 = prep3.view(prep3.size()[0], -1)
|
||||
y3 = self.embed3(condition)
|
||||
y3 = torch.sum(y3 * prep3, dim=1, keepdim=True)
|
||||
prep3 = self.linear3(prep3) + y3
|
||||
|
||||
out_prep = [prep1,prep2,prep3]
|
||||
return out_prep
|
||||
|
||||
def get_outputs_len(self):
|
||||
num = 0
|
||||
for m in self.modules():
|
||||
if isinstance(m,nn.Linear):
|
||||
num+=1
|
||||
return num
|
||||
|
||||
if __name__ == "__main__":
|
||||
wocao = Discriminator().cuda()
|
||||
from torchsummary import summary
|
||||
summary(wocao, input_size=(3, 512, 512))
|
||||
@@ -0,0 +1,114 @@
|
||||
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Conditional_Generator_tanh.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 6th July 2021 1:16:46 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
from components.ResBlock import ResBlock
|
||||
from components.DeConv import DeConv
|
||||
from components.Conditional_ResBlock_ModulaConv import Conditional_ResBlock
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
chn=32,
|
||||
k_size=3,
|
||||
res_num = 5,
|
||||
class_num = 3,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
padding_size = int((k_size -1)/2)
|
||||
self.resblock_list = []
|
||||
self.n_class = class_num
|
||||
self.encoder1 = nn.Sequential(
|
||||
# nn.InstanceNorm2d(3, affine=True),
|
||||
# nn.ReflectionPad2d(padding_size),
|
||||
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size= k_size, stride=1, padding=1, bias= False),
|
||||
nn.InstanceNorm2d(chn, affine=True, momentum=0),
|
||||
# nn.ReLU(),
|
||||
nn.LeakyReLU(),
|
||||
# nn.ReflectionPad2d(padding_size),
|
||||
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size= k_size, stride=2, padding=1,bias =False), #
|
||||
nn.InstanceNorm2d(chn*2, affine=True, momentum=0),
|
||||
# nn.ReLU(),
|
||||
nn.LeakyReLU(),
|
||||
# nn.ReflectionPad2d(padding_size),
|
||||
nn.Conv2d(in_channels = chn*2, out_channels = chn * 4, kernel_size= k_size, stride=2, padding=1,bias =False),
|
||||
nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
|
||||
# nn.ReLU(),
|
||||
nn.LeakyReLU(),
|
||||
# nn.ReflectionPad2d(padding_size),
|
||||
nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size= k_size, stride=2, padding=1,bias =False),
|
||||
nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.ReLU(),
|
||||
nn.LeakyReLU(),
|
||||
# # nn.ReflectionPad2d(padding_size),
|
||||
nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size= k_size, stride=2, padding=1,bias =False),
|
||||
nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.ReLU(),
|
||||
nn.LeakyReLU()
|
||||
)
|
||||
|
||||
|
||||
res_size = chn * 8
|
||||
for _ in range(res_num-1):
|
||||
self.resblock_list += [ResBlock(res_size,k_size),]
|
||||
self.resblocks = nn.Sequential(*self.resblock_list)
|
||||
self.conditional_res = Conditional_ResBlock(res_size, k_size, class_num)
|
||||
self.decoder1 = nn.Sequential(
|
||||
DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size= k_size),
|
||||
nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.ReLU(),
|
||||
nn.LeakyReLU(),
|
||||
DeConv(in_channels = chn * 8, out_channels = chn *4, kernel_size= k_size),
|
||||
nn.InstanceNorm2d(chn *4, affine=True, momentum=0),
|
||||
# nn.ReLU(),
|
||||
nn.LeakyReLU(),
|
||||
DeConv(in_channels = chn * 4, out_channels = chn * 2 , kernel_size= k_size),
|
||||
nn.InstanceNorm2d(chn*2, affine=True, momentum=0),
|
||||
# nn.ReLU(),
|
||||
nn.LeakyReLU(),
|
||||
DeConv(in_channels = chn *2, out_channels = chn, kernel_size= k_size),
|
||||
nn.InstanceNorm2d(chn, affine=True, momentum=0),
|
||||
# nn.ReLU(),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn, out_channels =3, kernel_size= k_size, stride=1, padding=1,bias =True)
|
||||
# nn.Tanh()
|
||||
)
|
||||
|
||||
self.__weights_init__()
|
||||
|
||||
def __weights_init__(self):
|
||||
for layer in self.encoder1:
|
||||
if isinstance(layer,nn.Conv2d):
|
||||
nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input, condition=None, get_feature = False):
|
||||
feature = self.encoder1(input)
|
||||
if get_feature:
|
||||
return feature
|
||||
out = self.conditional_res(feature, condition)
|
||||
out = self.resblocks(out)
|
||||
# n, _,h,w = out.size()
|
||||
# attr = condition.view((n, self.n_class, 1, 1)).expand((n, self.n_class, h, w))
|
||||
# out = torch.cat([out, attr], dim=1)
|
||||
out = self.decoder1(out)
|
||||
return out,feature
|
||||
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Conditional_ResBlock_v2.py
|
||||
# Created Date: Tuesday June 29th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 29th June 2021 3:59:44 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
# -*- coding:utf-8 -*-
|
||||
###################################################################
|
||||
### @FilePath: \ASMegaGAN\components\Conditional_ResBlock_v2.py
|
||||
### @Author: Ziang Liu
|
||||
### @Date: 2021-06-28 21:30:17
|
||||
### @LastEditors: Ziang Liu
|
||||
### @LastEditTime: 2021-06-28 21:46:24
|
||||
### @Copyright (C) 2021 SJTU. All rights reserved.
|
||||
###################################################################
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
# from ops.Conditional_BN import Conditional_BN
|
||||
# from components.Adain import Adain
|
||||
|
||||
class Conv2DMod(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel, demod=True, stride=1, dilation=1, eps = 1e-8, **kwargs):
|
||||
super().__init__()
|
||||
self.filters = out_channels
|
||||
self.demod = demod
|
||||
self.kernel = kernel
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.weight = nn.Parameter(torch.randn((out_channels, in_channels, kernel, kernel)))
|
||||
self.eps = eps
|
||||
|
||||
padding_size = int((kernel -1)/2)
|
||||
self.same_padding = nn.ReplicationPad2d(padding_size)
|
||||
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
||||
|
||||
def forward(self, x, y):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
w1 = y[:, None, :, None, None]
|
||||
w2 = self.weight[None, :, :, :, :]
|
||||
weights = w2 * (w1 + 1)
|
||||
|
||||
if self.demod:
|
||||
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
|
||||
weights = weights * d
|
||||
|
||||
x = x.reshape(1, -1, h, w)
|
||||
|
||||
_, _, *ws = weights.shape
|
||||
weights = weights.reshape(b * self.filters, *ws)
|
||||
|
||||
x = self.same_padding(x)
|
||||
x = F.conv2d(x, weights, groups=b)
|
||||
|
||||
x = x.reshape(-1, self.filters, h, w)
|
||||
return x
|
||||
|
||||
class Conditional_ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, k_size = 3, n_class = 2, stride=1):
|
||||
super().__init__()
|
||||
|
||||
self.embed1 = nn.Embedding(n_class, in_channel)
|
||||
self.embed2 = nn.Embedding(n_class, in_channel)
|
||||
self.conv1 = Conv2DMod(in_channels = in_channel , out_channels = in_channel, kernel= k_size, stride=stride)
|
||||
self.conv2 = Conv2DMod(in_channels = in_channel , out_channels = in_channel, kernel= k_size, stride=stride)
|
||||
|
||||
def forward(self, input, condition):
|
||||
res = input
|
||||
style1 = self.embed1(condition)
|
||||
h = self.conv1(res, style1)
|
||||
style2 = self.embed2(condition)
|
||||
h = self.conv2(h, style2)
|
||||
out = h + res
|
||||
return out
|
||||
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class DeConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size = 3, upsampl_scale = 2):
|
||||
super().__init__()
|
||||
self.upsampling = nn.UpsamplingNearest2d(scale_factor=upsampl_scale)
|
||||
padding_size = int((kernel_size -1)/2)
|
||||
# self.same_padding = nn.ReflectionPad2d(padding_size)
|
||||
self.conv = nn.Conv2d(in_channels = in_channels ,padding=padding_size, out_channels = out_channels , kernel_size= kernel_size, bias= False)
|
||||
self.__weights_init__()
|
||||
|
||||
def __weights_init__(self):
|
||||
nn.init.xavier_uniform_(self.conv.weight)
|
||||
|
||||
def forward(self, input):
|
||||
h = self.upsampling(input)
|
||||
# h = self.same_padding(h)
|
||||
h = self.conv(h)
|
||||
return h
|
||||
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Conditional_Generator_gpt_LN_encoder copy.py
|
||||
# Created Date: Saturday October 9th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 26th October 2021 3:25:47 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
from components.DeConv import DeConv
|
||||
from components.network_swin import SwinTransformerBlock, PatchEmbed, PatchUnEmbed
|
||||
|
||||
class ImageLN(nn.Module):
|
||||
def __init__(self, dim) -> None:
|
||||
super().__init__()
|
||||
self.layer = nn.LayerNorm(dim)
|
||||
def forward(self, x):
|
||||
y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
return y
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
class_num = kwargs["n_class"]
|
||||
window_size = kwargs["window_size"]
|
||||
image_size = kwargs["image_size"]
|
||||
|
||||
padding_size = int((k_size -1)/2)
|
||||
|
||||
self.resblock_list = []
|
||||
embed_dim = 96
|
||||
window_size = 8
|
||||
num_heads = 8
|
||||
mlp_ratio = 2.
|
||||
norm_layer = nn.LayerNorm
|
||||
qk_scale = None
|
||||
qkv_bias = True
|
||||
self.patch_norm = True
|
||||
self.lnnorm = norm_layer(embed_dim)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
|
||||
ImageLN(chn),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
|
||||
ImageLN(chn * 2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels = chn*2, out_channels = embed_dim, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
ImageLN(embed_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# self.encoder2 = nn.Sequential(
|
||||
|
||||
# nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU()
|
||||
# )
|
||||
|
||||
self.fea_size = (image_size//4, image_size//4)
|
||||
# self.conditional_GPT = GPT_Spatial(2, res_dim, res_num, class_num)
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(dim=embed_dim, input_resolution=self.fea_size,
|
||||
num_heads=num_heads, window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=0.0, attn_drop=0.0,
|
||||
drop_path=0.1,
|
||||
norm_layer=norm_layer)
|
||||
for i in range(res_num)])
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn *4, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
DeConv(in_channels = embed_dim, out_channels = chn * 2 , kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
ImageLN(chn * 2),
|
||||
nn.ReLU(),
|
||||
DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
|
||||
ImageLN(chn),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
|
||||
)
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=self.fea_size[0], patch_size=1, in_chans=embed_dim, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
self.patch_unembed = PatchUnEmbed(
|
||||
img_size=self.fea_size[0], patch_size=1, in_chans=embed_dim, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input):
|
||||
x2 = self.encoder(input)
|
||||
x2 = self.patch_embed(x2)
|
||||
for blk in self.blocks:
|
||||
x2 = blk(x2,self.fea_size)
|
||||
x2 = self.lnnorm(x2)
|
||||
x2 = self.patch_unembed(x2,self.fea_size)
|
||||
out = self.decoder(x2)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
upscale = 4
|
||||
window_size = 8
|
||||
height = 1024
|
||||
width = 1024
|
||||
model = Generator()
|
||||
print(model)
|
||||
|
||||
x = torch.randn((1, 3, height, width))
|
||||
x = model(x)
|
||||
print(x.shape)
|
||||
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Conditional_Generator_gpt_LN_encoder copy.py
|
||||
# Created Date: Saturday October 9th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 11th October 2021 5:22:22 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
from components.ResBlock import ResBlock
|
||||
from components.DeConv import DeConv
|
||||
|
||||
class ImageLN(nn.Module):
|
||||
def __init__(self, dim) -> None:
|
||||
super().__init__()
|
||||
self.layer = nn.LayerNorm(dim)
|
||||
def forward(self, x):
|
||||
y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
return y
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
class_num = kwargs["n_class"]
|
||||
window_size = kwargs["window_size"]
|
||||
image_size = kwargs["image_size"]
|
||||
|
||||
padding_size = int((k_size -1)/2)
|
||||
|
||||
self.resblock_list = []
|
||||
embed_dim = 96
|
||||
window_size = 8
|
||||
num_heads = 8
|
||||
mlp_ratio = 2.
|
||||
norm_layer = nn.LayerNorm
|
||||
qk_scale = None
|
||||
qkv_bias = True
|
||||
self.patch_norm = True
|
||||
self.lnnorm = norm_layer(embed_dim)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
|
||||
nn.InstanceNorm2d(chn),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
|
||||
nn.InstanceNorm2d(chn * 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn*2, out_channels = embed_dim, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
nn.InstanceNorm2d(embed_dim),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
|
||||
# self.encoder2 = nn.Sequential(
|
||||
|
||||
# nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU()
|
||||
# )
|
||||
self.decoder = nn.Sequential(
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn *4, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
DeConv(in_channels = embed_dim, out_channels = chn * 2 , kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
nn.InstanceNorm2d(chn * 2),
|
||||
nn.LeakyReLU(),
|
||||
DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
|
||||
nn.InstanceNorm2d(chn),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
|
||||
)
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input):
|
||||
x2 = self.encoder(input)
|
||||
out = self.decoder(x2)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
upscale = 4
|
||||
window_size = 8
|
||||
height = 1024
|
||||
width = 1024
|
||||
model = Generator()
|
||||
print(model)
|
||||
|
||||
x = torch.randn((1, 3, height, width))
|
||||
x = model(x)
|
||||
print(x.shape)
|
||||
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Conditional_Generator_gpt_LN_encoder copy.py
|
||||
# Created Date: Saturday October 9th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 7:35:08 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
from components.ResBlock import ResBlock
|
||||
from components.DeConv import DeConv
|
||||
|
||||
class ImageLN(nn.Module):
|
||||
def __init__(self, dim) -> None:
|
||||
super().__init__()
|
||||
self.layer = nn.LayerNorm(dim)
|
||||
def forward(self, x):
|
||||
y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
return y
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
|
||||
padding_size = int((k_size -1)/2)
|
||||
|
||||
self.resblock_list = []
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
|
||||
nn.InstanceNorm2d(chn, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
|
||||
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
for _ in range(res_num):
|
||||
self.resblock_list += [ResBlock(chn * 4,k_size),]
|
||||
self.resblocks = nn.Sequential(*self.resblock_list)
|
||||
self.decoder = nn.Sequential(
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size),
|
||||
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
DeConv(in_channels = chn * 2, out_channels = chn * 2 , kernel_size=k_size),
|
||||
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
nn.ReLU(),
|
||||
DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
|
||||
nn.InstanceNorm2d(chn, affine=True, momentum=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
|
||||
)
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input):
|
||||
x2 = self.encoder(input)
|
||||
x2 = self.resblocks(x2)
|
||||
out = self.decoder(x2)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
upscale = 4
|
||||
window_size = 8
|
||||
height = 1024
|
||||
width = 1024
|
||||
model = Generator()
|
||||
print(model)
|
||||
|
||||
x = torch.randn((1, 3, height, width))
|
||||
x = model(x)
|
||||
print(x.shape)
|
||||
@@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: FastNST_Liif.py
|
||||
# Created Date: Thursday October 14th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 2:39:09 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
from components.ResBlock import ResBlock
|
||||
from components.DeConv import DeConv
|
||||
from components.Liif import LIIF
|
||||
|
||||
class ImageLN(nn.Module):
|
||||
def __init__(self, dim) -> None:
|
||||
super().__init__()
|
||||
self.layer = nn.LayerNorm(dim)
|
||||
def forward(self, x):
|
||||
y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
return y
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
class_num = kwargs["n_class"]
|
||||
window_size = kwargs["window_size"]
|
||||
image_size = kwargs["image_size"]
|
||||
batch_size = kwargs["batch_size"]
|
||||
# mlp_in_dim = kwargs["mlp_in_dim"]
|
||||
# mlp_out_dim = kwargs["mlp_out_dim"]
|
||||
mlp_hidden_list = kwargs["mlp_hidden_list"]
|
||||
|
||||
padding_size = int((k_size -1)/2)
|
||||
|
||||
self.resblock_list = []
|
||||
embed_dim = 96
|
||||
window_size = 8
|
||||
num_heads = 8
|
||||
mlp_ratio = 2.
|
||||
norm_layer = nn.LayerNorm
|
||||
qk_scale = None
|
||||
qkv_bias = True
|
||||
self.patch_norm = True
|
||||
self.lnnorm = norm_layer(embed_dim)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
|
||||
nn.InstanceNorm2d(chn),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
|
||||
nn.InstanceNorm2d(chn * 2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
nn.InstanceNorm2d(chn * 4),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
ImageLN(chn * 4),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
for _ in range(res_num):
|
||||
self.resblock_list += [ResBlock(chn * 4,k_size),]
|
||||
self.resblocks = nn.Sequential(*self.resblock_list)
|
||||
# self.encoder2 = nn.Sequential(
|
||||
|
||||
# nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU()
|
||||
# )
|
||||
self.decoder = nn.Sequential(
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size),
|
||||
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
nn.InstanceNorm2d(chn),
|
||||
nn.LeakyReLU()
|
||||
# DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
|
||||
)
|
||||
|
||||
self.upsample = LIIF(chn, 3, mlp_hidden_list)
|
||||
self.upsample.gen_coord((batch_size, \
|
||||
chn,image_size//2,image_size//2),(image_size,image_size))
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input):
|
||||
x2 = self.encoder(input)
|
||||
x2 = self.resblocks(x2)
|
||||
out = self.decoder(x2)
|
||||
out = self.upsample(out)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
upscale = 4
|
||||
window_size = 8
|
||||
height = 1024
|
||||
width = 1024
|
||||
model = Generator()
|
||||
print(model)
|
||||
|
||||
x = torch.randn((1, 3, height, width))
|
||||
x = model(x)
|
||||
print(x.shape)
|
||||
@@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: FastNST_Liif.py
|
||||
# Created Date: Thursday October 14th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 4:33:51 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
from components.ResBlock import ResBlock
|
||||
from components.DeConv import DeConv
|
||||
from components.Liif_conv import LIIF
|
||||
|
||||
class ImageLN(nn.Module):
|
||||
def __init__(self, dim) -> None:
|
||||
super().__init__()
|
||||
self.layer = nn.LayerNorm(dim)
|
||||
def forward(self, x):
|
||||
y = self.layer(x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
return y
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
class_num = kwargs["n_class"]
|
||||
window_size = kwargs["window_size"]
|
||||
image_size = kwargs["image_size"]
|
||||
batch_size = kwargs["batch_size"]
|
||||
# mlp_in_dim = kwargs["mlp_in_dim"]
|
||||
# mlp_out_dim = kwargs["mlp_out_dim"]
|
||||
|
||||
|
||||
padding_size = int((k_size -1)/2)
|
||||
|
||||
self.resblock_list = []
|
||||
embed_dim = 96
|
||||
window_size = 8
|
||||
num_heads = 8
|
||||
mlp_ratio = 2.
|
||||
norm_layer = nn.LayerNorm
|
||||
qk_scale = None
|
||||
qkv_bias = True
|
||||
self.patch_norm = True
|
||||
self.lnnorm = norm_layer(embed_dim)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
|
||||
nn.InstanceNorm2d(chn, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
|
||||
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
for _ in range(res_num):
|
||||
self.resblock_list += [ResBlock(chn * 4,k_size),]
|
||||
self.resblocks = nn.Sequential(*self.resblock_list)
|
||||
# self.encoder2 = nn.Sequential(
|
||||
|
||||
# nn.Conv2d(in_channels = chn*4 , out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# ImageLN(chn * 8),
|
||||
# nn.LeakyReLU()
|
||||
# )
|
||||
self.decoder = nn.Sequential(
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
# DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size),
|
||||
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
# DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size),
|
||||
# # nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
# nn.InstanceNorm2d(chn, affine=True, momentum=0),
|
||||
# nn.LeakyReLU()
|
||||
# DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
|
||||
)
|
||||
|
||||
self.upsample1 = LIIF(chn*2, chn)
|
||||
self.upsample1.gen_coord((batch_size, \
|
||||
chn,image_size//4,image_size//4),(image_size//2,image_size//2))
|
||||
|
||||
self.upsample2 = LIIF(chn, chn)
|
||||
self.upsample2.gen_coord((batch_size, \
|
||||
chn,image_size//2,image_size//2),(image_size,image_size))
|
||||
self.out_conv = nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input):
|
||||
x2 = self.encoder(input)
|
||||
x2 = self.resblocks(x2)
|
||||
out = self.decoder(x2)
|
||||
out = self.upsample1(out)
|
||||
out = self.upsample2(out)
|
||||
out = self.out_conv(out)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
upscale = 4
|
||||
window_size = 8
|
||||
height = 1024
|
||||
width = 1024
|
||||
model = Generator()
|
||||
print(model)
|
||||
|
||||
x = torch.randn((1, 3, height, width))
|
||||
x = model(x)
|
||||
print(x.shape)
|
||||
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: FastNST_Liif.py
|
||||
# Created Date: Thursday October 14th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 8:47:28 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
from components.ResBlock import ResBlock
|
||||
from components.DeConv import DeConv
|
||||
from components.Liif_invo import LIIF
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
class_num = kwargs["n_class"]
|
||||
window_size = kwargs["window_size"]
|
||||
image_size = kwargs["image_size"]
|
||||
batch_size = kwargs["batch_size"]
|
||||
# mlp_in_dim = kwargs["mlp_in_dim"]
|
||||
# mlp_out_dim = kwargs["mlp_out_dim"]
|
||||
|
||||
|
||||
padding_size = int((k_size -1)/2)
|
||||
|
||||
self.resblock_list = []
|
||||
embed_dim = 96
|
||||
norm_layer = nn.LayerNorm
|
||||
|
||||
self.img_token = nn.Sequential(
|
||||
nn.Conv2d(in_channels = 3 , out_channels = chn , kernel_size=k_size, stride=1, padding=1, bias= False),
|
||||
nn.InstanceNorm2d(chn, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn , out_channels = chn*2, kernel_size=k_size, stride=2, padding=1,bias =False), #
|
||||
nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn*2, out_channels = chn*4, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = chn*4 , out_channels = chn * 4, kernel_size=k_size, stride=2, padding=1,bias =False),
|
||||
# nn.InstanceNorm2d(chn * 4, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
)
|
||||
image_size = image_size // 2
|
||||
self.downsample1 = LIIF(chn * 2, chn * 4)
|
||||
self.downsample1.gen_coord((batch_size, \
|
||||
chn,image_size,image_size),(image_size//2,image_size//2))
|
||||
image_size = image_size // 2
|
||||
self.downsample2 = LIIF(chn * 4, chn * 4)
|
||||
self.downsample2.gen_coord((batch_size, \
|
||||
chn,image_size,image_size),(image_size//2,image_size//2))
|
||||
|
||||
|
||||
for _ in range(res_num):
|
||||
self.resblock_list += [ResBlock(chn * 4,k_size),]
|
||||
self.resblocks = nn.Sequential(*self.resblock_list)
|
||||
# self.decoder = nn.Sequential(
|
||||
# # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# # nn.LeakyReLU(),
|
||||
# # DeConv(in_channels = chn * 8, out_channels = chn * 8, kernel_size=k_size),
|
||||
# # nn.InstanceNorm2d(chn * 8, affine=True, momentum=0),
|
||||
# # nn.LeakyReLU(),
|
||||
# DeConv(in_channels = chn * 4, out_channels = chn *2, kernel_size=k_size),
|
||||
# nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
# nn.LeakyReLU(),
|
||||
# # DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size),
|
||||
# # # nn.InstanceNorm2d(chn * 2, affine=True, momentum=0),
|
||||
# # nn.InstanceNorm2d(chn, affine=True, momentum=0),
|
||||
# # nn.LeakyReLU()
|
||||
# # DeConv(in_channels = chn *2, out_channels = chn, kernel_size=k_size),
|
||||
# # nn.InstanceNorm2d(chn),
|
||||
# # nn.LeakyReLU(),
|
||||
# # nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
|
||||
# )
|
||||
image_size = image_size // 2
|
||||
self.upsample1 = LIIF(chn*4, chn * 4)
|
||||
self.upsample1.gen_coord((batch_size, \
|
||||
chn,image_size,image_size),(image_size*2,image_size*2))
|
||||
image_size = image_size * 2
|
||||
self.upsample2 = LIIF(chn*4, chn * 2)
|
||||
self.upsample2.gen_coord((batch_size, \
|
||||
chn,image_size,image_size),(image_size*2,image_size*2))
|
||||
# image_size = image_size * 2
|
||||
# self.upsample2 = LIIF(chn, chn)
|
||||
# self.upsample2.gen_coord((batch_size, \
|
||||
# chn,image_size,image_size),(image_size*2,image_size*2))
|
||||
self.decoder = nn.Sequential(
|
||||
DeConv(in_channels = chn * 2, out_channels = chn, kernel_size=k_size),
|
||||
nn.InstanceNorm2d(chn, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
|
||||
)
|
||||
# self.out_conv = nn.Conv2d(in_channels = chn, out_channels =3, kernel_size=k_size, stride=1, padding=1,bias =True)
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.img_token(input)
|
||||
out = self.downsample1(out)
|
||||
out = self.downsample2(out)
|
||||
out = self.resblocks(out)
|
||||
|
||||
out = self.upsample1(out)
|
||||
out = self.upsample2(out)
|
||||
out = self.decoder(out)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
upscale = 4
|
||||
window_size = 8
|
||||
height = 1024
|
||||
width = 1024
|
||||
model = Generator()
|
||||
print(model)
|
||||
|
||||
x = torch.randn((1, 3, height, width))
|
||||
x = model(x)
|
||||
print(x.shape)
|
||||
@@ -0,0 +1,303 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Involution.py
|
||||
# Created Date: Tuesday July 20th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 20th July 2021 10:35:52 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
from torch.autograd import Function
|
||||
import torch
|
||||
from torch.nn.modules.utils import _pair
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
|
||||
from collections import namedtuple
|
||||
import cupy
|
||||
from string import Template
|
||||
|
||||
|
||||
Stream = namedtuple('Stream', ['ptr'])
|
||||
|
||||
|
||||
def Dtype(t):
|
||||
if isinstance(t, torch.cuda.FloatTensor):
|
||||
return 'float'
|
||||
elif isinstance(t, torch.cuda.DoubleTensor):
|
||||
return 'double'
|
||||
|
||||
|
||||
@cupy._util.memoize(for_each_device=True)
|
||||
def load_kernel(kernel_name, code, **kwargs):
|
||||
code = Template(code).substitute(**kwargs)
|
||||
kernel_code = cupy.cuda.compile_with_cache(code)
|
||||
return kernel_code.get_function(kernel_name)
|
||||
|
||||
|
||||
CUDA_NUM_THREADS = 1024
|
||||
|
||||
kernel_loop = '''
|
||||
#define CUDA_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
||||
i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
'''
|
||||
|
||||
|
||||
def GET_BLOCKS(N):
|
||||
return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS
|
||||
|
||||
|
||||
_involution_kernel = kernel_loop + '''
|
||||
extern "C"
|
||||
__global__ void involution_forward_kernel(
|
||||
const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) {
|
||||
CUDA_KERNEL_LOOP(index, ${nthreads}) {
|
||||
const int n = index / ${channels} / ${top_height} / ${top_width};
|
||||
const int c = (index / ${top_height} / ${top_width}) % ${channels};
|
||||
const int h = (index / ${top_width}) % ${top_height};
|
||||
const int w = index % ${top_width};
|
||||
const int g = c / (${channels} / ${groups});
|
||||
${Dtype} value = 0;
|
||||
#pragma unroll
|
||||
for (int kh = 0; kh < ${kernel_h}; ++kh) {
|
||||
#pragma unroll
|
||||
for (int kw = 0; kw < ${kernel_w}; ++kw) {
|
||||
const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h};
|
||||
const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w};
|
||||
if ((h_in >= 0) && (h_in < ${bottom_height})
|
||||
&& (w_in >= 0) && (w_in < ${bottom_width})) {
|
||||
const int offset = ((n * ${channels} + c) * ${bottom_height} + h_in)
|
||||
* ${bottom_width} + w_in;
|
||||
const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h)
|
||||
* ${top_width} + w;
|
||||
value += weight_data[offset_weight] * bottom_data[offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
top_data[index] = value;
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
|
||||
_involution_kernel_backward_grad_input = kernel_loop + '''
|
||||
extern "C"
|
||||
__global__ void involution_backward_grad_input_kernel(
|
||||
const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* const bottom_diff) {
|
||||
CUDA_KERNEL_LOOP(index, ${nthreads}) {
|
||||
const int n = index / ${channels} / ${bottom_height} / ${bottom_width};
|
||||
const int c = (index / ${bottom_height} / ${bottom_width}) % ${channels};
|
||||
const int h = (index / ${bottom_width}) % ${bottom_height};
|
||||
const int w = index % ${bottom_width};
|
||||
const int g = c / (${channels} / ${groups});
|
||||
${Dtype} value = 0;
|
||||
#pragma unroll
|
||||
for (int kh = 0; kh < ${kernel_h}; ++kh) {
|
||||
#pragma unroll
|
||||
for (int kw = 0; kw < ${kernel_w}; ++kw) {
|
||||
const int h_out_s = h + ${pad_h} - kh * ${dilation_h};
|
||||
const int w_out_s = w + ${pad_w} - kw * ${dilation_w};
|
||||
if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) {
|
||||
const int h_out = h_out_s / ${stride_h};
|
||||
const int w_out = w_out_s / ${stride_w};
|
||||
if ((h_out >= 0) && (h_out < ${top_height})
|
||||
&& (w_out >= 0) && (w_out < ${top_width})) {
|
||||
const int offset = ((n * ${channels} + c) * ${top_height} + h_out)
|
||||
* ${top_width} + w_out;
|
||||
const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h_out)
|
||||
* ${top_width} + w_out;
|
||||
value += weight_data[offset_weight] * top_diff[offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
bottom_diff[index] = value;
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
|
||||
_involution_kernel_backward_grad_weight = kernel_loop + '''
|
||||
extern "C"
|
||||
__global__ void involution_backward_grad_weight_kernel(
|
||||
const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* const buffer_data) {
|
||||
CUDA_KERNEL_LOOP(index, ${nthreads}) {
|
||||
const int h = (index / ${top_width}) % ${top_height};
|
||||
const int w = index % ${top_width};
|
||||
const int kh = (index / ${kernel_w} / ${top_height} / ${top_width})
|
||||
% ${kernel_h};
|
||||
const int kw = (index / ${top_height} / ${top_width}) % ${kernel_w};
|
||||
const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h};
|
||||
const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w};
|
||||
if ((h_in >= 0) && (h_in < ${bottom_height})
|
||||
&& (w_in >= 0) && (w_in < ${bottom_width})) {
|
||||
const int g = (index / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${groups};
|
||||
const int n = (index / ${groups} / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${num};
|
||||
${Dtype} value = 0;
|
||||
#pragma unroll
|
||||
for (int c = g * (${channels} / ${groups}); c < (g + 1) * (${channels} / ${groups}); ++c) {
|
||||
const int top_offset = ((n * ${channels} + c) * ${top_height} + h)
|
||||
* ${top_width} + w;
|
||||
const int bottom_offset = ((n * ${channels} + c) * ${bottom_height} + h_in)
|
||||
* ${bottom_width} + w_in;
|
||||
value += top_diff[top_offset] * bottom_data[bottom_offset];
|
||||
}
|
||||
buffer_data[index] = value;
|
||||
} else {
|
||||
buffer_data[index] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
|
||||
class _involution(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, stride, padding, dilation):
|
||||
assert input.dim() == 4 and input.is_cuda
|
||||
assert weight.dim() == 6 and weight.is_cuda
|
||||
batch_size, channels, height, width = input.size()
|
||||
kernel_h, kernel_w = weight.size()[2:4]
|
||||
output_h = int((height + 2 * padding[0] - (dilation[0] * (kernel_h - 1) + 1)) / stride[0] + 1)
|
||||
output_w = int((width + 2 * padding[1] - (dilation[1] * (kernel_w - 1) + 1)) / stride[1] + 1)
|
||||
|
||||
output = input.new(batch_size, channels, output_h, output_w)
|
||||
n = output.numel()
|
||||
|
||||
with torch.cuda.device_of(input):
|
||||
f = load_kernel('involution_forward_kernel', _involution_kernel, Dtype=Dtype(input), nthreads=n,
|
||||
num=batch_size, channels=channels, groups=weight.size()[1],
|
||||
bottom_height=height, bottom_width=width,
|
||||
top_height=output_h, top_width=output_w,
|
||||
kernel_h=kernel_h, kernel_w=kernel_w,
|
||||
stride_h=stride[0], stride_w=stride[1],
|
||||
dilation_h=dilation[0], dilation_w=dilation[1],
|
||||
pad_h=padding[0], pad_w=padding[1])
|
||||
f(block=(CUDA_NUM_THREADS,1,1),
|
||||
grid=(GET_BLOCKS(n),1,1),
|
||||
args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()],
|
||||
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
|
||||
|
||||
ctx.save_for_backward(input, weight)
|
||||
ctx.stride, ctx.padding, ctx.dilation = stride, padding, dilation
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
assert grad_output.is_cuda and grad_output.is_contiguous()
|
||||
input, weight = ctx.saved_tensors
|
||||
stride, padding, dilation = ctx.stride, ctx.padding, ctx.dilation
|
||||
|
||||
batch_size, channels, height, width = input.size()
|
||||
kernel_h, kernel_w = weight.size()[2:4]
|
||||
output_h, output_w = grad_output.size()[2:]
|
||||
|
||||
grad_input, grad_weight = None, None
|
||||
|
||||
opt = dict(Dtype=Dtype(grad_output),
|
||||
num=batch_size, channels=channels, groups=weight.size()[1],
|
||||
bottom_height=height, bottom_width=width,
|
||||
top_height=output_h, top_width=output_w,
|
||||
kernel_h=kernel_h, kernel_w=kernel_w,
|
||||
stride_h=stride[0], stride_w=stride[1],
|
||||
dilation_h=dilation[0], dilation_w=dilation[1],
|
||||
pad_h=padding[0], pad_w=padding[1])
|
||||
|
||||
with torch.cuda.device_of(input):
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = input.new(input.size())
|
||||
|
||||
n = grad_input.numel()
|
||||
opt['nthreads'] = n
|
||||
|
||||
f = load_kernel('involution_backward_grad_input_kernel',
|
||||
_involution_kernel_backward_grad_input, **opt)
|
||||
f(block=(CUDA_NUM_THREADS,1,1),
|
||||
grid=(GET_BLOCKS(n),1,1),
|
||||
args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()],
|
||||
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = weight.new(weight.size())
|
||||
|
||||
n = grad_weight.numel()
|
||||
opt['nthreads'] = n
|
||||
|
||||
f = load_kernel('involution_backward_grad_weight_kernel',
|
||||
_involution_kernel_backward_grad_weight, **opt)
|
||||
f(block=(CUDA_NUM_THREADS,1,1),
|
||||
grid=(GET_BLOCKS(n),1,1),
|
||||
args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()],
|
||||
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
|
||||
|
||||
return grad_input, grad_weight, None, None, None
|
||||
|
||||
|
||||
def _involution_cuda(input, weight, bias=None, stride=1, padding=0, dilation=1):
|
||||
""" involution kernel
|
||||
"""
|
||||
assert input.size(0) == weight.size(0)
|
||||
assert input.size(-2)//stride == weight.size(-2)
|
||||
assert input.size(-1)//stride == weight.size(-1)
|
||||
if input.is_cuda:
|
||||
out = _involution.apply(input, weight, _pair(stride), _pair(padding), _pair(dilation))
|
||||
if bias is not None:
|
||||
out += bias.view(1,-1,1,1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return out
|
||||
|
||||
|
||||
class involution(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride):
|
||||
super(involution, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.channels = channels
|
||||
reduction_ratio = 4
|
||||
self.group_channels = 8
|
||||
self.groups = self.channels // self.group_channels
|
||||
self.seblock = nn.Sequential(
|
||||
nn.Conv2d(in_channels = channels, out_channels = channels // reduction_ratio, kernel_size= 1),
|
||||
nn.InstanceNorm2d(channels // reduction_ratio, affine=True, momentum=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels = channels // reduction_ratio, out_channels = kernel_size**2 * self.groups, kernel_size= 1)
|
||||
)
|
||||
|
||||
# self.conv1 = ConvModule(
|
||||
# in_channels=channels,
|
||||
# out_channels=channels // reduction_ratio,
|
||||
# kernel_size=1,
|
||||
# conv_cfg=None,
|
||||
# norm_cfg=dict(type='BN'),
|
||||
# act_cfg=dict(type='ReLU'))
|
||||
# self.conv2 = ConvModule(
|
||||
# in_channels=channels // reduction_ratio,
|
||||
# out_channels=kernel_size**2 * self.groups,
|
||||
# kernel_size=1,
|
||||
# stride=1,
|
||||
# conv_cfg=None,
|
||||
# norm_cfg=None,
|
||||
# act_cfg=None)
|
||||
if stride > 1:
|
||||
self.avgpool = nn.AvgPool2d(stride, stride)
|
||||
|
||||
def forward(self, x):
|
||||
# weight = self.conv2(self.conv1(x if self.stride == 1 else self.avgpool(x)))
|
||||
weight = self.seblock(x)
|
||||
b, c, h, w = weight.shape
|
||||
weight = weight.view(b, self.groups, self.kernel_size, self.kernel_size, h, w)
|
||||
out = _involution_cuda(x, weight, stride=self.stride, padding=(self.kernel_size-1)//2)
|
||||
return out
|
||||
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Liif.py
|
||||
# Created Date: Monday October 18th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 10:27:09 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_coord(shape, ranges=None, flatten=True):
|
||||
""" Make coordinates at grid centers.
|
||||
"""
|
||||
coord_seqs = []
|
||||
for i, n in enumerate(shape):
|
||||
print("i: %d, n: %d"%(i,n))
|
||||
if ranges is None:
|
||||
v0, v1 = -1, 1
|
||||
else:
|
||||
v0, v1 = ranges[i]
|
||||
r = (v1 - v0) / (2 * n)
|
||||
seq = v0 + r + (2 * r) * torch.arange(n).float()
|
||||
coord_seqs.append(seq)
|
||||
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
|
||||
if flatten:
|
||||
ret = ret.view(-1, ret.shape[-1])
|
||||
return ret
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, hidden_list):
|
||||
super().__init__()
|
||||
layers = []
|
||||
lastv = in_dim
|
||||
for hidden in hidden_list:
|
||||
layers.append(nn.Linear(lastv, hidden))
|
||||
layers.append(nn.ReLU())
|
||||
lastv = hidden
|
||||
layers.append(nn.Linear(lastv, out_dim))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
shape = x.shape[:-1]
|
||||
x = self.layers(x.view(-1, x.shape[-1]))
|
||||
return x.view(*shape, -1)
|
||||
|
||||
class LIIF(nn.Module):
|
||||
|
||||
def __init__(self, mlp_in_dim, mlp_out_dim, mlp_hidden_list):
|
||||
super().__init__()
|
||||
|
||||
imnet_in_dim = mlp_in_dim
|
||||
imnet_in_dim *= 9
|
||||
imnet_in_dim += 2 # attach coord
|
||||
imnet_in_dim += 2
|
||||
self.imnet = MLP(imnet_in_dim, mlp_out_dim, mlp_hidden_list).cuda()
|
||||
|
||||
def gen_coord(self, in_shape, output_size):
|
||||
|
||||
self.vx_lst = [-1, 1]
|
||||
self.vy_lst = [-1, 1]
|
||||
eps_shift = 1e-6
|
||||
self.image_size=output_size
|
||||
|
||||
# field radius (global: [-1, 1])
|
||||
rx = 2 / in_shape[-2] / 2
|
||||
ry = 2 / in_shape[-1] / 2
|
||||
|
||||
coord = make_coord(output_size,flatten=False) \
|
||||
.expand(in_shape[0],output_size[0],output_size[1],2) \
|
||||
.view(in_shape[0],output_size[0]*output_size[1],2)
|
||||
|
||||
cell = torch.ones_like(coord)
|
||||
cell[:, :, 0] *= 2 / coord.shape[-2]
|
||||
cell[:, :, 1] *= 2 / coord.shape[-1]
|
||||
|
||||
feat_coord = make_coord(in_shape[-2:], flatten=False) \
|
||||
.permute(2, 0, 1) \
|
||||
.unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:])
|
||||
|
||||
areas = []
|
||||
|
||||
self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
||||
self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
||||
self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
||||
for vx in self.vx_lst:
|
||||
for vy in self.vy_lst:
|
||||
self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone()
|
||||
self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift
|
||||
self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift
|
||||
self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
|
||||
q_coord = F.grid_sample(
|
||||
feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1),
|
||||
mode='nearest', align_corners=False)[:, :, 0, :] \
|
||||
.permute(0, 2, 1)
|
||||
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord
|
||||
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
|
||||
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
|
||||
|
||||
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone()
|
||||
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
|
||||
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
|
||||
area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1])
|
||||
areas.append(area + 1e-9)
|
||||
tot_area = torch.stack(areas).sum(dim=0)
|
||||
t = areas[0]; areas[0] = areas[3]; areas[3] = t
|
||||
t = areas[1]; areas[1] = areas[2]; areas[2] = t
|
||||
self.area_weights = []
|
||||
for item in areas:
|
||||
self.area_weights.append((item / tot_area).unsqueeze(-1).cuda())
|
||||
|
||||
self.rel_coord = self.rel_coord.cuda()
|
||||
self.rel_cell = self.rel_cell.cuda()
|
||||
self.coord_ = self.coord_.cuda()
|
||||
|
||||
def forward(self, feat):
|
||||
# B K*K*Cin H W
|
||||
feat = F.unfold(feat, 3, padding=1).view(
|
||||
feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
|
||||
|
||||
preds = []
|
||||
for vx in [0,1]:
|
||||
for vy in [0,1]:
|
||||
q_feat = F.grid_sample(
|
||||
feat, self.coord_[vx,vy,:,:,:].flip(-1).unsqueeze(1),
|
||||
mode='nearest', align_corners=False)[:, :, 0, :] \
|
||||
.permute(0, 2, 1)
|
||||
inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1)
|
||||
|
||||
bs, q = self.coord_[0,0,:,:,:].shape[:2]
|
||||
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
|
||||
# print("pred shape: ",pred.shape)
|
||||
preds.append(pred)
|
||||
ret = 0
|
||||
for pred, area in zip(preds, self.area_weights):
|
||||
ret = ret + pred * area
|
||||
|
||||
return ret.permute(0, 2, 1).view(-1,3,self.image_size[0],self.image_size[1])
|
||||
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Liif.py
|
||||
# Created Date: Monday October 18th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 4:26:26 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_coord(shape, ranges=None, flatten=True):
|
||||
""" Make coordinates at grid centers.
|
||||
"""
|
||||
coord_seqs = []
|
||||
for i, n in enumerate(shape):
|
||||
print("i: %d, n: %d"%(i,n))
|
||||
if ranges is None:
|
||||
v0, v1 = -1, 1
|
||||
else:
|
||||
v0, v1 = ranges[i]
|
||||
r = (v1 - v0) / (2 * n)
|
||||
seq = v0 + r + (2 * r) * torch.arange(n).float()
|
||||
coord_seqs.append(seq)
|
||||
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
|
||||
if flatten:
|
||||
ret = ret.view(-1, ret.shape[-1])
|
||||
return ret
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, hidden_list):
|
||||
super().__init__()
|
||||
layers = []
|
||||
lastv = in_dim
|
||||
for hidden in hidden_list:
|
||||
layers.append(nn.Linear(lastv, hidden))
|
||||
layers.append(nn.ReLU())
|
||||
lastv = hidden
|
||||
layers.append(nn.Linear(lastv, out_dim))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
shape = x.shape[:-1]
|
||||
x = self.layers(x.view(-1, x.shape[-1]))
|
||||
return x.view(*shape, -1)
|
||||
|
||||
class LIIF(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
|
||||
imnet_in_dim = in_dim
|
||||
# imnet_in_dim += 2 # attach coord
|
||||
# imnet_in_dim += 2
|
||||
self.imnet = nn.Sequential( \
|
||||
nn.Conv2d(in_channels = imnet_in_dim, out_channels = out_dim, kernel_size= 3,padding=1),
|
||||
nn.InstanceNorm2d(out_dim, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = out_dim, out_channels = out_dim, kernel_size= 3,padding=1),
|
||||
# nn.InstanceNorm2d(out_dim),
|
||||
# nn.LeakyReLU(),
|
||||
)
|
||||
|
||||
def gen_coord(self, in_shape, output_size):
|
||||
|
||||
self.vx_lst = [-1, 1]
|
||||
self.vy_lst = [-1, 1]
|
||||
eps_shift = 1e-6
|
||||
self.image_size=output_size
|
||||
|
||||
# field radius (global: [-1, 1])
|
||||
rx = 2 / in_shape[-2] / 2
|
||||
ry = 2 / in_shape[-1] / 2
|
||||
|
||||
self.coord = make_coord(output_size,flatten=False) \
|
||||
.expand(in_shape[0],output_size[0],output_size[1],2)
|
||||
|
||||
# cell = torch.ones_like(coord)
|
||||
# cell[:, :, 0] *= 2 / coord.shape[-2]
|
||||
# cell[:, :, 1] *= 2 / coord.shape[-1]
|
||||
|
||||
# feat_coord = make_coord(in_shape[-2:], flatten=False) \
|
||||
# .permute(2, 0, 1) \
|
||||
# .unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:])
|
||||
|
||||
# areas = []
|
||||
|
||||
# self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
||||
# self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
||||
# self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
||||
# for vx in self.vx_lst:
|
||||
# for vy in self.vy_lst:
|
||||
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone()
|
||||
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift
|
||||
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift
|
||||
# self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
|
||||
# q_coord = F.grid_sample(
|
||||
# feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1),
|
||||
# mode='nearest', align_corners=False)[:, :, 0, :] \
|
||||
# .permute(0, 2, 1)
|
||||
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord
|
||||
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
|
||||
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
|
||||
|
||||
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone()
|
||||
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
|
||||
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
|
||||
# area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1])
|
||||
# areas.append(area + 1e-9)
|
||||
# tot_area = torch.stack(areas).sum(dim=0)
|
||||
# t = areas[0]; areas[0] = areas[3]; areas[3] = t
|
||||
# t = areas[1]; areas[1] = areas[2]; areas[2] = t
|
||||
# self.area_weights = []
|
||||
# for item in areas:
|
||||
# self.area_weights.append((item / tot_area).unsqueeze(-1).cuda())
|
||||
|
||||
# self.rel_coord = self.rel_coord.cuda()
|
||||
# self.rel_cell = self.rel_cell.cuda()
|
||||
# self.coord_ = self.coord_.cuda()
|
||||
self.coord = self.coord.cuda()
|
||||
|
||||
|
||||
def forward(self, feat):
|
||||
# B K*K*Cin H W
|
||||
# feat = F.unfold(feat, 3, padding=1).view(
|
||||
# feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
|
||||
|
||||
# preds = []
|
||||
# for vx in [0,1]:
|
||||
# for vy in [0,1]:
|
||||
# print("feat shape: ", feat.shape)
|
||||
# print("coor shape: ", self.coord.shape)
|
||||
q_feat = F.grid_sample(
|
||||
feat, self.coord,
|
||||
mode='bilinear', align_corners=False)
|
||||
out = self.imnet(q_feat)
|
||||
# inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1)
|
||||
|
||||
# bs, q = self.coord_[0,0,:,:,:].shape[:2]
|
||||
# pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
|
||||
# # print("pred shape: ",pred.shape)
|
||||
# preds.append(pred)
|
||||
# ret = 0
|
||||
# for pred, area in zip(preds, self.area_weights):
|
||||
# ret = ret + pred * area
|
||||
# print("warp output shape: ",out.shape)
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Liif.py
|
||||
# Created Date: Monday October 18th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 8:25:18 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from components.Involution import involution
|
||||
|
||||
|
||||
def make_coord(shape, ranges=None, flatten=True):
|
||||
""" Make coordinates at grid centers.
|
||||
"""
|
||||
coord_seqs = []
|
||||
for i, n in enumerate(shape):
|
||||
print("i: %d, n: %d"%(i,n))
|
||||
if ranges is None:
|
||||
v0, v1 = -1, 1
|
||||
else:
|
||||
v0, v1 = ranges[i]
|
||||
r = (v1 - v0) / (2 * n)
|
||||
seq = v0 + r + (2 * r) * torch.arange(n).float()
|
||||
coord_seqs.append(seq)
|
||||
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
|
||||
if flatten:
|
||||
ret = ret.view(-1, ret.shape[-1])
|
||||
return ret
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, hidden_list):
|
||||
super().__init__()
|
||||
layers = []
|
||||
lastv = in_dim
|
||||
for hidden in hidden_list:
|
||||
layers.append(nn.Linear(lastv, hidden))
|
||||
layers.append(nn.ReLU())
|
||||
lastv = hidden
|
||||
layers.append(nn.Linear(lastv, out_dim))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
shape = x.shape[:-1]
|
||||
x = self.layers(x.view(-1, x.shape[-1]))
|
||||
return x.view(*shape, -1)
|
||||
|
||||
class LIIF(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
|
||||
imnet_in_dim = in_dim
|
||||
# imnet_in_dim += 2 # attach coord
|
||||
# imnet_in_dim += 2
|
||||
|
||||
self.conv1x1 = nn.Conv2d(in_channels = imnet_in_dim, out_channels = out_dim, kernel_size= 1)
|
||||
# self.same_padding = nn.ReflectionPad2d(padding_size)
|
||||
|
||||
# self.conv = involution(out_dim,5,1)
|
||||
self.imnet = nn.Sequential( \
|
||||
# nn.Conv2d(in_channels = imnet_in_dim, out_channels = out_dim, kernel_size= 3,padding=1),
|
||||
involution(out_dim,5,1),
|
||||
nn.InstanceNorm2d(out_dim, affine=True, momentum=0),
|
||||
nn.LeakyReLU(),
|
||||
# nn.Conv2d(in_channels = out_dim, out_channels = out_dim, kernel_size= 3,padding=1),
|
||||
# nn.InstanceNorm2d(out_dim),
|
||||
# nn.LeakyReLU(),
|
||||
)
|
||||
|
||||
def gen_coord(self, in_shape, output_size):
|
||||
|
||||
self.vx_lst = [-1, 1]
|
||||
self.vy_lst = [-1, 1]
|
||||
eps_shift = 1e-6
|
||||
self.image_size=output_size
|
||||
|
||||
# field radius (global: [-1, 1])
|
||||
rx = 2 / in_shape[-2] / 2
|
||||
ry = 2 / in_shape[-1] / 2
|
||||
|
||||
self.coord = make_coord(output_size,flatten=False) \
|
||||
.expand(in_shape[0],output_size[0],output_size[1],2)
|
||||
|
||||
# cell = torch.ones_like(coord)
|
||||
# cell[:, :, 0] *= 2 / coord.shape[-2]
|
||||
# cell[:, :, 1] *= 2 / coord.shape[-1]
|
||||
|
||||
# feat_coord = make_coord(in_shape[-2:], flatten=False) \
|
||||
# .permute(2, 0, 1) \
|
||||
# .unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:])
|
||||
|
||||
# areas = []
|
||||
|
||||
# self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
||||
# self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
||||
# self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
||||
# for vx in self.vx_lst:
|
||||
# for vy in self.vy_lst:
|
||||
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone()
|
||||
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift
|
||||
# self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift
|
||||
# self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
|
||||
# q_coord = F.grid_sample(
|
||||
# feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1),
|
||||
# mode='nearest', align_corners=False)[:, :, 0, :] \
|
||||
# .permute(0, 2, 1)
|
||||
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord
|
||||
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
|
||||
# self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
|
||||
|
||||
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone()
|
||||
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
|
||||
# self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
|
||||
# area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1])
|
||||
# areas.append(area + 1e-9)
|
||||
# tot_area = torch.stack(areas).sum(dim=0)
|
||||
# t = areas[0]; areas[0] = areas[3]; areas[3] = t
|
||||
# t = areas[1]; areas[1] = areas[2]; areas[2] = t
|
||||
# self.area_weights = []
|
||||
# for item in areas:
|
||||
# self.area_weights.append((item / tot_area).unsqueeze(-1).cuda())
|
||||
|
||||
# self.rel_coord = self.rel_coord.cuda()
|
||||
# self.rel_cell = self.rel_cell.cuda()
|
||||
# self.coord_ = self.coord_.cuda()
|
||||
self.coord = self.coord.cuda()
|
||||
|
||||
|
||||
def forward(self, feat):
|
||||
# B K*K*Cin H W
|
||||
# feat = F.unfold(feat, 3, padding=1).view(
|
||||
# feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
|
||||
|
||||
# preds = []
|
||||
# for vx in [0,1]:
|
||||
# for vy in [0,1]:
|
||||
# print("feat shape: ", feat.shape)
|
||||
# print("coor shape: ", self.coord.shape)
|
||||
q_feat = self.conv1x1(feat)
|
||||
q_feat = F.grid_sample(
|
||||
q_feat, self.coord,
|
||||
mode='bilinear', align_corners=False)
|
||||
out = self.imnet(q_feat)
|
||||
# inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1)
|
||||
|
||||
# bs, q = self.coord_[0,0,:,:,:].shape[:2]
|
||||
# pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
|
||||
# # print("pred shape: ",pred.shape)
|
||||
# preds.append(pred)
|
||||
# ret = 0
|
||||
# for pred, area in zip(preds, self.area_weights):
|
||||
# ret = ret + pred * area
|
||||
# print("warp output shape: ",out.shape)
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: ResBlock.py
|
||||
# Created Date: Monday July 5th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 5th July 2021 12:18:18 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
from torch import nn
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, k_size = 3, stride=1):
|
||||
super().__init__()
|
||||
padding_size = int((k_size -1)/2)
|
||||
self.block = nn.Sequential(
|
||||
nn.ReflectionPad2d(padding_size),
|
||||
nn.Conv2d(in_channels = in_channel , out_channels = in_channel , kernel_size= k_size, stride=stride, bias= False),
|
||||
nn.InstanceNorm2d(in_channel, affine=True, momentum=0),
|
||||
nn.ReflectionPad2d(padding_size),
|
||||
nn.Conv2d(in_channels = in_channel , out_channels = in_channel , kernel_size= k_size, stride=stride, bias= False),
|
||||
nn.InstanceNorm2d(in_channel, affine=True, momentum=0)
|
||||
)
|
||||
self.__weights_init__()
|
||||
|
||||
def __weights_init__(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m,nn.Conv2d):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
|
||||
def forward(self, input):
|
||||
res = input
|
||||
h = self.block(input)
|
||||
out = h + res
|
||||
return out
|
||||
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class Transform_block(nn.Module):
|
||||
def __init__(self, k_size = 10):
|
||||
super().__init__()
|
||||
padding_size = int((k_size -1)/2)
|
||||
# self.padding = nn.ReplicationPad2d(padding_size)
|
||||
self.pool = nn.AvgPool2d(k_size, stride=1,padding=padding_size)
|
||||
|
||||
def forward(self, input_image):
|
||||
# h = self.padding(input)
|
||||
out = self.pool(input_image)
|
||||
return out
|
||||
@@ -0,0 +1,854 @@
|
||||
# -----------------------------------------------------------------------------------
|
||||
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
||||
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
||||
# -----------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
nn.init
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
||||
|
||||
def flops(self, N):
|
||||
# calculate flops for 1 window with token length of N
|
||||
flops = 0
|
||||
# qkv = self.qkv(x)
|
||||
flops += N * self.dim * 3 * self.dim
|
||||
# attn = (q @ k.transpose(-2, -1))
|
||||
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
||||
# x = (attn @ v)
|
||||
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
||||
# x = self.proj(x)
|
||||
flops += N * self.dim * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
r""" Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resulotion.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
|
||||
self.drop_path = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
if self.shift_size > 0:
|
||||
attn_mask = self.calculate_mask(self.input_resolution)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
def calculate_mask(self, x_size):
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = x_size
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
|
||||
return attn_mask
|
||||
|
||||
def forward(self, x, x_size):
|
||||
H, W = x_size
|
||||
B, L, C = x.shape
|
||||
# assert L == H * W, "input feature has wrong size"
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
shifted_x = x
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
||||
if self.input_resolution == x_size:
|
||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
else:
|
||||
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
x = x.view(B, H * W, C)
|
||||
x = shortcut + self.drop_path(x)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
||||
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.input_resolution
|
||||
# norm1
|
||||
flops += self.dim * H * W
|
||||
# W-MSA/SW-MSA
|
||||
nW = H * W / self.window_size / self.window_size
|
||||
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
||||
# mlp
|
||||
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
||||
# norm2
|
||||
flops += self.dim * H * W
|
||||
return flops
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r""" Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_resolution,
|
||||
dim, norm_layer = nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = H * W * self.dim
|
||||
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
||||
num_heads=num_heads, window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, x_size):
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, x_size)
|
||||
else:
|
||||
x = blk(x, x_size)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
for blk in self.blocks:
|
||||
flops += blk.flops()
|
||||
if self.downsample is not None:
|
||||
flops += self.downsample.flops()
|
||||
return flops
|
||||
|
||||
|
||||
class RSTB(nn.Module):
|
||||
"""Residual Swin Transformer Block (RSTB).
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
img_size: Input image size.
|
||||
patch_size: Patch size.
|
||||
resi_connection: The convolutional block before residual connection.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
||||
img_size=224, patch_size=4, resi_connection='1conv'):
|
||||
super(RSTB, self).__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
self.residual_group = BasicLayer(dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path,
|
||||
norm_layer=norm_layer,
|
||||
downsample=downsample,
|
||||
use_checkpoint=use_checkpoint)
|
||||
|
||||
if resi_connection == '1conv':
|
||||
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
||||
elif resi_connection == '3conv':
|
||||
# to save parameters and memory
|
||||
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
||||
norm_layer=None)
|
||||
|
||||
self.patch_unembed = PatchUnEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
||||
norm_layer=None)
|
||||
|
||||
def forward(self, x, x_size):
|
||||
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
flops += self.residual_group.flops()
|
||||
H, W = self.input_resolution
|
||||
flops += H * W * self.dim * self.dim * 9
|
||||
flops += self.patch_embed.flops()
|
||||
flops += self.patch_unembed.flops()
|
||||
|
||||
return flops
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
r""" Image to Patch Embedding
|
||||
|
||||
Args:
|
||||
img_size (int): Image size. Default: 224.
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.patches_resolution = patches_resolution
|
||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.img_size
|
||||
if self.norm is not None:
|
||||
flops += H * W * self.embed_dim
|
||||
return flops
|
||||
|
||||
|
||||
class PatchUnEmbed(nn.Module):
|
||||
r""" Image to Patch Unembedding
|
||||
|
||||
Args:
|
||||
img_size (int): Image size. Default: 224.
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.patches_resolution = patches_resolution
|
||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, x, x_size):
|
||||
B, HW, C = x.shape
|
||||
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
return flops
|
||||
|
||||
|
||||
class Upsample(nn.Sequential):
|
||||
"""Upsample module.
|
||||
|
||||
Args:
|
||||
scale (int): Scale factor. Supported scales: 2^n and 3.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
"""
|
||||
|
||||
def __init__(self, scale, num_feat):
|
||||
m = []
|
||||
if (scale & (scale - 1)) == 0: # scale = 2^n
|
||||
for _ in range(int(math.log(scale, 2))):
|
||||
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(2))
|
||||
elif scale == 3:
|
||||
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(3))
|
||||
else:
|
||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||
super(Upsample, self).__init__(*m)
|
||||
|
||||
|
||||
class UpsampleOneStep(nn.Sequential):
|
||||
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
||||
Used in lightweight SR to save parameters.
|
||||
|
||||
Args:
|
||||
scale (int): Scale factor. Supported scales: 2^n and 3.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
||||
self.num_feat = num_feat
|
||||
self.input_resolution = input_resolution
|
||||
m = []
|
||||
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(scale))
|
||||
super(UpsampleOneStep, self).__init__(*m)
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = H * W * self.num_feat * 3 * 9
|
||||
return flops
|
||||
|
||||
|
||||
class SwinIR(nn.Module):
|
||||
r""" SwinIR
|
||||
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple(int)): Input image size. Default 64
|
||||
patch_size (int | tuple(int)): Patch size. Default: 1
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
depths (tuple(int)): Depth of each Swin Transformer layer.
|
||||
num_heads (tuple(int)): Number of attention heads in different layers.
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
||||
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
||||
img_range: Image range. 1. or 255.
|
||||
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
||||
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||
**kwargs):
|
||||
super(SwinIR, self).__init__()
|
||||
num_in_ch = in_chans
|
||||
num_out_ch = in_chans
|
||||
num_feat = 64
|
||||
self.img_range = img_range
|
||||
if in_chans == 3:
|
||||
rgb_mean = (0.4488, 0.4371, 0.4040)
|
||||
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
||||
else:
|
||||
self.mean = torch.zeros(1, 1, 1, 1)
|
||||
self.upscale = upscale
|
||||
self.upsampler = upsampler
|
||||
|
||||
#####################################################################################################
|
||||
################################### 1, shallow feature extraction ###################################
|
||||
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
||||
|
||||
#####################################################################################################
|
||||
################################### 2, deep feature extraction ######################################
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.patch_norm = patch_norm
|
||||
self.num_features = embed_dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
# merge non-overlapping patches into image
|
||||
self.patch_unembed = PatchUnEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build Residual Swin Transformer blocks (RSTB)
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = RSTB(dim=embed_dim,
|
||||
input_resolution=(patches_resolution[0],
|
||||
patches_resolution[1]),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||
norm_layer=norm_layer,
|
||||
downsample=None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
resi_connection=resi_connection
|
||||
|
||||
)
|
||||
self.layers.append(layer)
|
||||
self.norm = norm_layer(self.num_features)
|
||||
|
||||
# build the last conv layer in deep feature extraction
|
||||
if resi_connection == '1conv':
|
||||
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
||||
elif resi_connection == '3conv':
|
||||
# to save parameters and memory
|
||||
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
||||
|
||||
#####################################################################################################
|
||||
################################ 3, high quality image reconstruction ################################
|
||||
if self.upsampler == 'pixelshuffle':
|
||||
# for classical SR
|
||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
self.upsample = Upsample(upscale, num_feat)
|
||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||
elif self.upsampler == 'pixelshuffledirect':
|
||||
# for lightweight SR (to save parameters)
|
||||
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
||||
(patches_resolution[0], patches_resolution[1]))
|
||||
elif self.upsampler == 'nearest+conv':
|
||||
# for real-world SR (less artifacts)
|
||||
assert self.upscale == 4, 'only support x4 now.'
|
||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
else:
|
||||
# for image denoising and JPEG compression artifact reduction
|
||||
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'absolute_pos_embed'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
return {'relative_position_bias_table'}
|
||||
|
||||
def forward_features(self, x):
|
||||
x_size = (x.shape[2], x.shape[3])
|
||||
x = self.patch_embed(x)
|
||||
if self.ape:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, x_size)
|
||||
|
||||
x = self.norm(x) # B L C
|
||||
x = self.patch_unembed(x, x_size)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
self.mean = self.mean.type_as(x)
|
||||
x = (x - self.mean) * self.img_range
|
||||
|
||||
if self.upsampler == 'pixelshuffle':
|
||||
# for classical SR
|
||||
x = self.conv_first(x)
|
||||
x = self.conv_after_body(self.forward_features(x)) + x
|
||||
x = self.conv_before_upsample(x)
|
||||
x = self.conv_last(self.upsample(x))
|
||||
elif self.upsampler == 'pixelshuffledirect':
|
||||
# for lightweight SR
|
||||
x = self.conv_first(x)
|
||||
x = self.conv_after_body(self.forward_features(x)) + x
|
||||
x = self.upsample(x)
|
||||
elif self.upsampler == 'nearest+conv':
|
||||
# for real-world SR
|
||||
x = self.conv_first(x)
|
||||
x = self.conv_after_body(self.forward_features(x)) + x
|
||||
x = self.conv_before_upsample(x)
|
||||
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
||||
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
||||
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
||||
else:
|
||||
# for image denoising and JPEG compression artifact reduction
|
||||
x_first = self.conv_first(x)
|
||||
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
||||
x = x + self.conv_last(res)
|
||||
|
||||
x = x / self.img_range + self.mean
|
||||
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.patches_resolution
|
||||
flops += H * W * 3 * self.embed_dim * 9
|
||||
flops += self.patch_embed.flops()
|
||||
for i, layer in enumerate(self.layers):
|
||||
flops += layer.flops()
|
||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
||||
flops += self.upsample.flops()
|
||||
return flops
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
upscale = 4
|
||||
window_size = 8
|
||||
height = (1024 // upscale // window_size + 1) * window_size
|
||||
width = (720 // upscale // window_size + 1) * window_size
|
||||
model = SwinIR(upscale=2, img_size=(height, width),
|
||||
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
||||
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
||||
print(model)
|
||||
print(height, width, model.flops() / 1e9)
|
||||
|
||||
x = torch.randn((1, 3, height, width))
|
||||
x = model(x)
|
||||
print(x.shape)
|
||||
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: warp_invo.py
|
||||
# Created Date: Tuesday October 19th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 11:27:13 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
from torch import nn
|
||||
from components.Involution import involution
|
||||
|
||||
|
||||
class DeConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size = 3, upsampl_scale = 2, padding="reflect"):
|
||||
super().__init__()
|
||||
self.upsampling = nn.UpsamplingNearest2d(scale_factor=upsampl_scale)
|
||||
padding_size = int((kernel_size -1)/2)
|
||||
self.conv1x1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size= 1)
|
||||
# self.same_padding = nn.ReflectionPad2d(padding_size)
|
||||
if padding.lower() == "reflect":
|
||||
|
||||
self.conv = involution(out_channels,5,1)
|
||||
# self.conv = nn.Sequential(
|
||||
# nn.ReflectionPad2d(padding_size),
|
||||
# nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size= kernel_size, bias= False))
|
||||
# for layer in self.conv:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
elif padding.lower() == "zero":
|
||||
self.conv = involution(out_channels,5,1)
|
||||
# nn.init.xavier_uniform_(self.conv.weight)
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# nn.init.xavier_uniform_(self.conv.weight)
|
||||
|
||||
def forward(self, input):
|
||||
h = self.conv1x1(input)
|
||||
h = self.upsampling(h)
|
||||
h = self.conv(h)
|
||||
return h
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: StyleResize.py
|
||||
# Created Date: Friday April 17th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 18th April 2020 1:39:53 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
from PIL import Image
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
class StyleResize(object):
|
||||
|
||||
def __call__(self, images):
|
||||
th, tw = images.size # target height, width
|
||||
if max(th,tw) > 1800:
|
||||
alpha = 1800. / float(min(th,tw))
|
||||
h = int(th*alpha)
|
||||
w = int(tw*alpha)
|
||||
images = F.resize(images, (h, w))
|
||||
if max(th,tw) < 800:
|
||||
# Resize the smallest side of the image to 800px
|
||||
alpha = 800. / float(min(th,tw))
|
||||
if alpha < 4.:
|
||||
h = int(th*alpha)
|
||||
w = int(tw*alpha)
|
||||
images = F.resize(images, (h, w))
|
||||
else:
|
||||
images = F.resize(images, (800, 800))
|
||||
return images
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '()'
|
||||
@@ -0,0 +1,269 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: data_loader.py
|
||||
# Created Date: Saturday April 4th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 5th January 2021 2:12:29 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from torch.utils import data
|
||||
import torchvision.datasets as dsets
|
||||
from torchvision import transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
class StyleResize(object):
|
||||
def __call__(self, images):
|
||||
th, tw = images.size # target height, width
|
||||
if max(th,tw) > 1800:
|
||||
alpha = 1800. / float(min(th,tw))
|
||||
h = int(th*alpha)
|
||||
w = int(tw*alpha)
|
||||
images = F.resize(images, (h, w))
|
||||
if max(th,tw) < 800:
|
||||
# Resize the smallest side of the image to 800px
|
||||
alpha = 800. / float(min(th,tw))
|
||||
if alpha < 4.:
|
||||
h = int(th*alpha)
|
||||
w = int(tw*alpha)
|
||||
images = F.resize(images, (h, w))
|
||||
else:
|
||||
images = F.resize(images, (800, 800))
|
||||
return images
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '()'
|
||||
|
||||
class DataPrefetcher():
|
||||
def __init__(self, loader):
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
# self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
|
||||
# self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
self.__preload__()
|
||||
|
||||
def __preload__(self):
|
||||
try:
|
||||
self.content, self.style, self.label = next(self.dataiter)
|
||||
except StopIteration:
|
||||
self.dataiter = iter(self.loader)
|
||||
self.content, self.style, self.label = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.content= self.content.cuda(non_blocking=True)
|
||||
self.style = self.style.cuda(non_blocking=True)
|
||||
self.label = self.label.cuda(non_blocking=True)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
# self.next_input = self.next_input.float()
|
||||
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
content = self.content
|
||||
style = self.style
|
||||
label = self.label
|
||||
self.__preload__()
|
||||
return content, style, label
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return len(self.loader)
|
||||
|
||||
class TotalDataset(data.Dataset):
|
||||
"""Dataset class for the Artworks dataset and content dataset."""
|
||||
def __init__(self, content_image_dir,style_image_dir,
|
||||
selectedContent,selectedStyle,
|
||||
content_transform,style_transform,
|
||||
subffix='jpg', random_seed=1234):
|
||||
"""Initialize and preprocess the CelebA dataset."""
|
||||
self.content_image_dir= content_image_dir
|
||||
self.style_image_dir = style_image_dir
|
||||
self.content_transform= content_transform
|
||||
self.style_transform = style_transform
|
||||
self.selectedContent = selectedContent
|
||||
self.selectedStyle = selectedStyle
|
||||
self.subffix = subffix
|
||||
self.content_dataset = []
|
||||
self.art_dataset = []
|
||||
self.random_seed= random_seed
|
||||
self.__preprocess__()
|
||||
self.num_images = len(self.content_dataset)
|
||||
self.art_num = len(self.art_dataset)
|
||||
|
||||
def __preprocess__(self):
|
||||
"""Preprocess the Artworks dataset."""
|
||||
print("processing content images...")
|
||||
for dir_item in self.selectedContent:
|
||||
join_path = Path(self.content_image_dir,dir_item)#.replace('/','_'))
|
||||
if join_path.exists():
|
||||
print("processing %s"%dir_item)
|
||||
images = join_path.glob('*.%s'%(self.subffix))
|
||||
for item in images:
|
||||
self.content_dataset.append(item)
|
||||
else:
|
||||
print("%s dir does not exist!"%dir_item)
|
||||
label_index = 0
|
||||
print("processing style images...")
|
||||
for class_item in self.selectedStyle:
|
||||
images = Path(self.style_image_dir).glob('%s/*.%s'%(class_item, self.subffix))
|
||||
for item in images:
|
||||
self.art_dataset.append([item, label_index])
|
||||
label_index += 1
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.content_dataset)
|
||||
random.shuffle(self.art_dataset)
|
||||
# self.dataset = images
|
||||
print('Finished preprocessing the Art Works dataset, total image number: %d...'%len(self.art_dataset))
|
||||
print('Finished preprocessing the Content dataset, total image number: %d...'%len(self.content_dataset))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return one image and its corresponding attribute label."""
|
||||
filename = self.content_dataset[index]
|
||||
image = Image.open(filename)
|
||||
content = self.content_transform(image)
|
||||
art_index = random.randint(0,self.art_num-1)
|
||||
filename,label = self.art_dataset[art_index]
|
||||
image = Image.open(filename)
|
||||
style = self.style_transform(image)
|
||||
return content,style,label
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
|
||||
def GetLoader(s_image_dir,c_image_dir,
|
||||
style_selected_dir, content_selected_dir,
|
||||
crop_size=178, batch_size=16, num_workers=8,
|
||||
colorJitterEnable=True, colorConfig={"brightness":0.05,"contrast":0.05,"saturation":0.05,"hue":0.05}):
|
||||
"""Build and return a data loader."""
|
||||
|
||||
s_transforms = []
|
||||
c_transforms = []
|
||||
|
||||
s_transforms.append(T.Resize(768))
|
||||
# s_transforms.append(T.Resize(900))
|
||||
c_transforms.append(T.Resize(768))
|
||||
|
||||
s_transforms.append(T.RandomCrop(crop_size,pad_if_needed=True,padding_mode='reflect'))
|
||||
c_transforms.append(T.RandomCrop(crop_size))
|
||||
|
||||
s_transforms.append(T.RandomHorizontalFlip())
|
||||
c_transforms.append(T.RandomHorizontalFlip())
|
||||
|
||||
s_transforms.append(T.RandomVerticalFlip())
|
||||
c_transforms.append(T.RandomVerticalFlip())
|
||||
|
||||
if colorJitterEnable:
|
||||
if colorConfig is not None:
|
||||
print("Enable color jitter!")
|
||||
colorBrightness = colorConfig["brightness"]
|
||||
colorContrast = colorConfig["contrast"]
|
||||
colorSaturation = colorConfig["saturation"]
|
||||
colorHue = (-colorConfig["hue"],colorConfig["hue"])
|
||||
s_transforms.append(T.ColorJitter(brightness=colorBrightness,\
|
||||
contrast=colorContrast,saturation=colorSaturation, hue=colorHue))
|
||||
c_transforms.append(T.ColorJitter(brightness=colorBrightness,\
|
||||
contrast=colorContrast,saturation=colorSaturation, hue=colorHue))
|
||||
s_transforms.append(T.ToTensor())
|
||||
c_transforms.append(T.ToTensor())
|
||||
|
||||
s_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
|
||||
c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
|
||||
|
||||
s_transforms = T.Compose(s_transforms)
|
||||
c_transforms = T.Compose(c_transforms)
|
||||
|
||||
content_dataset = TotalDataset(c_image_dir,s_image_dir, content_selected_dir, style_selected_dir
|
||||
, c_transforms,s_transforms)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
prefetcher = DataPrefetcher(content_data_loader)
|
||||
return prefetcher
|
||||
|
||||
def GetValiDataTensors(
|
||||
image_dir=None,
|
||||
selected_imgs=[],
|
||||
crop_size=178,
|
||||
mean = (0.5, 0.5, 0.5),
|
||||
std=(0.5, 0.5, 0.5)
|
||||
):
|
||||
|
||||
transforms = []
|
||||
|
||||
transforms.append(T.Resize(768))
|
||||
|
||||
transforms.append(T.RandomCrop(crop_size,pad_if_needed=True,padding_mode='reflect'))
|
||||
|
||||
transforms.append(T.ToTensor())
|
||||
|
||||
transforms.append(T.Normalize(mean=mean, std=std))
|
||||
|
||||
transforms = T.Compose(transforms)
|
||||
|
||||
result_img = []
|
||||
print("Start to read validation data......")
|
||||
if len(selected_imgs) != 0:
|
||||
for s_img in selected_imgs:
|
||||
if image_dir == None:
|
||||
temp_img = s_img
|
||||
else:
|
||||
temp_img = os.path.join(image_dir, s_img)
|
||||
temp_img = Image.open(temp_img)
|
||||
temp_img = transforms(temp_img).cuda().unsqueeze(0)
|
||||
result_img.append(temp_img)
|
||||
else:
|
||||
s_imgs = glob.glob(os.path.join(image_dir, '*.jpg'))
|
||||
s_imgs = s_imgs + glob.glob(os.path.join(image_dir, '*.png'))
|
||||
for s_img in s_imgs:
|
||||
temp_img = os.path.join(image_dir, s_img)
|
||||
temp_img = Image.open(temp_img)
|
||||
temp_img = transforms(temp_img).cuda().unsqueeze(0)
|
||||
result_img.append(temp_img)
|
||||
print("Finish to read validation data......")
|
||||
print("Total validation images: %d"%len(result_img))
|
||||
return result_img
|
||||
|
||||
def ScanAbnormalImg(image_dir, selected_imgs):
|
||||
"""Scan the dataset, this function is designed to exclude or remove the non-RGB images."""
|
||||
print("processing images...")
|
||||
subffix = "jpg"
|
||||
for dir_item in selected_imgs:
|
||||
join_path = Path(image_dir,dir_item)#.replace('/','_'))
|
||||
if join_path.exists():
|
||||
print("processing %s"%dir_item)
|
||||
images = join_path.glob('*.%s'%(subffix))
|
||||
for item in images:
|
||||
# print(str(item.name)[0:6])
|
||||
# temp = cv2.imread(str(item))
|
||||
temp = Image.open(item)
|
||||
# exclude the abnormal images
|
||||
if temp.mode!="RGB":
|
||||
print(temp.mode)
|
||||
print("Found one abnormal image!")
|
||||
print(item)
|
||||
os.remove(str(item))
|
||||
|
||||
else:
|
||||
print("%s dir does not exist!"%dir_item)
|
||||
@@ -0,0 +1,253 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: data_loader_modify.py
|
||||
# Created Date: Saturday April 4th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 4th July 2021 11:12:42 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from torch.utils import data
|
||||
import torchvision.datasets as dsets
|
||||
from torchvision import transforms as T
|
||||
from data_tools.StyleResize import StyleResize
|
||||
# from StyleResize import StyleResize
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader):
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
# self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
|
||||
# self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.content, self.style, self.label = next(self.dataiter)
|
||||
except StopIteration:
|
||||
self.dataiter = iter(self.loader)
|
||||
self.content, self.style, self.label = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.content= self.content.cuda(non_blocking=True)
|
||||
self.style = self.style.cuda(non_blocking=True)
|
||||
self.label = self.label.cuda(non_blocking=True)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
# self.next_input = self.next_input.float()
|
||||
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
content = self.content
|
||||
style = self.style
|
||||
label = self.label
|
||||
self.preload()
|
||||
return content, style, label
|
||||
|
||||
class TotalDataset(data.Dataset):
|
||||
"""Dataset class for the Artworks dataset and content dataset."""
|
||||
|
||||
def __init__(self, content_image_dir,style_image_dir,
|
||||
selectedContent,selectedStyle,
|
||||
content_transform,style_transform,
|
||||
subffix='jpg', random_seed=1234):
|
||||
"""Initialize and preprocess the CelebA dataset."""
|
||||
self.content_image_dir = content_image_dir
|
||||
self.style_image_dir = style_image_dir
|
||||
self.content_transform = content_transform
|
||||
self.style_transform = style_transform
|
||||
self.selectedContent = selectedContent
|
||||
self.selectedStyle = selectedStyle
|
||||
self.subffix = subffix
|
||||
self.content_dataset = []
|
||||
self.art_dataset = []
|
||||
self.random_seed = random_seed
|
||||
self.preprocess()
|
||||
self.num_images = len(self.content_dataset)
|
||||
self.art_num = len(self.art_dataset)
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the Artworks dataset."""
|
||||
print("processing content images...")
|
||||
for dir_item in self.selectedContent:
|
||||
join_path = Path(self.content_image_dir,dir_item)
|
||||
if join_path.exists():
|
||||
print("processing %s"%dir_item,end='\r')
|
||||
images = join_path.glob('*.%s'%(self.subffix))
|
||||
for item in images:
|
||||
self.content_dataset.append(item)
|
||||
else:
|
||||
print("%s dir does not exist!"%dir_item,end='\r')
|
||||
label_index = 0
|
||||
print("processing style images...")
|
||||
for class_item in self.selectedStyle:
|
||||
images = Path(self.style_image_dir).glob('%s/*.%s'%(class_item, self.subffix))
|
||||
for item in images:
|
||||
self.art_dataset.append([item, label_index])
|
||||
label_index += 1
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.content_dataset)
|
||||
random.shuffle(self.art_dataset)
|
||||
# self.dataset = images
|
||||
print('Finished preprocessing the Art Works dataset, total image number: %d...'%len(self.art_dataset))
|
||||
print('Finished preprocessing the Content dataset, total image number: %d...'%len(self.content_dataset))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return one image and its corresponding attribute label."""
|
||||
filename = self.content_dataset[index]
|
||||
image = Image.open(filename)
|
||||
content = self.content_transform(image)
|
||||
art_index = random.randint(0,self.art_num-1)
|
||||
filename,label = self.art_dataset[art_index]
|
||||
image = Image.open(filename)
|
||||
style = self.style_transform(image)
|
||||
return content,style,label
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
def GetLoader( dataset_roots,
|
||||
batch_size=16,
|
||||
crop_size=512,
|
||||
**kwargs
|
||||
):
|
||||
"""Build and return a data loader."""
|
||||
if not kwargs:
|
||||
a = "Input params error!"
|
||||
raise ValueError(print(a))
|
||||
|
||||
colorJitterEnable = kwargs["color_jitter"]
|
||||
colorConfig = kwargs["color_config"]
|
||||
num_workers = kwargs["dataloader_workers"]
|
||||
num_workers = kwargs["dataloader_workers"]
|
||||
place365_root = dataset_roots["Place365_big"]
|
||||
wikiart_root = dataset_roots["WikiArt"]
|
||||
selected_c_dir = kwargs["selected_content_dir"]
|
||||
selected_s_dir = kwargs["selected_style_dir"]
|
||||
random_seed = kwargs["random_seed"]
|
||||
|
||||
s_transforms = []
|
||||
c_transforms = []
|
||||
|
||||
s_transforms.append(StyleResize())
|
||||
# s_transforms.append(T.Resize(900))
|
||||
c_transforms.append(T.Resize(900))
|
||||
|
||||
s_transforms.append(T.RandomCrop(crop_size, pad_if_needed=True, padding_mode='reflect'))
|
||||
c_transforms.append(T.RandomCrop(crop_size))
|
||||
|
||||
s_transforms.append(T.RandomHorizontalFlip())
|
||||
c_transforms.append(T.RandomHorizontalFlip())
|
||||
|
||||
s_transforms.append(T.RandomVerticalFlip())
|
||||
c_transforms.append(T.RandomVerticalFlip())
|
||||
|
||||
if colorJitterEnable:
|
||||
if colorConfig is not None:
|
||||
print("Enable color jitter!")
|
||||
colorBrightness = colorConfig["brightness"]
|
||||
colorContrast = colorConfig["contrast"]
|
||||
colorSaturation = colorConfig["saturation"]
|
||||
colorHue = (-colorConfig["hue"],colorConfig["hue"])
|
||||
s_transforms.append(T.ColorJitter(brightness=colorBrightness,\
|
||||
contrast=colorContrast,saturation=colorSaturation, hue=colorHue))
|
||||
c_transforms.append(T.ColorJitter(brightness=colorBrightness,\
|
||||
contrast=colorContrast,saturation=colorSaturation, hue=colorHue))
|
||||
s_transforms.append(T.ToTensor())
|
||||
c_transforms.append(T.ToTensor())
|
||||
|
||||
s_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
|
||||
c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
|
||||
|
||||
s_transforms = T.Compose(s_transforms)
|
||||
c_transforms = T.Compose(c_transforms)
|
||||
|
||||
content_dataset = TotalDataset(place365_root,wikiart_root,
|
||||
selected_c_dir, selected_s_dir,
|
||||
c_transforms, s_transforms, "jpg", random_seed)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
prefetcher = data_prefetcher(content_data_loader)
|
||||
return prefetcher
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torchvision.utils import save_image
|
||||
style_class = ["vangogh","picasso","samuel"]
|
||||
categories_names = \
|
||||
['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
|
||||
s_datapath = "D:\\F_Disk\\data_set\\Art_Data\\data_art_backup"
|
||||
c_datapath = "D:\\Downloads\\data_large"
|
||||
savepath = "D:\\PatchFace\\PleaseWork\\multi-style-gan\\StyleTransfer\\dataloader_test"
|
||||
|
||||
imsize = 512
|
||||
s_datasetloader= getLoader(s_datapath,c_datapath,
|
||||
style_class, categories_names,
|
||||
crop_size=imsize, batch_size=16, num_workers=4)
|
||||
wocao = iter(s_datasetloader)
|
||||
for i in range(500):
|
||||
print("new batch")
|
||||
s_image,c_image,label = next(wocao)
|
||||
print(label)
|
||||
# print(label)
|
||||
# saved_image1 = torch.cat([denorm(image.data),denorm(hahh.data)],3)
|
||||
# save_image(denorm(image), "%s\\%d-label-%d.jpg"%(savepath,i), nrow=1, padding=1)
|
||||
pass
|
||||
# import cv2
|
||||
# import os
|
||||
# for dir_item in categories_names:
|
||||
# join_path = Path(contentdatapath,dir_item)
|
||||
# if join_path.exists():
|
||||
# print("processing %s"%dir_item,end='\r')
|
||||
# images = join_path.glob('*.%s'%("jpg"))
|
||||
# for item in images:
|
||||
# temp_path = str(item)
|
||||
# # temp = cv2.imread(temp_path)
|
||||
# temp = Image.open(temp_path)
|
||||
# if temp.layers<3:
|
||||
# print("remove broken image...")
|
||||
# print("image name:%s"%temp_path)
|
||||
# del temp
|
||||
# os.remove(item)
|
||||
@@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: data_loader_modify.py
|
||||
# Created Date: Saturday April 4th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 11th October 2021 12:17:58 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from torch.utils import data
|
||||
import torchvision.datasets as dsets
|
||||
from torchvision import transforms as T
|
||||
from data_tools.StyleResize import StyleResize
|
||||
# from StyleResize import StyleResize
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader):
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
# self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
|
||||
# self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
self.num_images = len(loader)
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.content = next(self.dataiter)
|
||||
except StopIteration:
|
||||
self.dataiter = iter(self.loader)
|
||||
self.content = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.content= self.content.cuda(non_blocking=True)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
# self.next_input = self.next_input.float()
|
||||
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
content = self.content
|
||||
self.preload()
|
||||
return content
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
class Place365Dataset(data.Dataset):
|
||||
"""Dataset class for the Artworks dataset and content dataset."""
|
||||
|
||||
def __init__(self,
|
||||
content_image_dir,
|
||||
selectedContent,
|
||||
content_transform,
|
||||
subffix='jpg',
|
||||
random_seed=1234):
|
||||
"""Initialize and preprocess the CelebA dataset."""
|
||||
self.content_image_dir = content_image_dir
|
||||
self.content_transform = content_transform
|
||||
self.selectedContent = selectedContent
|
||||
self.subffix = subffix
|
||||
self.content_dataset = []
|
||||
self.random_seed = random_seed
|
||||
self.preprocess()
|
||||
self.num_images = len(self.content_dataset)
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the Artworks dataset."""
|
||||
print("processing content images...")
|
||||
for dir_item in self.selectedContent:
|
||||
join_path = Path(self.content_image_dir,dir_item)
|
||||
if join_path.exists():
|
||||
print("processing %s"%dir_item,end='\r')
|
||||
images = join_path.glob('*.%s'%(self.subffix))
|
||||
for item in images:
|
||||
self.content_dataset.append(item)
|
||||
else:
|
||||
print("%s dir does not exist!"%dir_item,end='\r')
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.content_dataset)
|
||||
print('Finished preprocessing the Content dataset, total image number: %d...'%len(self.content_dataset))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return one image and its corresponding attribute label."""
|
||||
filename = self.content_dataset[index]
|
||||
image = Image.open(filename)
|
||||
content = self.content_transform(image)
|
||||
return content
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
def GetLoader( dataset_roots,
|
||||
batch_size=16,
|
||||
crop_size=512,
|
||||
**kwargs
|
||||
):
|
||||
"""Build and return a data loader."""
|
||||
if not kwargs:
|
||||
a = "Input params error!"
|
||||
raise ValueError(print(a))
|
||||
|
||||
colorJitterEnable = kwargs["color_jitter"]
|
||||
colorConfig = kwargs["color_config"]
|
||||
num_workers = kwargs["dataloader_workers"]
|
||||
num_workers = kwargs["dataloader_workers"]
|
||||
place365_root = dataset_roots["Place365_big"]
|
||||
selected_c_dir = kwargs["selected_content_dir"]
|
||||
random_seed = kwargs["random_seed"]
|
||||
|
||||
c_transforms = []
|
||||
|
||||
# s_transforms.append(T.Resize(900))
|
||||
c_transforms.append(T.Resize(900))
|
||||
c_transforms.append(T.RandomCrop(crop_size))
|
||||
c_transforms.append(T.RandomHorizontalFlip())
|
||||
c_transforms.append(T.RandomVerticalFlip())
|
||||
|
||||
if colorJitterEnable:
|
||||
if colorConfig is not None:
|
||||
print("Enable color jitter!")
|
||||
colorBrightness = colorConfig["brightness"]
|
||||
colorContrast = colorConfig["contrast"]
|
||||
colorSaturation = colorConfig["saturation"]
|
||||
colorHue = (-colorConfig["hue"],colorConfig["hue"])
|
||||
c_transforms.append(T.ColorJitter(brightness=colorBrightness,\
|
||||
contrast=colorContrast,saturation=colorSaturation, hue=colorHue))
|
||||
c_transforms.append(T.ToTensor())
|
||||
c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
|
||||
c_transforms = T.Compose(c_transforms)
|
||||
|
||||
content_dataset = Place365Dataset(
|
||||
place365_root,
|
||||
selected_c_dir,
|
||||
c_transforms,
|
||||
"jpg",
|
||||
random_seed)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
prefetcher = data_prefetcher(content_data_loader)
|
||||
return prefetcher
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torchvision.utils import save_image
|
||||
style_class = ["vangogh","picasso","samuel"]
|
||||
categories_names = \
|
||||
['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
|
||||
s_datapath = "D:\\F_Disk\\data_set\\Art_Data\\data_art_backup"
|
||||
c_datapath = "D:\\Downloads\\data_large"
|
||||
savepath = "D:\\PatchFace\\PleaseWork\\multi-style-gan\\StyleTransfer\\dataloader_test"
|
||||
|
||||
imsize = 512
|
||||
s_datasetloader= getLoader(s_datapath,c_datapath,
|
||||
style_class, categories_names,
|
||||
crop_size=imsize, batch_size=16, num_workers=4)
|
||||
wocao = iter(s_datasetloader)
|
||||
for i in range(500):
|
||||
print("new batch")
|
||||
s_image,c_image,label = next(wocao)
|
||||
print(label)
|
||||
# print(label)
|
||||
# saved_image1 = torch.cat([denorm(image.data),denorm(hahh.data)],3)
|
||||
# save_image(denorm(image), "%s\\%d-label-%d.jpg"%(savepath,i), nrow=1, padding=1)
|
||||
pass
|
||||
# import cv2
|
||||
# import os
|
||||
# for dir_item in categories_names:
|
||||
# join_path = Path(contentdatapath,dir_item)
|
||||
# if join_path.exists():
|
||||
# print("processing %s"%dir_item,end='\r')
|
||||
# images = join_path.glob('*.%s'%("jpg"))
|
||||
# for item in images:
|
||||
# temp_path = str(item)
|
||||
# # temp = cv2.imread(temp_path)
|
||||
# temp = Image.open(temp_path)
|
||||
# if temp.layers<3:
|
||||
# print("remove broken image...")
|
||||
# print("image name:%s"%temp_path)
|
||||
# del temp
|
||||
# os.remove(item)
|
||||
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: eval_dataloader_DIV2K.py
|
||||
# Created Date: Tuesday January 12th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th October 2021 8:29:51 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import glob
|
||||
import torch
|
||||
|
||||
class TestDataset:
|
||||
def __init__( self,
|
||||
path,
|
||||
batch_size = 16,
|
||||
subffix=['png','jpg']):
|
||||
"""Initialize and preprocess the setX dataset."""
|
||||
self.path = path
|
||||
|
||||
self.subffix = subffix
|
||||
self.dataset = []
|
||||
self.pointer = 0
|
||||
self.batch_size = batch_size
|
||||
self.__preprocess__()
|
||||
self.num_images = len(self.dataset)
|
||||
|
||||
def __preprocess__(self):
|
||||
"""Preprocess the SetX dataset."""
|
||||
|
||||
print("processing content images...")
|
||||
for i_suf in self.subffix:
|
||||
temp_path = os.path.join(self.path,'*.%s'%(i_suf))
|
||||
images = glob.glob(temp_path)
|
||||
for item in images:
|
||||
file_name = os.path.basename(item)
|
||||
file_name = os.path.splitext(file_name)
|
||||
file_name = file_name[0]
|
||||
# lr_name = os.path.join(set5lr_path, file_name)
|
||||
self.dataset.append([item,file_name])
|
||||
# self.dataset = images
|
||||
print('Finished preprocessing the content dataset, total image number: %d...'%len(self.dataset))
|
||||
|
||||
def __call__(self):
|
||||
"""Return one batch images."""
|
||||
if self.pointer>=self.num_images:
|
||||
self.pointer = 0
|
||||
a = "The end of the story!"
|
||||
raise StopIteration(print(a))
|
||||
elif (self.pointer+self.batch_size) > self.num_images:
|
||||
end = self.num_images
|
||||
else:
|
||||
end = self.pointer+self.batch_size
|
||||
for i in range(self.pointer, end):
|
||||
filename = self.dataset[i][0]
|
||||
hr_img = cv2.imread(filename)
|
||||
hr_img = cv2.cvtColor(hr_img,cv2.COLOR_BGR2RGB)
|
||||
hr_img = hr_img.transpose((2,0,1))#.astype(np.float)
|
||||
|
||||
hr_img = torch.from_numpy(hr_img)
|
||||
hr_img = hr_img/255.0
|
||||
hr_img = 2 * (hr_img - 0.5)
|
||||
if (i-self.pointer) == 0:
|
||||
hr_ls = hr_img.unsqueeze(0)
|
||||
nm_ls = [self.dataset[i][1],]
|
||||
else:
|
||||
hr_ls = torch.cat((hr_ls,hr_img.unsqueeze(0)),0)
|
||||
nm_ls += [self.dataset[i][1],]
|
||||
self.pointer = end
|
||||
return hr_ls, nm_ls
|
||||
|
||||
def __len__(self):
|
||||
return self.num_images
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + ' (' + self.path + ')'
|
||||
@@ -0,0 +1,248 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: PerceptualLoss.py
|
||||
# Created Date: Wednesday January 13th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 6th March 2021 4:42:26 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision.models import vgg as vgg
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
NAMES = {
|
||||
'vgg11': [
|
||||
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2',
|
||||
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1',
|
||||
'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1',
|
||||
'conv5_2', 'relu5_2', 'pool5'
|
||||
],
|
||||
'vgg13': [
|
||||
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
|
||||
'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
|
||||
'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
||||
'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
||||
],
|
||||
'vgg16': [
|
||||
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
|
||||
'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
|
||||
'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1',
|
||||
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4',
|
||||
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
||||
'pool5'
|
||||
],
|
||||
'vgg19': [
|
||||
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1',
|
||||
'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1',
|
||||
'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4',
|
||||
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
|
||||
'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
||||
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4',
|
||||
'pool5'
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def insert_bn(names):
|
||||
"""Insert bn layer after each conv.
|
||||
|
||||
Args:
|
||||
names (list): The list of layer names.
|
||||
|
||||
Returns:
|
||||
list: The list of layer names with bn layers.
|
||||
"""
|
||||
names_bn = []
|
||||
for name in names:
|
||||
names_bn.append(name)
|
||||
if 'conv' in name:
|
||||
position = name.replace('conv', '')
|
||||
names_bn.append('bn' + position)
|
||||
return names_bn
|
||||
|
||||
|
||||
class VGGFeatureExtractor(nn.Module):
|
||||
"""VGG network for feature extraction.
|
||||
|
||||
In this implementation, we allow users to choose whether use normalization
|
||||
in the input feature and the type of vgg network. Note that the pretrained
|
||||
path must fit the vgg type.
|
||||
|
||||
Args:
|
||||
layer_name_list (list[str]): Forward function returns the corresponding
|
||||
features according to the layer_name_list.
|
||||
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
||||
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image. Importantly,
|
||||
the input feature must in the range [0, 1]. Default: True.
|
||||
requires_grad (bool): If true, the parameters of VGG network will be
|
||||
optimized. Default: False.
|
||||
remove_pooling (bool): If true, the max pooling operations in VGG net
|
||||
will be removed. Default: False.
|
||||
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layer_name_list,
|
||||
vgg_type='vgg19',
|
||||
use_input_norm=True,
|
||||
requires_grad=False,
|
||||
remove_pooling=False,
|
||||
pooling_stride=2):
|
||||
super(VGGFeatureExtractor, self).__init__()
|
||||
|
||||
self.layer_name_list = layer_name_list
|
||||
self.use_input_norm = use_input_norm
|
||||
|
||||
self.names = NAMES[vgg_type.replace('_bn', '')]
|
||||
if 'bn' in vgg_type:
|
||||
self.names = insert_bn(self.names)
|
||||
|
||||
# only borrow layers that will be used to avoid unused params
|
||||
max_idx = 0
|
||||
for v in layer_name_list:
|
||||
idx = self.names.index(v)
|
||||
if idx > max_idx:
|
||||
max_idx = idx
|
||||
features = getattr(vgg,
|
||||
vgg_type)(pretrained=True).features[:max_idx + 1]
|
||||
|
||||
modified_net = OrderedDict()
|
||||
for k, v in zip(self.names, features):
|
||||
if 'pool' in k:
|
||||
# if remove_pooling is true, pooling operation will be removed
|
||||
if remove_pooling:
|
||||
continue
|
||||
else:
|
||||
# in some cases, we may want to change the default stride
|
||||
modified_net[k] = nn.MaxPool2d(
|
||||
kernel_size=2, stride=pooling_stride)
|
||||
else:
|
||||
modified_net[k] = v
|
||||
|
||||
self.vgg_net = nn.Sequential(modified_net)
|
||||
|
||||
if not requires_grad:
|
||||
self.vgg_net.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
else:
|
||||
self.vgg_net.train()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
if self.use_input_norm:
|
||||
# the mean is for image with range [0, 1]
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
# the std is for image with range [0, 1]
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
|
||||
output = {}
|
||||
for key, layer in self.vgg_net._modules.items():
|
||||
x = layer(x)
|
||||
if key in self.layer_name_list:
|
||||
output[key] = x.clone()
|
||||
|
||||
return output
|
||||
|
||||
class PerceptualLoss(nn.Module):
|
||||
"""Perceptual loss with commonly used style loss.
|
||||
|
||||
Args:
|
||||
layer_weights (dict): The weight for each layer of vgg feature.
|
||||
Here is an example: {'conv5_4': 1.}, which means the conv5_4
|
||||
feature layer (before relu5_4) will be extracted with weight
|
||||
1.0 in calculting losses.
|
||||
vgg_type (str): The type of vgg network used as feature extractor.
|
||||
Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image in vgg.
|
||||
Default: True.
|
||||
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
|
||||
loss will be calculated and the loss will multiplied by the
|
||||
weight. Default: 1.0.
|
||||
style_weight (float): If `style_weight > 0`, the style loss will be
|
||||
calculated and the loss will multiplied by the weight.
|
||||
Default: 0.
|
||||
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
|
||||
this is different from the `use_input_norm` which norm the input in
|
||||
in forward function of vgg according to the statistics of dataset.
|
||||
Importantly, the input image must be in range [-1, 1].
|
||||
Default: False.
|
||||
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layer_weights,
|
||||
vgg_type='vgg19',
|
||||
use_input_norm=True,
|
||||
perceptual_weight=1.0,
|
||||
criterion='l1'):
|
||||
super(PerceptualLoss, self).__init__()
|
||||
|
||||
self.perceptual_weight = perceptual_weight
|
||||
self.layer_weights = layer_weights
|
||||
self.vgg = VGGFeatureExtractor(
|
||||
layer_name_list=list(layer_weights.keys()),
|
||||
vgg_type=vgg_type,
|
||||
use_input_norm=use_input_norm)
|
||||
|
||||
self.criterion_type = criterion
|
||||
if self.criterion_type == 'l1':
|
||||
self.criterion = torch.nn.L1Loss()
|
||||
elif self.criterion_type == 'l2':
|
||||
self.criterion = torch.nn.L2loss()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'{criterion} criterion has not been supported.')
|
||||
|
||||
def forward(self, x, gt):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
|
||||
# extract vgg features
|
||||
x_features = self.vgg(x)
|
||||
gt_features = self.vgg(gt.detach())
|
||||
|
||||
# calculate perceptual loss
|
||||
if self.perceptual_weight > 0:
|
||||
percep_loss = 0
|
||||
for k in x_features.keys():
|
||||
percep_loss += self.criterion(
|
||||
x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
percep_loss *= self.perceptual_weight
|
||||
else:
|
||||
percep_loss = None
|
||||
|
||||
return percep_loss
|
||||
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: SliceWassersteinDistance.py
|
||||
# Created Date: Tuesday October 12th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th October 2021 3:11:23 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SWD(nn.Module):
|
||||
""" Slicing layer: computes projections and returns sorted vector """
|
||||
def __init__(self, channel, direction_num=16):
|
||||
super().__init__()
|
||||
# Number of directions
|
||||
self.direc_num = direction_num
|
||||
self.channel = channel
|
||||
self.seed = nn.Parameter(torch.normal(mean=0.0, std=torch.ones(self.direc_num, self.channel)),requires_grad=False)
|
||||
|
||||
def update(self):
|
||||
""" Update random directions """
|
||||
# Generate random directions
|
||||
self.seed.normal_()
|
||||
# norm = self.directions.norm(dim=-1,keepdim=True)
|
||||
self.directions = F.normalize(self.seed)
|
||||
|
||||
# Normalize directions
|
||||
# self.directions = self.directions/norm
|
||||
# print("self.directions shape:", self.directions.shape)
|
||||
# print("self.directions:", self.directions)
|
||||
|
||||
def forward(self, input):
|
||||
""" Implementation of figure 2 """
|
||||
input = input.flatten(-2)
|
||||
sliced = self.directions @ input
|
||||
sliced, _ = sliced.sort()
|
||||
|
||||
return sliced
|
||||
|
||||
if __name__ == "__main__":
|
||||
wocao = torch.ones((4,3,5,5))
|
||||
slice = SWD(wocao.shape[1])
|
||||
slice.update()
|
||||
wocao_slice = slice(wocao)
|
||||
print(wocao_slice.shape)
|
||||
print(wocao_slice)
|
||||
@@ -0,0 +1,266 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: test.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th October 2021 7:44:02 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from torch.backends import cudnn
|
||||
from utilities.json_config import readConfig
|
||||
from utilities.reporter import Reporter
|
||||
from utilities.sshupload import fileUploaderClass
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ('true')
|
||||
|
||||
####################################################################################
|
||||
# To configure the seting of training\finetune\test
|
||||
#
|
||||
####################################################################################
|
||||
def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='fastnst_3',
|
||||
help="version name for train, test, finetune")
|
||||
|
||||
parser.add_argument('-c', '--cuda', type=int, default=-1) # >0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-e', '--checkpoint_epoch', type=int, default=19,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
|
||||
# test
|
||||
parser.add_argument('-t', '--test_script_name', type=str, default='FastNST')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=1)
|
||||
parser.add_argument('-n', '--node_name', type=str, default='localhost',
|
||||
choices=['localhost', '4card','8card','new4card'])
|
||||
|
||||
parser.add_argument('--save_test_result', action='store_false')
|
||||
|
||||
parser.add_argument('--test_dataloader', type=str, default='dir')
|
||||
|
||||
parser.add_argument('-p', '--test_data_path', type=str, default='G:\\UltraHighStyleTransfer\\benchmark')
|
||||
|
||||
parser.add_argument('--use_specified_data', action='store_true')
|
||||
parser.add_argument('--specified_data_paths', type=str, nargs='+', default=[""], help='paths to specified files')
|
||||
|
||||
# # logs (does not to be changed in most time)
|
||||
# parser.add_argument('--dataloader_workers', type=int, default=6)
|
||||
# parser.add_argument('--use_tensorboard', type=str2bool, default='True',
|
||||
# choices=['True', 'False'], help='enable the tensorboard')
|
||||
# parser.add_argument('--log_step', type=int, default=100)
|
||||
# parser.add_argument('--sample_step', type=int, default=100)
|
||||
|
||||
# # template (onece editing finished, it should be deleted)
|
||||
# parser.add_argument('--str_parameter', type=str, default="default", help='str parameter')
|
||||
# parser.add_argument('--str_parameter_choices', type=str,
|
||||
# default="default", choices=['choice1', 'choice2','choice3'], help='str parameter with choices list')
|
||||
# parser.add_argument('--int_parameter', type=int, default=0, help='int parameter')
|
||||
# parser.add_argument('--float_parameter', type=float, default=0.0, help='float parameter')
|
||||
# parser.add_argument('--bool_parameter', type=str2bool, default='True', choices=['True', 'False'], help='bool parameter')
|
||||
# parser.add_argument('--list_str_parameter', type=str, nargs='+', default=["element1","element2"], help='str list parameter')
|
||||
# parser.add_argument('--list_int_parameter', type=int, nargs='+', default=[0,1], help='int list parameter')
|
||||
return parser.parse_args()
|
||||
|
||||
ignoreKey = [
|
||||
"dataloader_workers",
|
||||
"log_root_path",
|
||||
"project_root",
|
||||
"project_summary",
|
||||
"project_checkpoints",
|
||||
"project_samples",
|
||||
"project_scripts",
|
||||
"reporter_path",
|
||||
"use_specified_data",
|
||||
"specified_data_paths",
|
||||
"dataset_path","cuda",
|
||||
"test_script_name",
|
||||
"test_dataloader",
|
||||
"test_dataset_path",
|
||||
"save_test_result",
|
||||
"test_batch_size",
|
||||
"node_name",
|
||||
"checkpoint_epoch",
|
||||
"test_dataset_path",
|
||||
"test_dataset_name",
|
||||
"use_my_test_date"]
|
||||
|
||||
####################################################################################
|
||||
# This function will create the related directories before the
|
||||
# training\fintune\test starts
|
||||
# Your_log_root (version name)
|
||||
# |---summary/...
|
||||
# |---samples/... (save evaluated images)
|
||||
# |---checkpoints/...
|
||||
# |---scripts/...
|
||||
#
|
||||
####################################################################################
|
||||
def createDirs(sys_state):
|
||||
# the base dir
|
||||
if not os.path.exists(sys_state["log_root_path"]):
|
||||
os.makedirs(sys_state["log_root_path"])
|
||||
|
||||
# create dirs
|
||||
sys_state["project_root"] = os.path.join(sys_state["log_root_path"],
|
||||
sys_state["version"])
|
||||
|
||||
project_root = sys_state["project_root"]
|
||||
if not os.path.exists(project_root):
|
||||
os.makedirs(project_root)
|
||||
|
||||
sys_state["project_summary"] = os.path.join(project_root, "summary")
|
||||
if not os.path.exists(sys_state["project_summary"]):
|
||||
os.makedirs(sys_state["project_summary"])
|
||||
|
||||
sys_state["project_checkpoints"] = os.path.join(project_root, "checkpoints")
|
||||
if not os.path.exists(sys_state["project_checkpoints"]):
|
||||
os.makedirs(sys_state["project_checkpoints"])
|
||||
|
||||
sys_state["project_samples"] = os.path.join(project_root, "samples")
|
||||
if not os.path.exists(sys_state["project_samples"]):
|
||||
os.makedirs(sys_state["project_samples"])
|
||||
|
||||
sys_state["project_scripts"] = os.path.join(project_root, "scripts")
|
||||
if not os.path.exists(sys_state["project_scripts"]):
|
||||
os.makedirs(sys_state["project_scripts"])
|
||||
|
||||
sys_state["reporter_path"] = os.path.join(project_root,sys_state["version"]+"_report")
|
||||
|
||||
def main():
|
||||
|
||||
config = getParameters()
|
||||
# speed up the program
|
||||
cudnn.benchmark = True
|
||||
|
||||
sys_state = {}
|
||||
|
||||
# set the GPU number
|
||||
if config.cuda >= 0:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.cuda)
|
||||
|
||||
# read system environment paths
|
||||
env_config = readConfig('env/env.json')
|
||||
env_config = env_config["path"]
|
||||
|
||||
# obtain all configurations in argparse
|
||||
config_dic = vars(config)
|
||||
for config_key in config_dic.keys():
|
||||
sys_state[config_key] = config_dic[config_key]
|
||||
|
||||
#=======================Test Phase=========================#
|
||||
|
||||
# TODO modify below lines to obtain the configuration
|
||||
sys_state["log_root_path"] = env_config["train_log_root"]
|
||||
|
||||
sys_state["test_samples_path"] = os.path.join(env_config["test_log_root"],
|
||||
sys_state["version"] , "samples")
|
||||
# if not config.use_my_test_date:
|
||||
# print("Use public benchmark...")
|
||||
# data_key = config.test_dataset_name.lower()
|
||||
# sys_state["test_dataset_path"] = env_config["test_dataset_paths"][data_key]
|
||||
# if config.test_dataset_name.lower() == "set5" or config.test_dataset_name.lower() =="set14":
|
||||
# sys_state["test_dataloader"] = "setx"
|
||||
# else:
|
||||
# sys_state["test_dataloader"] = config.test_dataset_name.lower()
|
||||
|
||||
# sys_state["test_dataset_name"] = config.test_dataset_name
|
||||
|
||||
if not os.path.exists(sys_state["test_samples_path"]):
|
||||
os.makedirs(sys_state["test_samples_path"])
|
||||
|
||||
# Create dirs
|
||||
createDirs(sys_state)
|
||||
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
|
||||
|
||||
# Read model_config.json from remote machine
|
||||
if sys_state["node_name"]!="localhost":
|
||||
remote_mac = env_config["remote_machine"]
|
||||
nodeinf = remote_mac[sys_state["node_name"]]
|
||||
print("ready to fetch related files from server: %s ......"%nodeinf["ip"])
|
||||
uploader = fileUploaderClass(nodeinf["ip"],nodeinf["user"],nodeinf["passwd"])
|
||||
|
||||
remotebase = os.path.join(nodeinf['base_path'],"train_logs",sys_state["version"]).replace('\\','/')
|
||||
|
||||
# Get the config.json
|
||||
print("ready to get the config.json...")
|
||||
remoteFile = os.path.join(remotebase, env_config["config_json_name"]).replace('\\','/')
|
||||
localFile = config_json
|
||||
|
||||
ssh_state = uploader.sshScpGet(remoteFile, localFile)
|
||||
if not ssh_state:
|
||||
raise Exception(print("Get file %s failed! config.json does not exist!"%remoteFile))
|
||||
print("success get the config.json from server %s"%nodeinf['ip'])
|
||||
|
||||
# Read model_config.json
|
||||
json_obj = readConfig(config_json)
|
||||
for item in json_obj.items():
|
||||
if item[0] in ignoreKey:
|
||||
pass
|
||||
else:
|
||||
sys_state[item[0]] = item[1]
|
||||
|
||||
# Read scripts from remote machine
|
||||
if sys_state["node_name"]!="localhost":
|
||||
# # Get scripts
|
||||
# remoteFile = os.path.join(remotebase, "scripts", sys_state["gScriptName"]+".py").replace('\\','/')
|
||||
# localFile = os.path.join(sys_state["project_scripts"], sys_state["gScriptName"]+".py")
|
||||
# ssh_state = uploader.sshScpGet(remoteFile, localFile)
|
||||
# if not ssh_state:
|
||||
# raise Exception(print("Get file %s failed! Program exists!"%remoteFile))
|
||||
# print("Get the scripts:%s.py successfully"%sys_state["gScriptName"])
|
||||
# Get checkpoint of generator
|
||||
localFile = os.path.join(sys_state["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
sys_state["checkpoint_names"]["generator_name"]))
|
||||
if not os.path.exists(localFile):
|
||||
remoteFile = os.path.join(remotebase, "checkpoints",
|
||||
"epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
sys_state["checkpoint_names"]["generator_name"])).replace('\\','/')
|
||||
ssh_state = uploader.sshScpGet(remoteFile, localFile, True)
|
||||
if not ssh_state:
|
||||
raise Exception(print("Get file %s failed! Checkpoint file does not exist!"%remoteFile))
|
||||
print("Get the checkpoint %s successfully!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
sys_state["checkpoint_names"]["generator_name"])))
|
||||
else:
|
||||
print("%s exists!"%("epoch%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
sys_state["checkpoint_names"]["generator_name"])))
|
||||
|
||||
# TODO get the checkpoint file path
|
||||
sys_state["ckp_name"] = {}
|
||||
for data_key in sys_state["checkpoint_names"].keys():
|
||||
sys_state["ckp_name"][data_key] = os.path.join(sys_state["project_checkpoints"],
|
||||
"%d_%s.pth"%(sys_state["checkpoint_epoch"],
|
||||
sys_state["checkpoint_names"][data_key]))
|
||||
|
||||
# Get the test configurations
|
||||
sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"]
|
||||
|
||||
# make a reporter
|
||||
report_path = os.path.join(env_config["test_log_root"], sys_state["version"],
|
||||
sys_state["version"]+"_report")
|
||||
reporter = Reporter(report_path)
|
||||
reporter.writeConfig(sys_state)
|
||||
|
||||
# Display the test information
|
||||
# TODO modify below lines to display your configuration information
|
||||
moduleName = "test_scripts.tester_" + sys_state["test_script_name"]
|
||||
print("Start to run test script: {}".format(moduleName))
|
||||
print("Test version: %s"%sys_state["version"])
|
||||
print("Test Script Name: %s"%sys_state["test_script_name"])
|
||||
|
||||
package = __import__(moduleName, fromlist=True)
|
||||
testerClass = getattr(package, 'Tester')
|
||||
tester = testerClass(sys_state,reporter)
|
||||
tester.test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th October 2021 8:22:37 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
|
||||
import torch
|
||||
from utilities.utilities import tensor2img
|
||||
|
||||
# from utilities.Reporter import Reporter
|
||||
from tqdm import tqdm
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
#============build evaluation dataloader==============#
|
||||
print("Prepare the test dataloader...")
|
||||
dlModulename = config["test_dataloader"]
|
||||
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'TestDataset')
|
||||
dataloader = dataloaderClass(config["test_data_path"],
|
||||
1,
|
||||
["png","jpg"])
|
||||
self.test_loader= dataloader
|
||||
|
||||
self.test_iter = len(dataloader)
|
||||
# if len(dataloader)%config["batch_size"]>0:
|
||||
# self.test_iter+=1
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
script_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(script_name, fromlist=True)
|
||||
network_class = getattr(package, class_name)
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = network_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_epoch"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
|
||||
self.network.load_state_dict(torch.load(model_path))
|
||||
# self.network.load_state_dict(torch.load(pathwocao))
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
def test(self):
|
||||
|
||||
# save_result = self.config["saveTestResult"]
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_epoch = self.config["checkpoint_epoch"]
|
||||
version = self.config["version"]
|
||||
batch_size = self.config["batch_size"]
|
||||
win_size = self.config["model_configs"]["g_model"]["module_params"]["window_size"]
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
total = len(self.test_loader)
|
||||
print("total:", total)
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
for _ in tqdm(range(total)):
|
||||
contents, img_names = self.test_loader()
|
||||
B, C, H, W = contents.shape
|
||||
crop_h = H - H%32
|
||||
crop_w = W - W%32
|
||||
crop_s = min(crop_h, crop_w)
|
||||
contents = contents[:,:,(H//2 - crop_s//2):(crop_s//2 + H//2),
|
||||
(W//2 - crop_s//2):(crop_s//2 + W//2)]
|
||||
if self.config["cuda"] >=0:
|
||||
contents = contents.cuda()
|
||||
res = self.network(contents, (crop_s, crop_s))
|
||||
print("res shape:", res.shape)
|
||||
res = tensor2img(res.cpu())
|
||||
temp_img = res[0,:,:,:]
|
||||
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
|
||||
print(save_dir)
|
||||
print(img_names[0])
|
||||
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}.png'.format(
|
||||
img_names[0], version, ckp_epoch)),temp_img)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 4th July 2021 11:32:14 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
|
||||
import torch
|
||||
from utilities.utilities import tensor2img
|
||||
|
||||
# from utilities.Reporter import Reporter
|
||||
from tqdm import tqdm
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
#============build evaluation dataloader==============#
|
||||
print("Prepare the test dataloader...")
|
||||
dlModulename = config["test_dataloader"]
|
||||
package = __import__("data_tools.test_dataloader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'TestDataset')
|
||||
dataloader = dataloaderClass(config["test_data_path"],
|
||||
config["batch_size"],
|
||||
["png","jpg"])
|
||||
self.test_loader= dataloader
|
||||
|
||||
self.test_iter = len(dataloader)//config["batch_size"]
|
||||
if len(dataloader)%config["batch_size"]>0:
|
||||
self.test_iter+=1
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
script_name = "components."+self.config["module_script_name"]
|
||||
class_name = self.config["class_name"]
|
||||
package = __import__(script_name, fromlist=True)
|
||||
network_class = getattr(package, class_name)
|
||||
n_class = len(self.config["selectedStyleDir"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = network_class(self.config["GConvDim"],
|
||||
self.config["GKS"],
|
||||
self.config["resNum"],
|
||||
n_class
|
||||
#**self.config["module_params"]
|
||||
)
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
# loader1 = torch.load(self.config["ckp_name"]["generator_name"])
|
||||
# print(loader1.key())
|
||||
# pathwocao = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\repsr_pixel_0\\checkpoints\\epoch%d_RepSR_Plain.pth"%self.config["checkpoint_epoch"]
|
||||
self.network.load_state_dict(torch.load(self.config["ckp_name"]["generator_name"])["g_model"])
|
||||
# self.network.load_state_dict(torch.load(pathwocao))
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["checkpoint_epoch"]))
|
||||
|
||||
def test(self):
|
||||
|
||||
# save_result = self.config["saveTestResult"]
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_epoch = self.config["checkpoint_epoch"]
|
||||
version = self.config["version"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_names = self.config["selectedStyleDir"]
|
||||
n_class = len(style_names)
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
condition_labels = torch.ones((n_class, batch_size, 1)).long()
|
||||
for i in range(n_class):
|
||||
condition_labels[i,:,:] = condition_labels[i,:,:]*i
|
||||
if self.config["cuda"] >=0:
|
||||
condition_labels = condition_labels.cuda()
|
||||
total = len(self.test_loader)
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
for _ in tqdm(range(total//batch_size)):
|
||||
contents, img_names = self.test_loader()
|
||||
for i in range(n_class):
|
||||
if self.config["cuda"] >=0:
|
||||
contents = contents.cuda()
|
||||
res, _ = self.network(contents, condition_labels[i, 0, :])
|
||||
res = tensor2img(res.cpu())
|
||||
for t in range(batch_size):
|
||||
temp_img = res[t,:,:,:]
|
||||
temp_img = cv2.cvtColor(temp_img, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(os.path.join(save_dir,'{}_version_{}_step{}_style_{}.png'.format(
|
||||
img_names[t], version, ckp_epoch, style_names[i])),temp_img)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
@@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: train.py
|
||||
# Created Date: Tuesday April 28th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 8:50:15 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
from torch.backends import cudnn
|
||||
from utilities.json_config import readConfig, writeConfig
|
||||
from utilities.reporter import Reporter
|
||||
from utilities.yaml_config import getConfigYaml
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ('true')
|
||||
|
||||
####################################################################################
|
||||
# To configure the seting of training\finetune\test
|
||||
#
|
||||
####################################################################################
|
||||
def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='liff_warpinvo_0',
|
||||
help="version name for train, test, finetune")
|
||||
|
||||
parser.add_argument('-p', '--phase', type=str, default="train",
|
||||
choices=['train', 'finetune','debug'],
|
||||
help="The phase of current project")
|
||||
|
||||
parser.add_argument('-c', '--cuda', type=int, default=1) # <0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-e', '--checkpoint_epoch', type=int, default=74,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
|
||||
# training
|
||||
parser.add_argument('--experiment_description', type=str,
|
||||
default="尝试使用Liif+Invo作为上采样和降采样的算子,降采样两个DSF算子,上采样两个DSF算子")
|
||||
|
||||
parser.add_argument('--train_yaml', type=str, default="train_FastNST_CNN_Resblock.yaml")
|
||||
|
||||
# # logs (does not to be changed in most time)
|
||||
# parser.add_argument('--dataloader_workers', type=int, default=6)
|
||||
# parser.add_argument('--use_tensorboard', type=str2bool, default='True',
|
||||
# choices=['True', 'False'], help='enable the tensorboard')
|
||||
# parser.add_argument('--log_step', type=int, default=100)
|
||||
# parser.add_argument('--sample_step', type=int, default=100)
|
||||
|
||||
# # template (onece editing finished, it should be deleted)
|
||||
# parser.add_argument('--str_parameter', type=str, default="default", help='str parameter')
|
||||
# parser.add_argument('--str_parameter_choices', type=str,
|
||||
# default="default", choices=['choice1', 'choice2','choice3'], help='str parameter with choices list')
|
||||
# parser.add_argument('--int_parameter', type=int, default=0, help='int parameter')
|
||||
# parser.add_argument('--float_parameter', type=float, default=0.0, help='float parameter')
|
||||
# parser.add_argument('--bool_parameter', type=str2bool, default='True', choices=['True', 'False'], help='bool parameter')
|
||||
# parser.add_argument('--list_str_parameter', type=str, nargs='+', default=["element1","element2"], help='str list parameter')
|
||||
# parser.add_argument('--list_int_parameter', type=int, nargs='+', default=[0,1], help='int list parameter')
|
||||
return parser.parse_args()
|
||||
|
||||
ignoreKey = [
|
||||
"dataloader_workers",
|
||||
"log_root_path",
|
||||
"project_root",
|
||||
"project_summary",
|
||||
"project_checkpoints",
|
||||
"project_samples",
|
||||
"project_scripts",
|
||||
"reporter_path",
|
||||
"use_specified_data",
|
||||
"specified_data_paths",
|
||||
"dataset_path","cuda",
|
||||
"test_script_name",
|
||||
"test_dataloader",
|
||||
"test_dataset_path",
|
||||
"save_test_result",
|
||||
"test_batch_size",
|
||||
"node_name",
|
||||
"checkpoint_epoch",
|
||||
"test_dataset_path",
|
||||
"test_dataset_name",
|
||||
"use_my_test_date"]
|
||||
|
||||
####################################################################################
|
||||
# This function will create the related directories before the
|
||||
# training\fintune\test starts
|
||||
# Your_log_root (version name)
|
||||
# |---summary/...
|
||||
# |---samples/... (save evaluated images)
|
||||
# |---checkpoints/...
|
||||
# |---scripts/...
|
||||
#
|
||||
####################################################################################
|
||||
def createDirs(sys_state):
|
||||
# the base dir
|
||||
if not os.path.exists(sys_state["log_root_path"]):
|
||||
os.makedirs(sys_state["log_root_path"])
|
||||
|
||||
# create dirs
|
||||
sys_state["project_root"] = os.path.join(sys_state["log_root_path"],
|
||||
sys_state["version"])
|
||||
|
||||
project_root = sys_state["project_root"]
|
||||
if not os.path.exists(project_root):
|
||||
os.makedirs(project_root)
|
||||
|
||||
sys_state["project_summary"] = os.path.join(project_root, "summary")
|
||||
if not os.path.exists(sys_state["project_summary"]):
|
||||
os.makedirs(sys_state["project_summary"])
|
||||
|
||||
sys_state["project_checkpoints"] = os.path.join(project_root, "checkpoints")
|
||||
if not os.path.exists(sys_state["project_checkpoints"]):
|
||||
os.makedirs(sys_state["project_checkpoints"])
|
||||
|
||||
sys_state["project_samples"] = os.path.join(project_root, "samples")
|
||||
if not os.path.exists(sys_state["project_samples"]):
|
||||
os.makedirs(sys_state["project_samples"])
|
||||
|
||||
sys_state["project_scripts"] = os.path.join(project_root, "scripts")
|
||||
if not os.path.exists(sys_state["project_scripts"]):
|
||||
os.makedirs(sys_state["project_scripts"])
|
||||
|
||||
sys_state["reporter_path"] = os.path.join(project_root,sys_state["version"]+"_report")
|
||||
|
||||
def main():
|
||||
|
||||
config = getParameters()
|
||||
# speed up the program
|
||||
cudnn.benchmark = True
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_group_logo()
|
||||
|
||||
sys_state = {}
|
||||
|
||||
# set the GPU number
|
||||
if config.cuda >= 0:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.cuda)
|
||||
|
||||
# read system environment paths
|
||||
env_config = readConfig('env/env.json')
|
||||
env_config = env_config["path"]
|
||||
|
||||
# obtain all configurations in argparse
|
||||
config_dic = vars(config)
|
||||
for config_key in config_dic.keys():
|
||||
sys_state[config_key] = config_dic[config_key]
|
||||
|
||||
#=======================Train Phase=========================#
|
||||
if config.phase == "train":
|
||||
# read training configurations from yaml file
|
||||
ymal_config = getConfigYaml(os.path.join(env_config["train_config_path"], config.train_yaml))
|
||||
for item in ymal_config.items():
|
||||
sys_state[item[0]] = item[1]
|
||||
|
||||
# create related dirs
|
||||
sys_state["log_root_path"] = env_config["train_log_root"]
|
||||
createDirs(sys_state)
|
||||
|
||||
# create reporter file
|
||||
reporter = Reporter(sys_state["reporter_path"])
|
||||
|
||||
# save the config json
|
||||
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
|
||||
writeConfig(config_json, sys_state)
|
||||
|
||||
# save the dependent scripts
|
||||
# TODO and copy the scripts to the project dir
|
||||
|
||||
# save the trainer script into [train_logs_root]\[version name]\scripts\
|
||||
file1 = os.path.join(env_config["train_scripts_path"],
|
||||
"trainer_%s.py"%sys_state["train_script_name"])
|
||||
tgtfile1 = os.path.join(sys_state["project_scripts"],
|
||||
"trainer_%s.py"%sys_state["train_script_name"])
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
|
||||
# save the yaml file
|
||||
file1 = os.path.join(env_config["train_config_path"], config.train_yaml)
|
||||
tgtfile1 = os.path.join(sys_state["project_scripts"], config.train_yaml)
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
|
||||
# TODO replace below lines, here to save the critical scripts
|
||||
|
||||
#=====================Finetune Phase=====================#
|
||||
elif config.phase == "finetune":
|
||||
sys_state["log_root_path"] = env_config["train_log_root"]
|
||||
sys_state["project_root"] = os.path.join(sys_state["log_root_path"], sys_state["version"])
|
||||
|
||||
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
|
||||
train_config = readConfig(config_json)
|
||||
for item in train_config.items():
|
||||
if item[0] in ignoreKey:
|
||||
pass
|
||||
else:
|
||||
sys_state[item[0]] = item[1]
|
||||
|
||||
createDirs(sys_state)
|
||||
reporter = Reporter(sys_state["reporter_path"])
|
||||
sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"]
|
||||
|
||||
|
||||
# get the dataset path
|
||||
sys_state["dataset_paths"] = {}
|
||||
for data_key in env_config["dataset_paths"].keys():
|
||||
sys_state["dataset_paths"][data_key] = env_config["dataset_paths"][data_key]
|
||||
|
||||
# display the training information
|
||||
moduleName = "train_scripts.trainer_" + sys_state["train_script_name"]
|
||||
if config.phase == "finetune":
|
||||
moduleName = sys_state["com_base"] + "trainer_" + sys_state["train_script_name"]
|
||||
|
||||
# print some important information
|
||||
# TODO
|
||||
print("Start to run training script: {}".format(moduleName))
|
||||
print("Traning version: %s"%sys_state["version"])
|
||||
print("Dataloader Name: %s"%sys_state["dataloader"])
|
||||
# print("Image Size: %d"%sys_state["imsize"])
|
||||
print("Batch size: %d"%(sys_state["batch_size"]))
|
||||
|
||||
|
||||
|
||||
# Load the training script and start to train
|
||||
reporter.writeConfig(sys_state)
|
||||
|
||||
package = __import__(moduleName, fromlist=True)
|
||||
trainerClass= getattr(package, 'Trainer')
|
||||
trainer = trainerClass(sys_state, reporter)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,307 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_condition_SN_multiscale.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th October 2021 2:18:26 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from components.Transform import Transform_block
|
||||
from utilities.utilities import denorm, Gram, img2tensor255
|
||||
from pretrained_weights.vgg import VGG16
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
dataloader = self.dataloader_class(self.train_dataset,
|
||||
config["batch_size"],
|
||||
config["imcrop_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
# TODO To save the important scripts
|
||||
# save the yaml file
|
||||
import shutil
|
||||
file1 = os.path.join("components", "%s.py"%model_config["g_model"]["script"])
|
||||
tgtfile1 = os.path.join(self.config["project_scripts"], "%s.py"%model_config["g_model"]["script"])
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
# print("prepare the fixed labels...")
|
||||
# fix_label = [i for i in range(n_class)]
|
||||
# fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
style_tensor = img2tensor255(style_img).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_tensor = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
# style_features = VGG(style_tensor)
|
||||
style_gram = {}
|
||||
for key, value in style_tensor.items():
|
||||
style_gram[key] = Gram(value)
|
||||
del style_tensor
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
s_loss = MSE_loss(Gram(value), style_gram[key])
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = fake_image
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|
||||
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_condition_SN_multiscale.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 7:38:36 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from components.Transform import Transform_block
|
||||
from utilities.utilities import denorm, Gram, img2tensor255crop
|
||||
from pretrained_weights.vgg import VGG16
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
dataloader = self.dataloader_class(self.train_dataset,
|
||||
config["batch_size"],
|
||||
config["imcrop_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
# print("prepare the fixed labels...")
|
||||
# fix_label = [i for i in range(n_class)]
|
||||
# fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
style_gram = {}
|
||||
for key, value in style_features.items():
|
||||
style_gram[key] = Gram(value)
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
s_loss = MSE_loss(Gram(value), style_gram[key])
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = fake_image
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|
||||
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_condition_SN_multiscale.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 9:25:13 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from utilities.utilities import denorm, Gram, img2tensor255crop
|
||||
from pretrained_weights.vgg import VGG16
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
dataloader = self.dataloader_class(self.train_dataset,
|
||||
config["batch_size"],
|
||||
config["imcrop_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
# print("prepare the fixed labels...")
|
||||
# fix_label = [i for i in range(n_class)]
|
||||
# fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
style_gram = {}
|
||||
for key, value in style_features.items():
|
||||
style_gram[key] = Gram(value)
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
s_loss = MSE_loss(Gram(value), style_gram[key])
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = fake_image
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|
||||
@@ -0,0 +1,300 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_condition_SN_multiscale.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 19th October 2021 2:28:24 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from utilities.utilities import denorm, img2tensor255crop
|
||||
from losses.SliceWassersteinDistance import SWD
|
||||
from pretrained_weights.vgg import VGG16
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
dataloader = self.dataloader_class(self.train_dataset,
|
||||
config["batch_size"],
|
||||
config["imcrop_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_epoch"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
swd_dim = self.config["swd_dim"]
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
# print("prepare the fixed labels...")
|
||||
# fix_label = [i for i in range(n_class)]
|
||||
# fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
# swd = SWD()
|
||||
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
swd_list = {}
|
||||
for key, value in style_features.items():
|
||||
|
||||
swd_list[key] = SWD(value.shape[1],swd_dim).cuda()
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
swd_list[key].update()
|
||||
s_loss = MSE_loss(swd_list[key](value), swd_list[key](style_features[key]))
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = fake_image
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|
||||
@@ -0,0 +1,382 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_condition_SN_multiscale.py
|
||||
# Created Date: Saturday April 18th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 6th July 2021 7:36:42 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from components.Transform import Transform_block
|
||||
from utilities.utilities import denorm
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.dataloader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
# dataloader = self.dataloader_class(self.train_dataset,
|
||||
# config["batch_size_list"][0],
|
||||
# config["imcrop_size_list"][0],
|
||||
# **config["dataset_params"])
|
||||
|
||||
# self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
dscript_name = "components." + model_config["d_model"]["script"]
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
dscript_name = self.config["com_base"] + model_config["d_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
class_name = model_config["d_model"]["class_name"]
|
||||
package = __import__(dscript_name, fromlist=True)
|
||||
dis_class = getattr(package, class_name)
|
||||
self.dis = dis_class(**model_config["d_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
self.reporter.writeInfo("Discriminator structure:")
|
||||
self.reporter.writeModel(self.dis.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
self.dis = self.dis.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["discriminator_name"]))
|
||||
self.dis.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
d_train_opt = self.config['d_optim_config']
|
||||
g_optim_params = []
|
||||
d_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
for k, v in self.dis.named_parameters():
|
||||
if v.requires_grad:
|
||||
d_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
self.d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
|
||||
n_class = len(self.config["selected_style_dir"])
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
feature_w = self.config["feature_weight"]
|
||||
transform_w = self.config["transform_weight"]
|
||||
d_step = self.config["d_step"]
|
||||
g_step = self.config["g_step"]
|
||||
|
||||
batch_size_list = self.config["batch_size_list"]
|
||||
switch_epoch_list = self.config["switch_epoch_list"]
|
||||
imcrop_size_list = self.config["imcrop_size_list"]
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
current_epoch_index = 0
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
Transform = Transform_block().cuda()
|
||||
L1_loss = torch.nn.L1Loss()
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
Hinge_loss = torch.nn.ReLU().cuda()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
|
||||
output_size = self.dis.get_outputs_len()
|
||||
|
||||
print("prepare the fixed labels...")
|
||||
fix_label = [i for i in range(n_class)]
|
||||
fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(start, total_epoch):
|
||||
|
||||
# switch training image size
|
||||
if epoch in switch_epoch_list:
|
||||
print('Current epoch: {}'.format(epoch))
|
||||
print('***Redefining the dataloader for progressive training.***')
|
||||
print('***Current spatial size is {} and batch size is {}.***'.format(
|
||||
imcrop_size_list[current_epoch_index], batch_size_list[current_epoch_index]))
|
||||
del self.train_loader
|
||||
self.train_loader = self.dataloader_class(self.train_dataset,
|
||||
batch_size_list[current_epoch_index],
|
||||
imcrop_size_list[current_epoch_index],
|
||||
**self.config["dataset_params"])
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // (d_step + g_step)
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
current_epoch_index += 1
|
||||
|
||||
for step in range(step_epoch):
|
||||
self.dis.train()
|
||||
self.gen.train()
|
||||
|
||||
# ================== Train D ================== #
|
||||
# Compute loss with real images
|
||||
for _ in range(d_step):
|
||||
content_images,style_images,label = self.train_loader.next()
|
||||
label = label.long()
|
||||
|
||||
d_out = self.dis(style_images,label)
|
||||
d_loss_real = 0
|
||||
for i in range(output_size):
|
||||
temp = Hinge_loss(1 - d_out[i]).mean()
|
||||
d_loss_real += temp
|
||||
|
||||
d_loss_photo = 0
|
||||
d_out = self.dis(content_images,label)
|
||||
for i in range(output_size):
|
||||
temp = Hinge_loss(1 + d_out[i]).mean()
|
||||
d_loss_photo += temp
|
||||
|
||||
fake_image,_= self.gen(content_images,label)
|
||||
d_out = self.dis(fake_image.detach(),label)
|
||||
d_loss_fake = 0
|
||||
for i in range(output_size):
|
||||
temp = Hinge_loss(1 + d_out[i]).mean()
|
||||
# temp *= prep_weights[i]
|
||||
d_loss_fake += temp
|
||||
|
||||
# Backward + Optimize
|
||||
d_loss = d_loss_real + d_loss_photo + d_loss_fake
|
||||
self.d_optimizer.zero_grad()
|
||||
d_loss.backward()
|
||||
self.d_optimizer.step()
|
||||
|
||||
# ================== Train G ================== #
|
||||
for _ in range(g_step):
|
||||
|
||||
content_images,_,_ = self.train_loader.next()
|
||||
fake_image,real_feature = self.gen(content_images,label)
|
||||
fake_feature = self.gen(fake_image, get_feature=True)
|
||||
d_out = self.dis(fake_image,label.long())
|
||||
|
||||
g_feature_loss = L1_loss(fake_feature,real_feature)
|
||||
g_transform_loss = MSE_loss(Transform(content_images), Transform(fake_image))
|
||||
g_loss_fake = 0
|
||||
for i in range(output_size):
|
||||
temp = -d_out[i].mean()
|
||||
# temp *= prep_weights[i]
|
||||
g_loss_fake += temp
|
||||
|
||||
# backward & optimize
|
||||
g_loss = g_loss_fake + g_feature_loss* feature_w + g_transform_loss* transform_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, d_loss_real: {:.4f}, \\\
|
||||
d_loss_photo: {:.4f}, d_loss_fake: {:.4f}, g_loss: {:.4f}, g_loss_fake: {:.4f}, \\\
|
||||
g_feature_loss: {:.4f}, g_transform_loss: {:.4f}".format(self.config["version"],
|
||||
epoch + 1, total_epoch, elapsed, step + 1, step_epoch,
|
||||
d_loss.item(), d_loss_real.item(), d_loss_photo.item(),
|
||||
d_loss_fake.item(), g_loss.item(), g_loss_fake.item(),\
|
||||
g_feature_loss.item(), g_transform_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeRawInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/d_loss', d_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/d_loss_real', d_loss_real.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/d_loss_photo', d_loss_photo.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/d_loss_fake', d_loss_fake.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/g_loss_fake', g_loss_fake.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/g_feature_loss', g_feature_loss, cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/g_transform_loss', g_transform_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
print("Learning rate decay")
|
||||
for p in self.optimizer.param_groups:
|
||||
p['lr'] *= self.config["lr_decay"]
|
||||
print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
torch.save(self.dis.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["discriminator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(step + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = content_images[0, :, :, :].unsqueeze(0)
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
for index in range(n_class):
|
||||
fake_images,_ = self.gen(sample, fix_label[index].unsqueeze(0))
|
||||
saved_image1 = torch.cat((saved_image1, denorm(fake_images.cpu().data)), 0)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(step + 1)),nrow=3)
|
||||
@@ -0,0 +1,295 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_naiv512.py
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 9th January 2022 12:31:03 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from utilities.utilities import denorm, Gram, img2tensor255crop
|
||||
from pretrained_weights.vgg import VGG16
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
self.train_dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
self.dataloader_class = dataloaderClass
|
||||
dataloader = self.dataloader_class(self.train_dataset,
|
||||
config["batch_size"],
|
||||
config["imcrop_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
self.train_loader= dataloader
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
#==============build tensorboard=================#
|
||||
if self.config["use_tensorboard"]:
|
||||
from utilities.utilities import build_tensorboard
|
||||
self.tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"epoch%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model epoch {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
def __evaluation__(self, epoch, step = 0):
|
||||
# Evaluate the checkpoint
|
||||
self.network.eval()
|
||||
total_psnr = 0
|
||||
total_num = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(self.eval_iter):
|
||||
hr, lr = self.eval_loader()
|
||||
|
||||
if self.config["cuda"] >=0:
|
||||
hr = hr.cuda()
|
||||
lr = lr.cuda()
|
||||
hr = (hr + 1.0)/2.0 * 255.0
|
||||
hr = torch.clamp(hr,0.0,255.0)
|
||||
lr = (lr + 1.0)/2.0 * 255.0
|
||||
lr = torch.clamp(lr,0.0,255.0)
|
||||
res = self.network(lr)
|
||||
# res = (res + 1.0)/2.0 * 255.0
|
||||
# hr = (hr + 1.0)/2.0 * 255.0
|
||||
res = torch.clamp(res,0.0,255.0)
|
||||
diff = (res-hr) ** 2
|
||||
diff = diff.mean(dim=-1).mean(dim=-1).mean(dim=-1).sqrt()
|
||||
psnrs = 20. * (255. / diff).log10()
|
||||
total_psnr+= psnrs.sum()
|
||||
total_num+=res.shape[0]
|
||||
final_psnr = total_psnr/total_num
|
||||
print("[{}], Epoch [{}], psnr: {:.4f}".format(self.config["version"],
|
||||
epoch, final_psnr))
|
||||
self.reporter.writeTrainLog(epoch,step,"psnr: {:.4f}".format(final_psnr))
|
||||
self.tensorboard_writer.add_scalar('metric/loss', final_psnr, epoch)
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_frep = self.config["log_step"]
|
||||
model_freq = self.config["model_save_epoch"]
|
||||
total_epoch = self.config["total_epoch"]
|
||||
batch_size = self.config["batch_size"]
|
||||
style_img = self.config["style_img_path"]
|
||||
|
||||
# prep_weights= self.config["layersWeight"]
|
||||
content_w = self.config["content_weight"]
|
||||
style_w = self.config["style_weight"]
|
||||
crop_size = self.config["imcrop_size"]
|
||||
|
||||
sample_dir = self.config["project_samples"]
|
||||
|
||||
|
||||
#===============build framework================#
|
||||
self.__init_framework__()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
MSE_loss = torch.nn.MSELoss()
|
||||
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
start = self.config["checkpoint_epoch"] - 1
|
||||
else:
|
||||
start = 0
|
||||
|
||||
# print("prepare the fixed labels...")
|
||||
# fix_label = [i for i in range(n_class)]
|
||||
# fix_label = torch.tensor(fix_label).long().cuda()
|
||||
# fix_label = fix_label.view(n_class,1)
|
||||
# fix_label = torch.zeros(n_class, n_class).cuda().scatter_(1, fix_label, 1)
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
step_epoch = len(self.train_loader)
|
||||
step_epoch = step_epoch // batch_size
|
||||
print("Total step = %d in each epoch"%step_epoch)
|
||||
|
||||
VGG = VGG16().cuda()
|
||||
|
||||
MEAN_VAL = 127.5
|
||||
SCALE_VAL= 127.5
|
||||
# Get Style Features
|
||||
imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
imagenet_neg_mean_11= torch.tensor([-103.939 + MEAN_VAL, -116.779 + MEAN_VAL, -123.68 + MEAN_VAL], dtype=torch.float32).reshape(1,3,1,1).cuda()
|
||||
|
||||
style_tensor = img2tensor255crop(style_img,crop_size).cuda()
|
||||
style_tensor = style_tensor.add(imagenet_neg_mean)
|
||||
B, C, H, W = style_tensor.shape
|
||||
style_features = VGG(style_tensor.expand([batch_size, C, H, W]))
|
||||
style_gram = {}
|
||||
for key, value in style_features.items():
|
||||
style_gram[key] = Gram(value)
|
||||
# step_epoch = 2
|
||||
for epoch in range(start, total_epoch):
|
||||
for step in range(step_epoch):
|
||||
self.gen.train()
|
||||
|
||||
content_images = self.train_loader.next()
|
||||
fake_image = self.gen(content_images)
|
||||
generated_features = VGG((fake_image*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_features = VGG((content_images*SCALE_VAL).add(imagenet_neg_mean_11))
|
||||
content_loss = MSE_loss(generated_features['relu2_2'], content_features['relu2_2'])
|
||||
|
||||
style_loss = 0.0
|
||||
for key, value in generated_features.items():
|
||||
s_loss = MSE_loss(Gram(value), style_gram[key])
|
||||
style_loss += s_loss
|
||||
|
||||
# backward & optimize
|
||||
g_loss = content_loss* content_w + style_loss* style_w
|
||||
self.g_optimizer.zero_grad()
|
||||
g_loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_frep == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
# cumulative steps
|
||||
cum_step = (step_epoch * epoch + step + 1)
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Epoch [{}/{}], Step [{}/{}], content_loss: {:.4f}, style_loss: {:.4f}, g_loss: {:.4f}".format(self.config["version"], elapsed, epoch + 1, total_epoch, step + 1, step_epoch, content_loss.item(), style_loss.item(), g_loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["use_tensorboard"]:
|
||||
self.tensorboard_writer.add_scalar('data/g_loss', g_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/content_loss', content_loss.item(), cum_step)
|
||||
self.tensorboard_writer.add_scalar('data/style_loss', style_loss, cum_step)
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (epoch+1) % model_freq==0:
|
||||
print("Save epoch %d model checkpoint!"%(epoch+1))
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'epoch{}_{}.pth'.format(epoch + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print('Sample images {}_fake.jpg'.format(epoch + 1))
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
sample = fake_image
|
||||
saved_image1 = denorm(sample.cpu().data)
|
||||
save_image(saved_image1,
|
||||
os.path.join(sample_dir, '{}_fake.jpg'.format(epoch + 1)),nrow=4)
|
||||
@@ -0,0 +1,83 @@
|
||||
# Related scripts
|
||||
train_script_name: FastNST
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: FastNST
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 32
|
||||
g_kernel_size: 3
|
||||
res_num: 6
|
||||
n_class: 11
|
||||
image_size: 256
|
||||
window_size: 8
|
||||
|
||||
# Training information
|
||||
total_epoch: 120
|
||||
batch_size: 16
|
||||
imcrop_size: 256
|
||||
max2Keep: 10
|
||||
|
||||
# Dataset
|
||||
style_img_path: "G:\\UltraHighStyleTransfer\\reference\\fast-neural-style-pytorch-master\\fast-neural-style-pytorch-master\\images\\mosaic.jpg"
|
||||
dataloader: place365
|
||||
dataset_name: styletransfer
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
color_jitter: Enable
|
||||
color_config:
|
||||
brightness: 0.05
|
||||
contrast: 0.05
|
||||
saturation: 0.05
|
||||
hue: 0.05
|
||||
|
||||
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: !!float 2e-4
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
content_weight: 2.0
|
||||
style_weight: 1.0
|
||||
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# Log
|
||||
log_step: 100
|
||||
model_save_epoch: 1
|
||||
use_tensorboard: True
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
@@ -0,0 +1,108 @@
|
||||
# Related scripts
|
||||
train_script_name: FastNST_CNN
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: FastNST_CNN
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 32
|
||||
g_kernel_size: 3
|
||||
res_num: 6
|
||||
n_class: 11
|
||||
image_size: 256
|
||||
window_size: 8
|
||||
|
||||
# Training information
|
||||
total_epoch: 120
|
||||
batch_size: 16
|
||||
imcrop_size: 256
|
||||
max2Keep: 10
|
||||
|
||||
# Dataset
|
||||
style_img_path: "images/mosaic.jpg"
|
||||
dataloader: place365
|
||||
dataset_name: styletransfer
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
color_jitter: Enable
|
||||
color_config:
|
||||
brightness: 0.05
|
||||
contrast: 0.05
|
||||
saturation: 0.05
|
||||
hue: 0.05
|
||||
|
||||
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
|
||||
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
|
||||
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
|
||||
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
|
||||
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
|
||||
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
|
||||
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
|
||||
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
|
||||
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
|
||||
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
|
||||
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
|
||||
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
|
||||
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
|
||||
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
|
||||
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
|
||||
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
|
||||
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
|
||||
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
|
||||
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
|
||||
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
|
||||
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
|
||||
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
|
||||
# 'b_building_facade',
|
||||
# 'c_cemetery'
|
||||
# ]
|
||||
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: !!float 2e-4
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
content_weight: 2.0
|
||||
style_weight: 1.0
|
||||
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# Log
|
||||
log_step: 100
|
||||
model_save_epoch: 1
|
||||
use_tensorboard: True
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
@@ -0,0 +1,108 @@
|
||||
# Related scripts
|
||||
train_script_name: FastNST_CNN
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: FastNST_CNN_Resblock
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 32
|
||||
g_kernel_size: 3
|
||||
res_num: 6
|
||||
n_class: 11
|
||||
image_size: 256
|
||||
window_size: 8
|
||||
|
||||
# Training information
|
||||
total_epoch: 120
|
||||
batch_size: 16
|
||||
imcrop_size: 256
|
||||
max2Keep: 10
|
||||
|
||||
# Dataset
|
||||
style_img_path: "images/mosaic.jpg"
|
||||
dataloader: place365
|
||||
dataset_name: styletransfer
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
color_jitter: Enable
|
||||
color_config:
|
||||
brightness: 0.05
|
||||
contrast: 0.05
|
||||
saturation: 0.05
|
||||
hue: 0.05
|
||||
|
||||
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
|
||||
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
|
||||
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
|
||||
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
|
||||
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
|
||||
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
|
||||
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
|
||||
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
|
||||
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
|
||||
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
|
||||
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
|
||||
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
|
||||
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
|
||||
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
|
||||
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
|
||||
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
|
||||
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
|
||||
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
|
||||
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
|
||||
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
|
||||
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
|
||||
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
|
||||
# 'b_building_facade',
|
||||
# 'c_cemetery'
|
||||
# ]
|
||||
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: !!float 2e-4
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
content_weight: 2.0
|
||||
style_weight: 1.0
|
||||
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# Log
|
||||
log_step: 100
|
||||
model_save_epoch: 1
|
||||
use_tensorboard: True
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
@@ -0,0 +1,110 @@
|
||||
# Related scripts
|
||||
train_script_name: FastNST_Liif
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: FastNST_Liif
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 32
|
||||
g_kernel_size: 3
|
||||
res_num: 6
|
||||
n_class: 11
|
||||
image_size: 256
|
||||
mlp_hidden_list: [32,32]
|
||||
batch_size: 10
|
||||
window_size: 8
|
||||
|
||||
# Training information
|
||||
total_epoch: 120
|
||||
batch_size: 10
|
||||
imcrop_size: 256
|
||||
max2Keep: 10
|
||||
|
||||
# Dataset
|
||||
style_img_path: "images/mosaic.jpg"
|
||||
dataloader: place365
|
||||
dataset_name: styletransfer
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
color_jitter: Enable
|
||||
color_config:
|
||||
brightness: 0.05
|
||||
contrast: 0.05
|
||||
saturation: 0.05
|
||||
hue: 0.05
|
||||
|
||||
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
|
||||
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
|
||||
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
|
||||
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
|
||||
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
|
||||
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
|
||||
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
|
||||
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
|
||||
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
|
||||
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
|
||||
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
|
||||
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
|
||||
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
|
||||
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
|
||||
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
|
||||
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
|
||||
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
|
||||
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
|
||||
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
|
||||
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
|
||||
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
|
||||
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
|
||||
# 'b_building_facade',
|
||||
# 'c_cemetery'
|
||||
# ]
|
||||
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: !!float 2e-4
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
content_weight: 2.0
|
||||
style_weight: 1.0
|
||||
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# Log
|
||||
log_step: 100
|
||||
model_save_epoch: 1
|
||||
use_tensorboard: True
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
@@ -0,0 +1,109 @@
|
||||
# Related scripts
|
||||
train_script_name: FastNST_Liif
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: FastNST_Liif_warp
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 32
|
||||
g_kernel_size: 3
|
||||
res_num: 6
|
||||
n_class: 11
|
||||
image_size: 256
|
||||
batch_size: 16
|
||||
window_size: 8
|
||||
|
||||
# Training information
|
||||
total_epoch: 120
|
||||
batch_size: 16
|
||||
imcrop_size: 256
|
||||
max2Keep: 10
|
||||
|
||||
# Dataset
|
||||
style_img_path: "images/mosaic.jpg"
|
||||
dataloader: place365
|
||||
dataset_name: styletransfer
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
color_jitter: Enable
|
||||
color_config:
|
||||
brightness: 0.05
|
||||
contrast: 0.05
|
||||
saturation: 0.05
|
||||
hue: 0.05
|
||||
|
||||
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
|
||||
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
|
||||
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
|
||||
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
|
||||
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
|
||||
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
|
||||
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
|
||||
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
|
||||
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
|
||||
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
|
||||
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
|
||||
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
|
||||
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
|
||||
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
|
||||
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
|
||||
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
|
||||
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
|
||||
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
|
||||
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
|
||||
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
|
||||
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
|
||||
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
|
||||
# 'b_building_facade',
|
||||
# 'c_cemetery'
|
||||
# ]
|
||||
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: !!float 2e-4
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
content_weight: 2.0
|
||||
style_weight: 1.0
|
||||
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# Log
|
||||
log_step: 100
|
||||
model_save_epoch: 1
|
||||
use_tensorboard: True
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
@@ -0,0 +1,109 @@
|
||||
# Related scripts
|
||||
train_script_name: FastNST_Liif
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: FastNST_Liif_warp
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 32
|
||||
g_kernel_size: 3
|
||||
res_num: 6
|
||||
n_class: 11
|
||||
image_size: 256
|
||||
batch_size: 16
|
||||
window_size: 8
|
||||
|
||||
# Training information
|
||||
total_epoch: 120
|
||||
batch_size: 16
|
||||
imcrop_size: 256
|
||||
max2Keep: 10
|
||||
|
||||
# Dataset
|
||||
style_img_path: "images/mosaic.jpg"
|
||||
dataloader: place365
|
||||
dataset_name: styletransfer
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
color_jitter: Enable
|
||||
color_config:
|
||||
brightness: 0.05
|
||||
contrast: 0.05
|
||||
saturation: 0.05
|
||||
hue: 0.05
|
||||
|
||||
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
|
||||
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
|
||||
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
|
||||
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
|
||||
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
|
||||
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
|
||||
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
|
||||
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
|
||||
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
|
||||
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
|
||||
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
|
||||
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
|
||||
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
|
||||
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
|
||||
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
|
||||
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
|
||||
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
|
||||
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
|
||||
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
|
||||
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
|
||||
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
|
||||
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
|
||||
# 'b_building_facade',
|
||||
# 'c_cemetery'
|
||||
# ]
|
||||
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: !!float 2e-4
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
content_weight: 2.0
|
||||
style_weight: 1.0
|
||||
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# Log
|
||||
log_step: 100
|
||||
model_save_epoch: 1
|
||||
use_tensorboard: True
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
@@ -0,0 +1,109 @@
|
||||
# Related scripts
|
||||
train_script_name: FastNST_SWD
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: FastNST_CNN_Resblock
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 32
|
||||
g_kernel_size: 3
|
||||
res_num: 6
|
||||
n_class: 11
|
||||
image_size: 256
|
||||
window_size: 8
|
||||
|
||||
# Training information
|
||||
total_epoch: 120
|
||||
batch_size: 16
|
||||
imcrop_size: 256
|
||||
max2Keep: 10
|
||||
|
||||
# Dataset
|
||||
style_img_path: "images/mosaic.jpg"
|
||||
dataloader: place365
|
||||
dataset_name: styletransfer
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
color_jitter: Enable
|
||||
color_config:
|
||||
brightness: 0.05
|
||||
contrast: 0.05
|
||||
saturation: 0.05
|
||||
hue: 0.05
|
||||
|
||||
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
# selected_content_dir: ['a_abbey', 'a_arch', 'a_amphitheater', 'a_aqueduct', 'a_arena_rodeo', 'a_athletic_field_outdoor',
|
||||
# 'b_badlands', 'b_balcony_exterior', 'b_bamboo_forest', 'b_barn', 'b_barndoor', 'b_baseball_field',
|
||||
# 'b_basilica', 'b_bayou', 'b_beach', 'b_beach_house', 'b_beer_garden', 'b_boardwalk', 'b_boathouse',
|
||||
# 'b_botanical_garden', 'b_bullring', 'b_butte', 'c_cabin_outdoor', 'c_campsite', 'c_campus',
|
||||
# 'c_canal_natural', 'c_canal_urban', 'c_canyon', 'c_castle', 'c_church_outdoor', 'c_chalet',
|
||||
# 'c_cliff', 'c_coast', 'c_corn_field', 'c_corral', 'c_cottage', 'c_courtyard', 'c_crevasse',
|
||||
# 'd_dam', 'd_desert_vegetation', 'd_desert_road', 'd_doorway_outdoor', 'f_farm', 'f_fairway',
|
||||
# 'f_field_cultivated', 'f_field_wild', 'f_field_road', 'f_fishpond', 'f_florist_shop_indoor',
|
||||
# 'f_forest_broadleaf', 'f_forest_path', 'f_forest_road', 'f_formal_garden', 'g_gazebo_exterior',
|
||||
# 'g_glacier', 'g_golf_course', 'g_greenhouse_indoor', 'g_greenhouse_outdoor', 'g_grotto', 'g_gorge',
|
||||
# 'h_hayfield', 'h_herb_garden', 'h_hot_spring', 'h_house', 'h_hunting_lodge_outdoor', 'i_ice_floe',
|
||||
# 'i_ice_shelf', 'i_iceberg', 'i_inn_outdoor', 'i_islet', 'j_japanese_garden', 'k_kasbah',
|
||||
# 'k_kennel_outdoor', 'l_lagoon', 'l_lake_natural', 'l_lawn', 'l_library_outdoor', 'l_lighthouse',
|
||||
# 'm_mansion', 'm_marsh', 'm_mausoleum', 'm_moat_water', 'm_mosque_outdoor', 'm_mountain',
|
||||
# 'm_mountain_path', 'm_mountain_snowy', 'o_oast_house', 'o_ocean', 'o_orchard', 'p_park',
|
||||
# 'p_pasture', 'p_pavilion', 'p_picnic_area', 'p_pier', 'p_pond', 'r_raft', 'r_railroad_track',
|
||||
# 'r_rainforest', 'r_rice_paddy', 'r_river', 'r_rock_arch', 'r_roof_garden', 'r_rope_bridge',
|
||||
# 'r_ruin', 's_schoolhouse', 's_sky', 's_snowfield', 's_swamp', 's_swimming_hole',
|
||||
# 's_synagogue_outdoor', 't_temple_asia', 't_topiary_garden', 't_tree_farm', 't_tree_house',
|
||||
# 'u_underwater_ocean_deep', 'u_utility_room', 'v_valley', 'v_vegetable_garden', 'v_viaduct',
|
||||
# 'v_village', 'v_vineyard', 'v_volcano', 'w_waterfall', 'w_watering_hole', 'w_wave',
|
||||
# 'w_wheat_field', 'z_zen_garden', 'a_alcove', 'a_apartment-building_outdoor', 'a_artists_loft',
|
||||
# 'b_building_facade',
|
||||
# 'c_cemetery'
|
||||
# ]
|
||||
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: !!float 2e-4
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
content_weight: 1.0
|
||||
style_weight: 10.0
|
||||
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
swd_dim: 32
|
||||
|
||||
# Log
|
||||
log_step: 100
|
||||
model_save_epoch: 1
|
||||
use_tensorboard: True
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
@@ -0,0 +1,98 @@
|
||||
# Related scripts
|
||||
train_script_name: gan
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: Conditional_Generator_Noskip
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 32
|
||||
g_kernel_size: 3
|
||||
res_num: 8
|
||||
n_class: 11
|
||||
d_model:
|
||||
script: Conditional_Discriminator_Projection_big
|
||||
class_name: Discriminator
|
||||
module_params:
|
||||
d_conv_dim: 32
|
||||
d_kernel_size: 5
|
||||
|
||||
# Training information
|
||||
total_epoch: 120
|
||||
batch_size_list: [8, 4, 2]
|
||||
switch_epoch_list: [0, 5, 10]
|
||||
imcrop_size_list: [256, 512, 768]
|
||||
max2Keep: 10
|
||||
movingAverage: 0.05
|
||||
d_success_threshold: 0.8
|
||||
d_step: 3
|
||||
g_step: 1
|
||||
|
||||
# Dataset
|
||||
dataloader: condition
|
||||
dataset_name: styletransfer
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
color_jitter: Enable
|
||||
color_config:
|
||||
brightness: 0.05
|
||||
contrast: 0.05
|
||||
saturation: 0.05
|
||||
hue: 0.05
|
||||
selected_style_dir: ['berthe-morisot','edvard-munch',
|
||||
'ernst-ludwig-kirchner','jackson-pollock','kandinsky','monet',
|
||||
'nicholas','paul-cezanne','picasso','samuel','vangogh']
|
||||
selected_content_dir: ['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: !!float 2e-4
|
||||
betas: [0.9, 0.99]
|
||||
d_optim_config:
|
||||
lr: !!float 2e-4
|
||||
betas: [0.9, 0.99]
|
||||
|
||||
feature_weight: 50.0
|
||||
transform_weight: 50.0
|
||||
layers_weight: [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# Log
|
||||
log_step: 1000
|
||||
sampleStep: 2000
|
||||
model_save_epoch: 1
|
||||
useTensorboard: True
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
discriminator_name: Discriminator
|
||||
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: checkpoint_manager.py
|
||||
# Created Date: Sunday July 12th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 27th July 2020 11:01:16 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
class CheckpointManager(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
self.maxCkpNum = -1
|
||||
self.ckpList = []
|
||||
self.modelsDict = {} # key model name, value model
|
||||
self.currentEpoch = 0
|
||||
|
||||
|
||||
def registerModels(self):
|
||||
pass
|
||||
|
||||
def __updateCkpList__(self):
|
||||
pass
|
||||
|
||||
def saveModel(self):
|
||||
pass
|
||||
|
||||
def loadModel(self):
|
||||
pass
|
||||
|
||||
def saveLR(self):
|
||||
pass
|
||||
|
||||
def loadLR(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
def loadPretrainedModel(chechpointStep,modelSavePath,gModel,dModel,cuda,**kwargs):
|
||||
gModel.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_LocalG.pth'.format(chechpointStep)),map_location=cuda))
|
||||
dModel.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_GlobalD.pth'.format(chechpointStep)),map_location=cuda))
|
||||
print('loaded trained models (epoch: {}) successful!'.format(chechpointStep))
|
||||
if not kwargs:
|
||||
return
|
||||
for k,v in kwargs.items():
|
||||
v.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_{}.pth'.format(chechpointStep,k)),map_location=cuda))
|
||||
print("Loaded param %s"%k)
|
||||
|
||||
def loadPretrainedModelByDict(chechpointStep,modelSavePath,cuda,**kwargs):
|
||||
if not kwargs:
|
||||
return
|
||||
for k,v in kwargs.items():
|
||||
v.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_{}.pth'.format(chechpointStep,k)),map_location=cuda))
|
||||
print("Loaded param %s"%k)
|
||||
|
||||
def loadLR(chechpointStep,modelSavePath,dlr,glr):
|
||||
glr.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_LocalGlr.pth'.format(chechpointStep))))
|
||||
dlr.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_GlobalDlr.pth'.format(chechpointStep))))
|
||||
print("Generator learning rate:%f"%glr.get_lr()[0])
|
||||
print("Discriminator learning rate:%f"%dlr.get_lr()[0])
|
||||
|
||||
def saveLR(step,modelSavePath,dlr,glr):
|
||||
torch.save(glr.state_dict(),os.path.join(modelSavePath, 'Epoch{}_LocalGlr.pth'.format(step + 1)))
|
||||
torch.save(dlr.state_dict(),os.path.join(modelSavePath, 'Epoch{}_GlobalDlr.pth'.format(step + 1)))
|
||||
print("Epoch:{} models learning rate saved!".format(step+1))
|
||||
|
||||
|
||||
def saveModel(step,modelSavePath,gModel,dModel,**kwargs):
|
||||
torch.save(gModel.state_dict(),
|
||||
os.path.join(modelSavePath, 'Epoch{}_LocalG.pth'.format(step + 1)))
|
||||
torch.save(dModel.state_dict(),
|
||||
os.path.join(modelSavePath, 'Epoch{}_GlobalD.pth'.format(step + 1)))
|
||||
print("Epoch:{} models saved!".format(step+1))
|
||||
if not kwargs:
|
||||
return
|
||||
for k,v in kwargs.items():
|
||||
torch.save(v.state_dict(),
|
||||
os.path.join(modelSavePath, 'Epoch{}_{}.pth'.format(step + 1,k)))
|
||||
print("Epoch:{} models param {} saved!".format(step+1,k))
|
||||
|
||||
def saveModelByDict(step,modelSavePath,**kwargs):
|
||||
if not kwargs:
|
||||
return
|
||||
for k,v in kwargs.items():
|
||||
torch.save(v.state_dict(),
|
||||
os.path.join(modelSavePath, 'Epoch{}_{}.pth'.format(step + 1,k)))
|
||||
print("Epoch:{} models param {} saved!".format(step+1,k))
|
||||
@@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: figure.py
|
||||
# Created Date: Tuesday October 13th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 13th October 2020 2:54:30 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def plot_loss_curve(losses, save_path):
|
||||
for key in losses.keys():
|
||||
plt.plot(range(len(losses[key])), losses[key], label=key)
|
||||
plt.xlabel('iteration')
|
||||
plt.title(f'loss curve')
|
||||
plt.legend()
|
||||
plt.savefig(save_path)
|
||||
plt.clf()
|
||||
@@ -0,0 +1,15 @@
|
||||
import json
|
||||
|
||||
|
||||
def readConfig(path):
|
||||
with open(path,'r') as cf:
|
||||
nodelocaltionstr = cf.read()
|
||||
nodelocaltioninf = json.loads(nodelocaltionstr)
|
||||
if isinstance(nodelocaltioninf,str):
|
||||
nodelocaltioninf = json.loads(nodelocaltioninf)
|
||||
return nodelocaltioninf
|
||||
|
||||
def writeConfig(path, info):
|
||||
with open(path, 'w') as cf:
|
||||
configjson = json.dumps(info, indent=4)
|
||||
cf.writelines(configjson)
|
||||
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: learningrate_scheduler.py
|
||||
# Created Date: Tuesday January 5th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 5th January 2021 2:04:00 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
# Refer to basicSR https://github.com/xinntao/BasicSR
|
||||
|
||||
|
||||
import math
|
||||
from collections import Counter
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class MultiStepRestartLR(_LRScheduler):
|
||||
""" MultiStep with restarts learning rate scheme.
|
||||
|
||||
Args:
|
||||
optimizer (torch.nn.optimizer): Torch optimizer.
|
||||
milestones (list): Iterations that will decrease learning rate.
|
||||
gamma (float): Decrease ratio. Default: 0.1.
|
||||
restarts (list): Restart iterations. Default: [0].
|
||||
restart_weights (list): Restart weights at each restart iteration.
|
||||
Default: [1].
|
||||
last_epoch (int): Used in _LRScheduler. Default: -1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
milestones,
|
||||
gamma=0.1,
|
||||
restarts=(0,),
|
||||
restart_weights=(1,),
|
||||
last_epoch=-1):
|
||||
self.milestones = Counter(milestones)
|
||||
self.gamma = gamma
|
||||
self.restarts = restarts
|
||||
self.restart_weights = restart_weights
|
||||
print(type(self.restarts),self.restarts)
|
||||
print(type(self.restart_weights),self.restart_weights)
|
||||
assert len(self.restarts) == len(
|
||||
self.restart_weights), 'restarts and their weights do not match.'
|
||||
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch in self.restarts:
|
||||
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
||||
return [
|
||||
group['initial_lr'] * weight
|
||||
for group in self.optimizer.param_groups
|
||||
]
|
||||
if self.last_epoch not in self.milestones:
|
||||
return [group['lr'] for group in self.optimizer.param_groups]
|
||||
return [
|
||||
group['lr'] * self.gamma**self.milestones[self.last_epoch]
|
||||
for group in self.optimizer.param_groups
|
||||
]
|
||||
|
||||
|
||||
def get_position_from_periods(iteration, cumulative_period):
|
||||
"""Get the position from a period list.
|
||||
|
||||
It will return the index of the right-closest number in the period list.
|
||||
For example, the cumulative_period = [100, 200, 300, 400],
|
||||
if iteration == 50, return 0;
|
||||
if iteration == 210, return 2;
|
||||
if iteration == 300, return 2.
|
||||
|
||||
Args:
|
||||
iteration (int): Current iteration.
|
||||
cumulative_period (list[int]): Cumulative period list.
|
||||
|
||||
Returns:
|
||||
int: The position of the right-closest number in the period list.
|
||||
"""
|
||||
for i, period in enumerate(cumulative_period):
|
||||
if iteration <= period:
|
||||
return i
|
||||
|
||||
|
||||
class CosineAnnealingRestartLR(_LRScheduler):
|
||||
""" Cosine annealing with restarts learning rate scheme.
|
||||
|
||||
An example of config:
|
||||
periods = [10, 10, 10, 10]
|
||||
restart_weights = [1, 0.5, 0.5, 0.5]
|
||||
eta_min=1e-7
|
||||
|
||||
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
||||
scheduler will restart with the weights in restart_weights.
|
||||
|
||||
Args:
|
||||
optimizer (torch.nn.optimizer): Torch optimizer.
|
||||
periods (list): Period for each cosine anneling cycle.
|
||||
restart_weights (list): Restart weights at each restart iteration.
|
||||
Default: [1].
|
||||
eta_min (float): The mimimum lr. Default: 0.
|
||||
last_epoch (int): Used in _LRScheduler. Default: -1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
periods,
|
||||
restart_weights=(1),
|
||||
eta_min=0,
|
||||
last_epoch=-1):
|
||||
self.periods = periods
|
||||
self.restart_weights = restart_weights
|
||||
self.eta_min = eta_min
|
||||
assert (len(self.periods) == len(self.restart_weights)
|
||||
), 'periods and restart_weights should have the same length.'
|
||||
self.cumulative_period = [
|
||||
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
||||
]
|
||||
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
idx = get_position_from_periods(self.last_epoch,
|
||||
self.cumulative_period)
|
||||
current_weight = self.restart_weights[idx]
|
||||
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
||||
current_period = self.periods[idx]
|
||||
|
||||
return [
|
||||
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
|
||||
(1 + math.cos(math.pi * (
|
||||
(self.last_epoch - nearest_restart) / current_period)))
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: logo_class.py
|
||||
# Created Date: Tuesday June 29th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 11th October 2021 12:39:55 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
class logo_class:
|
||||
|
||||
@staticmethod
|
||||
def print_group_logo():
|
||||
logo_str = """
|
||||
|
||||
███╗ ██╗██████╗ ███████╗██╗ ██████╗ ███████╗ ██╗████████╗██╗ ██╗
|
||||
████╗ ██║██╔══██╗██╔════╝██║██╔════╝ ██╔════╝ ██║╚══██╔══╝██║ ██║
|
||||
██╔██╗ ██║██████╔╝███████╗██║██║ ███╗ ███████╗ ██║ ██║ ██║ ██║
|
||||
██║╚██╗██║██╔══██╗╚════██║██║██║ ██║ ╚════██║██ ██║ ██║ ██║ ██║
|
||||
██║ ╚████║██║ ██║███████║██║╚██████╔╝ ███████║╚█████╔╝ ██║ ╚██████╔╝
|
||||
╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═════╝ ╚══════╝ ╚════╝ ╚═╝ ╚═════╝
|
||||
Neural Rendering Special Interesting Group of SJTU
|
||||
|
||||
"""
|
||||
print(logo_str)
|
||||
|
||||
@staticmethod
|
||||
def print_start_training():
|
||||
logo_str = """
|
||||
_____ __ __ ______ _ _
|
||||
/ ___/ / /_ ____ _ _____ / /_ /_ __/_____ ____ _ (_)____ (_)____ ____ _
|
||||
\__ \ / __// __ `// ___// __/ / / / ___// __ `// // __ \ / // __ \ / __ `/
|
||||
___/ // /_ / /_/ // / / /_ / / / / / /_/ // // / / // // / / // /_/ /
|
||||
/____/ \__/ \__,_//_/ \__/ /_/ /_/ \__,_//_//_/ /_//_//_/ /_/ \__, /
|
||||
/____/
|
||||
"""
|
||||
print(logo_str)
|
||||
|
||||
if __name__=="__main__":
|
||||
# logo_class.print_group_logo()
|
||||
logo_class.print_start_training()
|
||||
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Reporter.py
|
||||
# Created Date: Tuesday September 24th 2019
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 4th July 2021 11:50:12 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2019 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import datetime
|
||||
import os
|
||||
|
||||
class Reporter:
|
||||
def __init__(self,reportPath):
|
||||
self.path = reportPath
|
||||
self.withTimeStamp = False
|
||||
self.index = 1
|
||||
self.timeStrFormat = '%Y-%m-%d %H:%M:%S'
|
||||
timeStr = datetime.datetime.strftime(datetime.datetime.now(),'%Y%m%d%H%M%S')
|
||||
self.path = self.path + "-%s.log"%timeStr
|
||||
if not os.path.exists(self.path):
|
||||
f = open(self.path,'w')
|
||||
f.close()
|
||||
|
||||
def writeInfo(self,strLine):
|
||||
with open(self.path,'a+') as logf:
|
||||
timeStr = datetime.datetime.strftime(datetime.datetime.now(),self.timeStrFormat)
|
||||
logf.writelines("[%d]-[%s]-[info] %s\n"%(self.index,timeStr,strLine))
|
||||
self.index += 1
|
||||
|
||||
def writeConfig(self,config):
|
||||
with open(self.path,'a+') as logf:
|
||||
for item in config.items():
|
||||
text = "[%d]-[parameters] %s--%s\n"%(self.index,item[0],str(item[1]))
|
||||
logf.writelines(text)
|
||||
self.index +=1
|
||||
|
||||
def writeModel(self,modelText):
|
||||
with open(self.path,'a+') as logf:
|
||||
logf.writelines("[%d]-[model] %s\n"%(self.index,modelText))
|
||||
self.index += 1
|
||||
|
||||
def writeRawInfo(self, strLine):
|
||||
with open(self.path,'a+') as logf:
|
||||
timeStr = datetime.datetime.strftime(datetime.datetime.now(),self.timeStrFormat)
|
||||
logf.writelines("[%d]-[info] %s\n"%(self.index,timeStr,strLine))
|
||||
self.index += 1
|
||||
|
||||
def writeTrainLog(self, epoch, step, logText):
|
||||
with open(self.path,'a+') as logf:
|
||||
timeStr = datetime.datetime.strftime(datetime.datetime.now(),self.timeStrFormat)
|
||||
logf.writelines("[%d]-[%s]-[logInfo]-epoch[%d]-step[%d] %s\n"%(self.index,timeStr,epoch,step,logText))
|
||||
self.index += 1
|
||||
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: save_heatmap.py
|
||||
# Created Date: Friday January 15th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 15th January 2021 10:23:13 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def SaveHeatmap(heatmaps, path, row=-1, dpi=72):
|
||||
"""
|
||||
The input tensor must be B X 1 X H X W
|
||||
"""
|
||||
batch_size = heatmaps.shape[0]
|
||||
temp_path = ".temp/"
|
||||
if not os.path.exists(temp_path):
|
||||
os.makedirs(temp_path)
|
||||
final_img = None
|
||||
if row < 1:
|
||||
col = batch_size
|
||||
row = 1
|
||||
else:
|
||||
col = batch_size // row
|
||||
if row * col <batch_size:
|
||||
col +=1
|
||||
|
||||
row_i = 0
|
||||
col_i = 0
|
||||
|
||||
for i in range(batch_size):
|
||||
img_path = os.path.join(temp_path,'temp_batch_{}.png'.format(i))
|
||||
sns.heatmap(heatmaps[i,0,:,:],vmin=0,vmax=heatmaps[i,0,:,:].max(),cbar=False)
|
||||
plt.savefig(img_path, dpi=dpi, bbox_inches = 'tight', pad_inches = 0)
|
||||
img = cv2.imread(img_path)
|
||||
if i == 0:
|
||||
H,W,C = img.shape
|
||||
final_img = np.zeros((H*row,W*col,C))
|
||||
final_img[H*row_i:H*(row_i+1),W*col_i:W*(col_i+1),:] = img
|
||||
col_i += 1
|
||||
if col_i >= col:
|
||||
col_i = 0
|
||||
row_i += 1
|
||||
cv2.imwrite(path,final_img)
|
||||
|
||||
if __name__ == "__main__":
|
||||
random_map = np.random.randn(16,1,10,10)
|
||||
SaveHeatmap(random_map,"./wocao.png",1)
|
||||
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: sshupload.py
|
||||
# Created Date: Tuesday September 24th 2019
|
||||
# Author: Lcx
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th January 2021 2:02:12 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2019 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import paramiko,os
|
||||
# ssh传输类:
|
||||
|
||||
class fileUploaderClass(object):
|
||||
def __init__(self,serverIp,userName,passWd,port=22):
|
||||
self.__ip__ = serverIp
|
||||
self.__userName__ = userName
|
||||
self.__passWd__ = passWd
|
||||
self.__port__ = port
|
||||
self.__ssh__ = paramiko.SSHClient()
|
||||
self.__ssh__.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
def sshScpPut(self,localFile,remoteFile):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
remoteDir = remoteFile.split("/")
|
||||
if remoteFile[0]=='/':
|
||||
sftp.chdir('/')
|
||||
|
||||
for item in remoteDir[0:-1]:
|
||||
if item == "":
|
||||
continue
|
||||
try:
|
||||
sftp.chdir(item)
|
||||
except:
|
||||
sftp.mkdir(item)
|
||||
sftp.chdir(item)
|
||||
sftp.put(localFile,remoteDir[-1])
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
print("ssh localfile:%s remotefile:%s success"%(localFile,remoteFile))
|
||||
|
||||
def sshScpGetNames(self,remoteDir):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
wocao = sftp.listdir(remoteDir)
|
||||
return wocao
|
||||
|
||||
def sshScpGet(self, remoteFile, localFile, showProgress=False):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
try:
|
||||
sftp.stat(remoteFile)
|
||||
print("Remote file exists!")
|
||||
except:
|
||||
print("Remote file does not exist!")
|
||||
return False
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
if showProgress:
|
||||
sftp.get(remoteFile, localFile,callback=self.__putCallBack__)
|
||||
else:
|
||||
sftp.get(remoteFile, localFile)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
return True
|
||||
|
||||
def __putCallBack__(self,transferred,total):
|
||||
print("current transferred %3.1f percent"%(transferred/total*100),end='\r')
|
||||
|
||||
def sshScpGetmd5(self, file_path):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
try:
|
||||
file = sftp.open(file_path, 'rb')
|
||||
res = (True,hashlib.new('md5', file.read()).hexdigest())
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
return res
|
||||
except:
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
return (False,None)
|
||||
|
||||
def sshScpRename(self, oldpath, newpath):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
sftp.rename(oldpath,newpath)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
print("ssh oldpath:%s newpath:%s success"%(oldpath,newpath))
|
||||
|
||||
def sshScpDelete(self,path):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
sftp.remove(path)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
print("ssh delete:%s success"%(path))
|
||||
|
||||
def sshScpDeleteDir(self,path):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
self.__rm__(sftp,path)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
|
||||
def __rm__(self,sftp,path):
|
||||
try:
|
||||
files = sftp.listdir(path=path)
|
||||
print(files)
|
||||
for f in files:
|
||||
filepath = os.path.join(path, f).replace('\\','/')
|
||||
self.__rm__(sftp,filepath)
|
||||
sftp.rmdir(path)
|
||||
print("ssh delete:%s success"%(path))
|
||||
except:
|
||||
print(path)
|
||||
sftp.remove(path)
|
||||
print("ssh delete:%s success"%(path))
|
||||
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: transfer_checkpoint.py
|
||||
# Created Date: Wednesday February 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 4th February 2021 1:27:09 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import init as init
|
||||
import os
|
||||
import numpy as np
|
||||
import scipy.io as io
|
||||
|
||||
class RepSRPlain_pixel(nn.Module):
|
||||
"""Networks consisting of Residual in Residual Dense Block, which is used
|
||||
in ESRGAN.
|
||||
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
||||
Currently, it supports x4 upsampling scale factor.
|
||||
Args:
|
||||
num_in_ch (int): Channel number of inputs.
|
||||
num_out_ch (int): Channel number of outputs.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
Default: 64
|
||||
num_block (int): Block number in the trunk network. Defaults: 23
|
||||
num_grow_ch (int): Channels for each growth. Default: 32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_in_ch,
|
||||
num_out_ch,
|
||||
num_feat=32,
|
||||
num_layer = 3,
|
||||
upsampling=4):
|
||||
super(RepSRPlain_pixel, self).__init__()
|
||||
|
||||
self.scale = upsampling
|
||||
self.ssqu = upsampling ** 2
|
||||
|
||||
self.rep1 = nn.Conv2d(num_in_ch, num_feat,3,1,1)
|
||||
self.rep2 = nn.Conv2d(num_feat, num_feat*2,3,1,1)
|
||||
self.rep3 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
|
||||
self.rep4 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
|
||||
self.rep5 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
|
||||
self.rep6 = nn.Conv2d(num_feat*2, num_feat,3,1,1)
|
||||
|
||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||
|
||||
|
||||
self.activator = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
# self.activator = nn.ReLU(inplace=True)
|
||||
|
||||
# default_init_weights([self.conv_up1,self.conv_up2,self.conv_hr,self.conv_last], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
f_d = self.activator(self.rep1(x))
|
||||
f_d = self.activator(self.rep2(f_d))
|
||||
f_d = self.activator(self.rep3(f_d))
|
||||
f_d = self.activator(self.rep4(f_d))
|
||||
f_d = self.activator(self.rep5(f_d))
|
||||
f_d = self.activator(self.rep6(f_d))
|
||||
|
||||
feat = self.activator(
|
||||
self.conv_up1(F.interpolate(f_d, scale_factor=2, mode='nearest')))
|
||||
feat = self.activator(
|
||||
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.activator(self.conv_hr(feat)))
|
||||
return out
|
||||
|
||||
def create_identity_conv(dim,kernel_size=3):
|
||||
zeros = torch.zeros((dim,dim,kernel_size,kernel_size)).cuda()
|
||||
for i_dim in range(dim):
|
||||
zeros[i_dim,i_dim,kernel_size//2,kernel_size//2] = 1.0
|
||||
return zeros
|
||||
|
||||
def fill_conv_kernel(in_tensor,kernel_size=3):
|
||||
shape = in_tensor.shape
|
||||
zeros = torch.zeros(shape[0],shape[1],kernel_size,kernel_size).cuda()
|
||||
for i_dim in range(shape[0]):
|
||||
zeros[i_dim,:,kernel_size//2,kernel_size//2] = in_tensor[i_dim,:,0,0]
|
||||
return zeros
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
|
||||
base_path = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\"
|
||||
version = "repsr_pixel_0"
|
||||
epoch = 73
|
||||
path_ckp= os.path.join(base_path,version,"checkpoints\\epoch%d_RepSR.pth"%epoch)
|
||||
path_plain_ckp= os.path.join(base_path,version,"checkpoints\\epoch%d_RepSR_Plain.pth"%epoch)
|
||||
network = RepSRPlain_pixel(3,
|
||||
3,
|
||||
64,
|
||||
3,
|
||||
4
|
||||
)
|
||||
network = network.cuda()
|
||||
|
||||
|
||||
|
||||
wocao = network.state_dict()
|
||||
# for data_key in wocao.keys():
|
||||
# print(data_key)
|
||||
# print(wocao[data_key].shape)
|
||||
wocao_cpk = torch.load(path_ckp)
|
||||
|
||||
# for data_key in wocao_cpk.keys():
|
||||
# print(data_key)
|
||||
# print(wocao_cpk[data_key].shape)
|
||||
name_list = ["rep1","rep2","rep3","rep4","rep5","rep6"]
|
||||
other_list = ["conv_up1","conv_up2","conv_hr","conv_last"]
|
||||
for i_name in name_list:
|
||||
temp= wocao_cpk[i_name+".conv3.weight"] + fill_conv_kernel(wocao_cpk[i_name+".conv1x1.weight"])
|
||||
wocao[i_name+".weight"] = temp
|
||||
temp= wocao_cpk[i_name+".conv3.bias"] + wocao_cpk[i_name+".conv1x1.bias"]
|
||||
wocao[i_name+".bias"] = temp
|
||||
|
||||
if wocao_cpk[i_name+".conv3.weight"].shape[0] == wocao_cpk[i_name+".conv3.weight"].shape[1]:
|
||||
print("include identity")
|
||||
temp = wocao[i_name+".weight"] + create_identity_conv(wocao_cpk[i_name+".conv3.weight"].shape[0])
|
||||
wocao[i_name+".weight"] = temp
|
||||
|
||||
for i_name in other_list:
|
||||
wocao[i_name+".weight"] = wocao_cpk[i_name+".weight"]
|
||||
wocao[i_name+".bias"] = wocao_cpk[i_name+".bias"]
|
||||
|
||||
torch.save(wocao,path_plain_ckp)
|
||||
|
||||
# wocao = torch.load(path_plain_ckp)
|
||||
# for data_key in wocao.keys():
|
||||
# result1 = wocao[data_key].cpu().numpy()
|
||||
# # np.savetxt(i_name+"_conv3_weight.txt",result1)
|
||||
# str_temp = ("%s"%data_key).replace(".","_")
|
||||
# io.savemat(str_temp+".mat",{str_temp:result1})
|
||||
|
||||
# for data_key in wocao.keys():
|
||||
# print(data_key)
|
||||
# print(wocao[data_key].shape)
|
||||
@@ -0,0 +1,335 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: utilities.py
|
||||
# Created Date: Monday April 6th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th October 2021 2:18:05 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
||||
# Gram Matrix
|
||||
def Gram(tensor: torch.Tensor):
|
||||
B, C, H, W = tensor.shape
|
||||
x = tensor.view(B, C, H*W)
|
||||
x_t = x.transpose(1, 2)
|
||||
return torch.bmm(x, x_t) / (C*H*W)
|
||||
|
||||
def build_tensorboard(summary_path):
|
||||
from tensorboardX import SummaryWriter
|
||||
# from logger import Logger
|
||||
# self.logger = Logger(self.log_path)
|
||||
writer = SummaryWriter(log_dir=summary_path)
|
||||
return writer
|
||||
|
||||
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
|
||||
def tensor2img(img_tensor):
|
||||
"""
|
||||
Input image tensor shape must be [B C H W]
|
||||
the return image numpy array shape is [B H W C]
|
||||
"""
|
||||
res = img_tensor.numpy()
|
||||
res = (res + 1) / 2
|
||||
res = np.clip(res, 0.0, 1.0)
|
||||
res = res * 255
|
||||
res = res.transpose((0,2,3,1))
|
||||
return res
|
||||
|
||||
def img2tensor255(path, max_size=None):
|
||||
|
||||
image = Image.open(path)
|
||||
# Rescale the image
|
||||
if (max_size==None):
|
||||
itot_t = transforms.Compose([
|
||||
#transforms.ToPILImage(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x.mul(255))
|
||||
])
|
||||
else:
|
||||
H, W, C = image.shape
|
||||
image_size = tuple([int((float(max_size) / max([H,W]))*x) for x in [H, W]])
|
||||
itot_t = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize(image_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x.mul(255))
|
||||
])
|
||||
|
||||
# Convert image to tensor
|
||||
tensor = itot_t(image)
|
||||
|
||||
# Add the batch_size dimension
|
||||
tensor = tensor.unsqueeze(dim=0)
|
||||
return tensor
|
||||
|
||||
def img2tensor255crop(path, crop_size=256):
|
||||
|
||||
image = Image.open(path)
|
||||
# Rescale the image
|
||||
itot_t = transforms.Compose([
|
||||
transforms.CenterCrop(crop_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x.mul(255))
|
||||
])
|
||||
|
||||
# Convert image to tensor
|
||||
tensor = itot_t(image)
|
||||
|
||||
# Add the batch_size dimension
|
||||
tensor = tensor.unsqueeze(dim=0)
|
||||
return tensor
|
||||
|
||||
# def img2tensor255(path, crop_size=None):
|
||||
# """
|
||||
# Input image tensor shape must be [B C H W]
|
||||
# the return image numpy array shape is [B H W C]
|
||||
# """
|
||||
# img = cv2.imread(path)
|
||||
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float)
|
||||
# img = torch.from_numpy(img).transpose((2,0,1)).unsqueeze(0)
|
||||
# return img
|
||||
|
||||
def img2tensor1(img_tensor):
|
||||
"""
|
||||
Input image tensor shape must be [B C H W]
|
||||
the return image numpy array shape is [B H W C]
|
||||
"""
|
||||
res = img_tensor.numpy()
|
||||
res = (res + 1) / 2
|
||||
res = np.clip(res, 0.0, 1.0)
|
||||
res = res * 255
|
||||
res = res.transpose((0,2,3,1))
|
||||
return res
|
||||
|
||||
def _convert_input_type_range(img):
|
||||
"""Convert the type and range of the input image.
|
||||
|
||||
It converts the input image to np.float32 type and range of [0, 1].
|
||||
It is mainly used for pre-processing the input image in colorspace
|
||||
convertion functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with type of np.float32 and range of
|
||||
[0, 1].
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = img.astype(np.float32)
|
||||
if img_type == np.float32:
|
||||
pass
|
||||
elif img_type == np.uint8:
|
||||
img /= 255.
|
||||
else:
|
||||
raise TypeError('The img type should be np.float32 or np.uint8, '
|
||||
f'but got {img_type}')
|
||||
return img
|
||||
|
||||
def _convert_output_type_range(img, dst_type):
|
||||
"""Convert the type and range of the image according to dst_type.
|
||||
|
||||
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
||||
images will be converted to np.uint8 type with range [0, 255]. If
|
||||
`dst_type` is np.float32, it converts the image to np.float32 type with
|
||||
range [0, 1].
|
||||
It is mainly used for post-processing images in colorspace convertion
|
||||
functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The image to be converted with np.float32 type and
|
||||
range [0, 255].
|
||||
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
||||
converts the image to np.uint8 type with range [0, 255]. If
|
||||
dst_type is np.float32, it converts the image to np.float32 type
|
||||
with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with desired type and range.
|
||||
"""
|
||||
if dst_type not in (np.uint8, np.float32):
|
||||
raise TypeError('The dst_type should be np.float32 or np.uint8, '
|
||||
f'but got {dst_type}')
|
||||
if dst_type == np.uint8:
|
||||
img = img.round()
|
||||
else:
|
||||
img /= 255.
|
||||
return img.astype(dst_type)
|
||||
|
||||
|
||||
def bgr2ycbcr(img, y_only=False):
|
||||
"""Convert a BGR image to YCbCr image.
|
||||
|
||||
The bgr version of rgb2ycbcr.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
# out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
||||
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 #RGB
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
||||
[65.481, -37.797, 112.0]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
def to_y_channel(img):
|
||||
"""Change to Y channel of YCbCr.
|
||||
|
||||
Args:
|
||||
img (ndarray): Images with range [0, 255].
|
||||
|
||||
Returns:
|
||||
(ndarray): Images with range [0, 255] (float type) without round.
|
||||
"""
|
||||
img = img.astype(np.float32) / 255.
|
||||
if img.ndim == 3 and img.shape[2] == 3:
|
||||
img = bgr2ycbcr(img, y_only=True)
|
||||
img = img[..., None]
|
||||
return img * 255.
|
||||
|
||||
def calculate_psnr(img1,
|
||||
img2,
|
||||
# crop_border=0,
|
||||
test_y_channel=True):
|
||||
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
||||
|
||||
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
||||
|
||||
Args:
|
||||
img1 (ndarray): Images with range [0, 255].
|
||||
img2 (ndarray): Images with range [0, 255].
|
||||
crop_border (int): Cropped pixels in each edge of an image. These
|
||||
pixels are not involved in the PSNR calculation.
|
||||
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
||||
Default: 'HWC'.
|
||||
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
||||
|
||||
Returns:
|
||||
float: psnr result.
|
||||
"""
|
||||
|
||||
# if crop_border != 0:
|
||||
# img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
# img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
|
||||
if test_y_channel:
|
||||
img1 = to_y_channel(img1)
|
||||
img2 = to_y_channel(img2)
|
||||
|
||||
mse = np.mean((img1 - img2)**2)
|
||||
if mse == 0:
|
||||
return float('inf')
|
||||
return 20. * np.log10(255. / np.sqrt(mse))
|
||||
|
||||
|
||||
def _ssim(img1, img2):
|
||||
"""Calculate SSIM (structural similarity) for one channel images.
|
||||
|
||||
It is called by func:`calculate_ssim`.
|
||||
|
||||
Args:
|
||||
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
||||
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
||||
|
||||
Returns:
|
||||
float: ssim result.
|
||||
"""
|
||||
|
||||
C1 = (0.01 * 255)**2
|
||||
C2 = (0.03 * 255)**2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) *
|
||||
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
def calculate_ssim(img1,
|
||||
img2,
|
||||
test_y_channel=True):
|
||||
"""Calculate SSIM (structural similarity).
|
||||
|
||||
Ref:
|
||||
Image quality assessment: From error visibility to structural similarity
|
||||
|
||||
The results are the same as that of the official released MATLAB code in
|
||||
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
||||
|
||||
For three-channel images, SSIM is calculated for each channel and then
|
||||
averaged.
|
||||
|
||||
Args:
|
||||
img1 (ndarray): Images with range [0, 255].
|
||||
img2 (ndarray): Images with range [0, 255].
|
||||
crop_border (int): Cropped pixels in each edge of an image. These
|
||||
pixels are not involved in the SSIM calculation.
|
||||
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
||||
Default: 'HWC'.
|
||||
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
||||
|
||||
Returns:
|
||||
float: ssim result.
|
||||
"""
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
|
||||
if test_y_channel:
|
||||
img1 = to_y_channel(img1)
|
||||
img2 = to_y_channel(img2)
|
||||
|
||||
ssims = []
|
||||
for i in range(img1.shape[2]):
|
||||
ssims.append(_ssim(img1[..., i], img2[..., i]))
|
||||
return np.array(ssims).mean()
|
||||
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Config_from_yaml.py
|
||||
# Created Date: Monday February 17th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 28th February 2020 4:30:01 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import yaml
|
||||
|
||||
def getConfigYaml(yaml_file):
|
||||
with open(yaml_file, 'r') as config_file:
|
||||
try:
|
||||
config_dict = yaml.load(config_file, Loader=yaml.FullLoader)
|
||||
return config_dict
|
||||
except ValueError:
|
||||
print('INVALID YAML file format.. Please provide a good yaml file')
|
||||
exit(-1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
a= getConfigYaml("./train_256.yaml")
|
||||
sys_state = {}
|
||||
for item in a.items():
|
||||
sys_state[item[0]] = item[1]
|
||||
Reference in New Issue
Block a user