init
This commit is contained in:
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: checkpoint_manager.py
|
||||
# Created Date: Sunday July 12th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 27th July 2020 11:01:16 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
class CheckpointManager(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
self.maxCkpNum = -1
|
||||
self.ckpList = []
|
||||
self.modelsDict = {} # key model name, value model
|
||||
self.currentEpoch = 0
|
||||
|
||||
|
||||
def registerModels(self):
|
||||
pass
|
||||
|
||||
def __updateCkpList__(self):
|
||||
pass
|
||||
|
||||
def saveModel(self):
|
||||
pass
|
||||
|
||||
def loadModel(self):
|
||||
pass
|
||||
|
||||
def saveLR(self):
|
||||
pass
|
||||
|
||||
def loadLR(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
def loadPretrainedModel(chechpointStep,modelSavePath,gModel,dModel,cuda,**kwargs):
|
||||
gModel.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_LocalG.pth'.format(chechpointStep)),map_location=cuda))
|
||||
dModel.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_GlobalD.pth'.format(chechpointStep)),map_location=cuda))
|
||||
print('loaded trained models (epoch: {}) successful!'.format(chechpointStep))
|
||||
if not kwargs:
|
||||
return
|
||||
for k,v in kwargs.items():
|
||||
v.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_{}.pth'.format(chechpointStep,k)),map_location=cuda))
|
||||
print("Loaded param %s"%k)
|
||||
|
||||
def loadPretrainedModelByDict(chechpointStep,modelSavePath,cuda,**kwargs):
|
||||
if not kwargs:
|
||||
return
|
||||
for k,v in kwargs.items():
|
||||
v.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_{}.pth'.format(chechpointStep,k)),map_location=cuda))
|
||||
print("Loaded param %s"%k)
|
||||
|
||||
def loadLR(chechpointStep,modelSavePath,dlr,glr):
|
||||
glr.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_LocalGlr.pth'.format(chechpointStep))))
|
||||
dlr.load_state_dict(torch.load(os.path.join(
|
||||
modelSavePath, 'Epoch{}_GlobalDlr.pth'.format(chechpointStep))))
|
||||
print("Generator learning rate:%f"%glr.get_lr()[0])
|
||||
print("Discriminator learning rate:%f"%dlr.get_lr()[0])
|
||||
|
||||
def saveLR(step,modelSavePath,dlr,glr):
|
||||
torch.save(glr.state_dict(),os.path.join(modelSavePath, 'Epoch{}_LocalGlr.pth'.format(step + 1)))
|
||||
torch.save(dlr.state_dict(),os.path.join(modelSavePath, 'Epoch{}_GlobalDlr.pth'.format(step + 1)))
|
||||
print("Epoch:{} models learning rate saved!".format(step+1))
|
||||
|
||||
|
||||
def saveModel(step,modelSavePath,gModel,dModel,**kwargs):
|
||||
torch.save(gModel.state_dict(),
|
||||
os.path.join(modelSavePath, 'Epoch{}_LocalG.pth'.format(step + 1)))
|
||||
torch.save(dModel.state_dict(),
|
||||
os.path.join(modelSavePath, 'Epoch{}_GlobalD.pth'.format(step + 1)))
|
||||
print("Epoch:{} models saved!".format(step+1))
|
||||
if not kwargs:
|
||||
return
|
||||
for k,v in kwargs.items():
|
||||
torch.save(v.state_dict(),
|
||||
os.path.join(modelSavePath, 'Epoch{}_{}.pth'.format(step + 1,k)))
|
||||
print("Epoch:{} models param {} saved!".format(step+1,k))
|
||||
|
||||
def saveModelByDict(step,modelSavePath,**kwargs):
|
||||
if not kwargs:
|
||||
return
|
||||
for k,v in kwargs.items():
|
||||
torch.save(v.state_dict(),
|
||||
os.path.join(modelSavePath, 'Epoch{}_{}.pth'.format(step + 1,k)))
|
||||
print("Epoch:{} models param {} saved!".format(step+1,k))
|
||||
@@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: figure.py
|
||||
# Created Date: Tuesday October 13th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 13th October 2020 2:54:30 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def plot_loss_curve(losses, save_path):
|
||||
for key in losses.keys():
|
||||
plt.plot(range(len(losses[key])), losses[key], label=key)
|
||||
plt.xlabel('iteration')
|
||||
plt.title(f'loss curve')
|
||||
plt.legend()
|
||||
plt.savefig(save_path)
|
||||
plt.clf()
|
||||
@@ -0,0 +1,15 @@
|
||||
import json
|
||||
|
||||
|
||||
def readConfig(path):
|
||||
with open(path,'r') as cf:
|
||||
nodelocaltionstr = cf.read()
|
||||
nodelocaltioninf = json.loads(nodelocaltionstr)
|
||||
if isinstance(nodelocaltioninf,str):
|
||||
nodelocaltioninf = json.loads(nodelocaltioninf)
|
||||
return nodelocaltioninf
|
||||
|
||||
def writeConfig(path, info):
|
||||
with open(path, 'w') as cf:
|
||||
configjson = json.dumps(info, indent=4)
|
||||
cf.writelines(configjson)
|
||||
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: learningrate_scheduler.py
|
||||
# Created Date: Tuesday January 5th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 5th January 2021 2:04:00 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
# Refer to basicSR https://github.com/xinntao/BasicSR
|
||||
|
||||
|
||||
import math
|
||||
from collections import Counter
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class MultiStepRestartLR(_LRScheduler):
|
||||
""" MultiStep with restarts learning rate scheme.
|
||||
|
||||
Args:
|
||||
optimizer (torch.nn.optimizer): Torch optimizer.
|
||||
milestones (list): Iterations that will decrease learning rate.
|
||||
gamma (float): Decrease ratio. Default: 0.1.
|
||||
restarts (list): Restart iterations. Default: [0].
|
||||
restart_weights (list): Restart weights at each restart iteration.
|
||||
Default: [1].
|
||||
last_epoch (int): Used in _LRScheduler. Default: -1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
milestones,
|
||||
gamma=0.1,
|
||||
restarts=(0,),
|
||||
restart_weights=(1,),
|
||||
last_epoch=-1):
|
||||
self.milestones = Counter(milestones)
|
||||
self.gamma = gamma
|
||||
self.restarts = restarts
|
||||
self.restart_weights = restart_weights
|
||||
print(type(self.restarts),self.restarts)
|
||||
print(type(self.restart_weights),self.restart_weights)
|
||||
assert len(self.restarts) == len(
|
||||
self.restart_weights), 'restarts and their weights do not match.'
|
||||
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch in self.restarts:
|
||||
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
||||
return [
|
||||
group['initial_lr'] * weight
|
||||
for group in self.optimizer.param_groups
|
||||
]
|
||||
if self.last_epoch not in self.milestones:
|
||||
return [group['lr'] for group in self.optimizer.param_groups]
|
||||
return [
|
||||
group['lr'] * self.gamma**self.milestones[self.last_epoch]
|
||||
for group in self.optimizer.param_groups
|
||||
]
|
||||
|
||||
|
||||
def get_position_from_periods(iteration, cumulative_period):
|
||||
"""Get the position from a period list.
|
||||
|
||||
It will return the index of the right-closest number in the period list.
|
||||
For example, the cumulative_period = [100, 200, 300, 400],
|
||||
if iteration == 50, return 0;
|
||||
if iteration == 210, return 2;
|
||||
if iteration == 300, return 2.
|
||||
|
||||
Args:
|
||||
iteration (int): Current iteration.
|
||||
cumulative_period (list[int]): Cumulative period list.
|
||||
|
||||
Returns:
|
||||
int: The position of the right-closest number in the period list.
|
||||
"""
|
||||
for i, period in enumerate(cumulative_period):
|
||||
if iteration <= period:
|
||||
return i
|
||||
|
||||
|
||||
class CosineAnnealingRestartLR(_LRScheduler):
|
||||
""" Cosine annealing with restarts learning rate scheme.
|
||||
|
||||
An example of config:
|
||||
periods = [10, 10, 10, 10]
|
||||
restart_weights = [1, 0.5, 0.5, 0.5]
|
||||
eta_min=1e-7
|
||||
|
||||
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
||||
scheduler will restart with the weights in restart_weights.
|
||||
|
||||
Args:
|
||||
optimizer (torch.nn.optimizer): Torch optimizer.
|
||||
periods (list): Period for each cosine anneling cycle.
|
||||
restart_weights (list): Restart weights at each restart iteration.
|
||||
Default: [1].
|
||||
eta_min (float): The mimimum lr. Default: 0.
|
||||
last_epoch (int): Used in _LRScheduler. Default: -1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
periods,
|
||||
restart_weights=(1),
|
||||
eta_min=0,
|
||||
last_epoch=-1):
|
||||
self.periods = periods
|
||||
self.restart_weights = restart_weights
|
||||
self.eta_min = eta_min
|
||||
assert (len(self.periods) == len(self.restart_weights)
|
||||
), 'periods and restart_weights should have the same length.'
|
||||
self.cumulative_period = [
|
||||
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
||||
]
|
||||
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
idx = get_position_from_periods(self.last_epoch,
|
||||
self.cumulative_period)
|
||||
current_weight = self.restart_weights[idx]
|
||||
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
||||
current_period = self.periods[idx]
|
||||
|
||||
return [
|
||||
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
|
||||
(1 + math.cos(math.pi * (
|
||||
(self.last_epoch - nearest_restart) / current_period)))
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: logo_class.py
|
||||
# Created Date: Tuesday June 29th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 11th October 2021 12:39:55 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
class logo_class:
|
||||
|
||||
@staticmethod
|
||||
def print_group_logo():
|
||||
logo_str = """
|
||||
|
||||
███╗ ██╗██████╗ ███████╗██╗ ██████╗ ███████╗ ██╗████████╗██╗ ██╗
|
||||
████╗ ██║██╔══██╗██╔════╝██║██╔════╝ ██╔════╝ ██║╚══██╔══╝██║ ██║
|
||||
██╔██╗ ██║██████╔╝███████╗██║██║ ███╗ ███████╗ ██║ ██║ ██║ ██║
|
||||
██║╚██╗██║██╔══██╗╚════██║██║██║ ██║ ╚════██║██ ██║ ██║ ██║ ██║
|
||||
██║ ╚████║██║ ██║███████║██║╚██████╔╝ ███████║╚█████╔╝ ██║ ╚██████╔╝
|
||||
╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═════╝ ╚══════╝ ╚════╝ ╚═╝ ╚═════╝
|
||||
Neural Rendering Special Interesting Group of SJTU
|
||||
|
||||
"""
|
||||
print(logo_str)
|
||||
|
||||
@staticmethod
|
||||
def print_start_training():
|
||||
logo_str = """
|
||||
_____ __ __ ______ _ _
|
||||
/ ___/ / /_ ____ _ _____ / /_ /_ __/_____ ____ _ (_)____ (_)____ ____ _
|
||||
\__ \ / __// __ `// ___// __/ / / / ___// __ `// // __ \ / // __ \ / __ `/
|
||||
___/ // /_ / /_/ // / / /_ / / / / / /_/ // // / / // // / / // /_/ /
|
||||
/____/ \__/ \__,_//_/ \__/ /_/ /_/ \__,_//_//_/ /_//_//_/ /_/ \__, /
|
||||
/____/
|
||||
"""
|
||||
print(logo_str)
|
||||
|
||||
if __name__=="__main__":
|
||||
# logo_class.print_group_logo()
|
||||
logo_class.print_start_training()
|
||||
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Reporter.py
|
||||
# Created Date: Tuesday September 24th 2019
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 4th July 2021 11:50:12 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2019 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import datetime
|
||||
import os
|
||||
|
||||
class Reporter:
|
||||
def __init__(self,reportPath):
|
||||
self.path = reportPath
|
||||
self.withTimeStamp = False
|
||||
self.index = 1
|
||||
self.timeStrFormat = '%Y-%m-%d %H:%M:%S'
|
||||
timeStr = datetime.datetime.strftime(datetime.datetime.now(),'%Y%m%d%H%M%S')
|
||||
self.path = self.path + "-%s.log"%timeStr
|
||||
if not os.path.exists(self.path):
|
||||
f = open(self.path,'w')
|
||||
f.close()
|
||||
|
||||
def writeInfo(self,strLine):
|
||||
with open(self.path,'a+') as logf:
|
||||
timeStr = datetime.datetime.strftime(datetime.datetime.now(),self.timeStrFormat)
|
||||
logf.writelines("[%d]-[%s]-[info] %s\n"%(self.index,timeStr,strLine))
|
||||
self.index += 1
|
||||
|
||||
def writeConfig(self,config):
|
||||
with open(self.path,'a+') as logf:
|
||||
for item in config.items():
|
||||
text = "[%d]-[parameters] %s--%s\n"%(self.index,item[0],str(item[1]))
|
||||
logf.writelines(text)
|
||||
self.index +=1
|
||||
|
||||
def writeModel(self,modelText):
|
||||
with open(self.path,'a+') as logf:
|
||||
logf.writelines("[%d]-[model] %s\n"%(self.index,modelText))
|
||||
self.index += 1
|
||||
|
||||
def writeRawInfo(self, strLine):
|
||||
with open(self.path,'a+') as logf:
|
||||
timeStr = datetime.datetime.strftime(datetime.datetime.now(),self.timeStrFormat)
|
||||
logf.writelines("[%d]-[info] %s\n"%(self.index,timeStr,strLine))
|
||||
self.index += 1
|
||||
|
||||
def writeTrainLog(self, epoch, step, logText):
|
||||
with open(self.path,'a+') as logf:
|
||||
timeStr = datetime.datetime.strftime(datetime.datetime.now(),self.timeStrFormat)
|
||||
logf.writelines("[%d]-[%s]-[logInfo]-epoch[%d]-step[%d] %s\n"%(self.index,timeStr,epoch,step,logText))
|
||||
self.index += 1
|
||||
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: save_heatmap.py
|
||||
# Created Date: Friday January 15th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 15th January 2021 10:23:13 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def SaveHeatmap(heatmaps, path, row=-1, dpi=72):
|
||||
"""
|
||||
The input tensor must be B X 1 X H X W
|
||||
"""
|
||||
batch_size = heatmaps.shape[0]
|
||||
temp_path = ".temp/"
|
||||
if not os.path.exists(temp_path):
|
||||
os.makedirs(temp_path)
|
||||
final_img = None
|
||||
if row < 1:
|
||||
col = batch_size
|
||||
row = 1
|
||||
else:
|
||||
col = batch_size // row
|
||||
if row * col <batch_size:
|
||||
col +=1
|
||||
|
||||
row_i = 0
|
||||
col_i = 0
|
||||
|
||||
for i in range(batch_size):
|
||||
img_path = os.path.join(temp_path,'temp_batch_{}.png'.format(i))
|
||||
sns.heatmap(heatmaps[i,0,:,:],vmin=0,vmax=heatmaps[i,0,:,:].max(),cbar=False)
|
||||
plt.savefig(img_path, dpi=dpi, bbox_inches = 'tight', pad_inches = 0)
|
||||
img = cv2.imread(img_path)
|
||||
if i == 0:
|
||||
H,W,C = img.shape
|
||||
final_img = np.zeros((H*row,W*col,C))
|
||||
final_img[H*row_i:H*(row_i+1),W*col_i:W*(col_i+1),:] = img
|
||||
col_i += 1
|
||||
if col_i >= col:
|
||||
col_i = 0
|
||||
row_i += 1
|
||||
cv2.imwrite(path,final_img)
|
||||
|
||||
if __name__ == "__main__":
|
||||
random_map = np.random.randn(16,1,10,10)
|
||||
SaveHeatmap(random_map,"./wocao.png",1)
|
||||
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: sshupload.py
|
||||
# Created Date: Tuesday September 24th 2019
|
||||
# Author: Lcx
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th January 2021 2:02:12 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2019 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import paramiko,os
|
||||
# ssh传输类:
|
||||
|
||||
class fileUploaderClass(object):
|
||||
def __init__(self,serverIp,userName,passWd,port=22):
|
||||
self.__ip__ = serverIp
|
||||
self.__userName__ = userName
|
||||
self.__passWd__ = passWd
|
||||
self.__port__ = port
|
||||
self.__ssh__ = paramiko.SSHClient()
|
||||
self.__ssh__.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
def sshScpPut(self,localFile,remoteFile):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
remoteDir = remoteFile.split("/")
|
||||
if remoteFile[0]=='/':
|
||||
sftp.chdir('/')
|
||||
|
||||
for item in remoteDir[0:-1]:
|
||||
if item == "":
|
||||
continue
|
||||
try:
|
||||
sftp.chdir(item)
|
||||
except:
|
||||
sftp.mkdir(item)
|
||||
sftp.chdir(item)
|
||||
sftp.put(localFile,remoteDir[-1])
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
print("ssh localfile:%s remotefile:%s success"%(localFile,remoteFile))
|
||||
|
||||
def sshScpGetNames(self,remoteDir):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
wocao = sftp.listdir(remoteDir)
|
||||
return wocao
|
||||
|
||||
def sshScpGet(self, remoteFile, localFile, showProgress=False):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
try:
|
||||
sftp.stat(remoteFile)
|
||||
print("Remote file exists!")
|
||||
except:
|
||||
print("Remote file does not exist!")
|
||||
return False
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
if showProgress:
|
||||
sftp.get(remoteFile, localFile,callback=self.__putCallBack__)
|
||||
else:
|
||||
sftp.get(remoteFile, localFile)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
return True
|
||||
|
||||
def __putCallBack__(self,transferred,total):
|
||||
print("current transferred %3.1f percent"%(transferred/total*100),end='\r')
|
||||
|
||||
def sshScpGetmd5(self, file_path):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__, self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
try:
|
||||
file = sftp.open(file_path, 'rb')
|
||||
res = (True,hashlib.new('md5', file.read()).hexdigest())
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
return res
|
||||
except:
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
return (False,None)
|
||||
|
||||
def sshScpRename(self, oldpath, newpath):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
sftp.rename(oldpath,newpath)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
print("ssh oldpath:%s newpath:%s success"%(oldpath,newpath))
|
||||
|
||||
def sshScpDelete(self,path):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
sftp.remove(path)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
print("ssh delete:%s success"%(path))
|
||||
|
||||
def sshScpDeleteDir(self,path):
|
||||
self.__ssh__.connect(self.__ip__, self.__port__ , self.__userName__, self.__passWd__)
|
||||
sftp = paramiko.SFTPClient.from_transport(self.__ssh__.get_transport())
|
||||
sftp = self.__ssh__.open_sftp()
|
||||
self.__rm__(sftp,path)
|
||||
sftp.close()
|
||||
self.__ssh__.close()
|
||||
|
||||
def __rm__(self,sftp,path):
|
||||
try:
|
||||
files = sftp.listdir(path=path)
|
||||
print(files)
|
||||
for f in files:
|
||||
filepath = os.path.join(path, f).replace('\\','/')
|
||||
self.__rm__(sftp,filepath)
|
||||
sftp.rmdir(path)
|
||||
print("ssh delete:%s success"%(path))
|
||||
except:
|
||||
print(path)
|
||||
sftp.remove(path)
|
||||
print("ssh delete:%s success"%(path))
|
||||
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: transfer_checkpoint.py
|
||||
# Created Date: Wednesday February 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 4th February 2021 1:27:09 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import init as init
|
||||
import os
|
||||
import numpy as np
|
||||
import scipy.io as io
|
||||
|
||||
class RepSRPlain_pixel(nn.Module):
|
||||
"""Networks consisting of Residual in Residual Dense Block, which is used
|
||||
in ESRGAN.
|
||||
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
||||
Currently, it supports x4 upsampling scale factor.
|
||||
Args:
|
||||
num_in_ch (int): Channel number of inputs.
|
||||
num_out_ch (int): Channel number of outputs.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
Default: 64
|
||||
num_block (int): Block number in the trunk network. Defaults: 23
|
||||
num_grow_ch (int): Channels for each growth. Default: 32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_in_ch,
|
||||
num_out_ch,
|
||||
num_feat=32,
|
||||
num_layer = 3,
|
||||
upsampling=4):
|
||||
super(RepSRPlain_pixel, self).__init__()
|
||||
|
||||
self.scale = upsampling
|
||||
self.ssqu = upsampling ** 2
|
||||
|
||||
self.rep1 = nn.Conv2d(num_in_ch, num_feat,3,1,1)
|
||||
self.rep2 = nn.Conv2d(num_feat, num_feat*2,3,1,1)
|
||||
self.rep3 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
|
||||
self.rep4 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
|
||||
self.rep5 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
|
||||
self.rep6 = nn.Conv2d(num_feat*2, num_feat,3,1,1)
|
||||
|
||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||
|
||||
|
||||
self.activator = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
# self.activator = nn.ReLU(inplace=True)
|
||||
|
||||
# default_init_weights([self.conv_up1,self.conv_up2,self.conv_hr,self.conv_last], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
f_d = self.activator(self.rep1(x))
|
||||
f_d = self.activator(self.rep2(f_d))
|
||||
f_d = self.activator(self.rep3(f_d))
|
||||
f_d = self.activator(self.rep4(f_d))
|
||||
f_d = self.activator(self.rep5(f_d))
|
||||
f_d = self.activator(self.rep6(f_d))
|
||||
|
||||
feat = self.activator(
|
||||
self.conv_up1(F.interpolate(f_d, scale_factor=2, mode='nearest')))
|
||||
feat = self.activator(
|
||||
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.activator(self.conv_hr(feat)))
|
||||
return out
|
||||
|
||||
def create_identity_conv(dim,kernel_size=3):
|
||||
zeros = torch.zeros((dim,dim,kernel_size,kernel_size)).cuda()
|
||||
for i_dim in range(dim):
|
||||
zeros[i_dim,i_dim,kernel_size//2,kernel_size//2] = 1.0
|
||||
return zeros
|
||||
|
||||
def fill_conv_kernel(in_tensor,kernel_size=3):
|
||||
shape = in_tensor.shape
|
||||
zeros = torch.zeros(shape[0],shape[1],kernel_size,kernel_size).cuda()
|
||||
for i_dim in range(shape[0]):
|
||||
zeros[i_dim,:,kernel_size//2,kernel_size//2] = in_tensor[i_dim,:,0,0]
|
||||
return zeros
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
|
||||
base_path = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\"
|
||||
version = "repsr_pixel_0"
|
||||
epoch = 73
|
||||
path_ckp= os.path.join(base_path,version,"checkpoints\\epoch%d_RepSR.pth"%epoch)
|
||||
path_plain_ckp= os.path.join(base_path,version,"checkpoints\\epoch%d_RepSR_Plain.pth"%epoch)
|
||||
network = RepSRPlain_pixel(3,
|
||||
3,
|
||||
64,
|
||||
3,
|
||||
4
|
||||
)
|
||||
network = network.cuda()
|
||||
|
||||
|
||||
|
||||
wocao = network.state_dict()
|
||||
# for data_key in wocao.keys():
|
||||
# print(data_key)
|
||||
# print(wocao[data_key].shape)
|
||||
wocao_cpk = torch.load(path_ckp)
|
||||
|
||||
# for data_key in wocao_cpk.keys():
|
||||
# print(data_key)
|
||||
# print(wocao_cpk[data_key].shape)
|
||||
name_list = ["rep1","rep2","rep3","rep4","rep5","rep6"]
|
||||
other_list = ["conv_up1","conv_up2","conv_hr","conv_last"]
|
||||
for i_name in name_list:
|
||||
temp= wocao_cpk[i_name+".conv3.weight"] + fill_conv_kernel(wocao_cpk[i_name+".conv1x1.weight"])
|
||||
wocao[i_name+".weight"] = temp
|
||||
temp= wocao_cpk[i_name+".conv3.bias"] + wocao_cpk[i_name+".conv1x1.bias"]
|
||||
wocao[i_name+".bias"] = temp
|
||||
|
||||
if wocao_cpk[i_name+".conv3.weight"].shape[0] == wocao_cpk[i_name+".conv3.weight"].shape[1]:
|
||||
print("include identity")
|
||||
temp = wocao[i_name+".weight"] + create_identity_conv(wocao_cpk[i_name+".conv3.weight"].shape[0])
|
||||
wocao[i_name+".weight"] = temp
|
||||
|
||||
for i_name in other_list:
|
||||
wocao[i_name+".weight"] = wocao_cpk[i_name+".weight"]
|
||||
wocao[i_name+".bias"] = wocao_cpk[i_name+".bias"]
|
||||
|
||||
torch.save(wocao,path_plain_ckp)
|
||||
|
||||
# wocao = torch.load(path_plain_ckp)
|
||||
# for data_key in wocao.keys():
|
||||
# result1 = wocao[data_key].cpu().numpy()
|
||||
# # np.savetxt(i_name+"_conv3_weight.txt",result1)
|
||||
# str_temp = ("%s"%data_key).replace(".","_")
|
||||
# io.savemat(str_temp+".mat",{str_temp:result1})
|
||||
|
||||
# for data_key in wocao.keys():
|
||||
# print(data_key)
|
||||
# print(wocao[data_key].shape)
|
||||
@@ -0,0 +1,335 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: utilities.py
|
||||
# Created Date: Monday April 6th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 12th October 2021 2:18:05 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
||||
# Gram Matrix
|
||||
def Gram(tensor: torch.Tensor):
|
||||
B, C, H, W = tensor.shape
|
||||
x = tensor.view(B, C, H*W)
|
||||
x_t = x.transpose(1, 2)
|
||||
return torch.bmm(x, x_t) / (C*H*W)
|
||||
|
||||
def build_tensorboard(summary_path):
|
||||
from tensorboardX import SummaryWriter
|
||||
# from logger import Logger
|
||||
# self.logger = Logger(self.log_path)
|
||||
writer = SummaryWriter(log_dir=summary_path)
|
||||
return writer
|
||||
|
||||
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
|
||||
def tensor2img(img_tensor):
|
||||
"""
|
||||
Input image tensor shape must be [B C H W]
|
||||
the return image numpy array shape is [B H W C]
|
||||
"""
|
||||
res = img_tensor.numpy()
|
||||
res = (res + 1) / 2
|
||||
res = np.clip(res, 0.0, 1.0)
|
||||
res = res * 255
|
||||
res = res.transpose((0,2,3,1))
|
||||
return res
|
||||
|
||||
def img2tensor255(path, max_size=None):
|
||||
|
||||
image = Image.open(path)
|
||||
# Rescale the image
|
||||
if (max_size==None):
|
||||
itot_t = transforms.Compose([
|
||||
#transforms.ToPILImage(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x.mul(255))
|
||||
])
|
||||
else:
|
||||
H, W, C = image.shape
|
||||
image_size = tuple([int((float(max_size) / max([H,W]))*x) for x in [H, W]])
|
||||
itot_t = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize(image_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x.mul(255))
|
||||
])
|
||||
|
||||
# Convert image to tensor
|
||||
tensor = itot_t(image)
|
||||
|
||||
# Add the batch_size dimension
|
||||
tensor = tensor.unsqueeze(dim=0)
|
||||
return tensor
|
||||
|
||||
def img2tensor255crop(path, crop_size=256):
|
||||
|
||||
image = Image.open(path)
|
||||
# Rescale the image
|
||||
itot_t = transforms.Compose([
|
||||
transforms.CenterCrop(crop_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x.mul(255))
|
||||
])
|
||||
|
||||
# Convert image to tensor
|
||||
tensor = itot_t(image)
|
||||
|
||||
# Add the batch_size dimension
|
||||
tensor = tensor.unsqueeze(dim=0)
|
||||
return tensor
|
||||
|
||||
# def img2tensor255(path, crop_size=None):
|
||||
# """
|
||||
# Input image tensor shape must be [B C H W]
|
||||
# the return image numpy array shape is [B H W C]
|
||||
# """
|
||||
# img = cv2.imread(path)
|
||||
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float)
|
||||
# img = torch.from_numpy(img).transpose((2,0,1)).unsqueeze(0)
|
||||
# return img
|
||||
|
||||
def img2tensor1(img_tensor):
|
||||
"""
|
||||
Input image tensor shape must be [B C H W]
|
||||
the return image numpy array shape is [B H W C]
|
||||
"""
|
||||
res = img_tensor.numpy()
|
||||
res = (res + 1) / 2
|
||||
res = np.clip(res, 0.0, 1.0)
|
||||
res = res * 255
|
||||
res = res.transpose((0,2,3,1))
|
||||
return res
|
||||
|
||||
def _convert_input_type_range(img):
|
||||
"""Convert the type and range of the input image.
|
||||
|
||||
It converts the input image to np.float32 type and range of [0, 1].
|
||||
It is mainly used for pre-processing the input image in colorspace
|
||||
convertion functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with type of np.float32 and range of
|
||||
[0, 1].
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = img.astype(np.float32)
|
||||
if img_type == np.float32:
|
||||
pass
|
||||
elif img_type == np.uint8:
|
||||
img /= 255.
|
||||
else:
|
||||
raise TypeError('The img type should be np.float32 or np.uint8, '
|
||||
f'but got {img_type}')
|
||||
return img
|
||||
|
||||
def _convert_output_type_range(img, dst_type):
|
||||
"""Convert the type and range of the image according to dst_type.
|
||||
|
||||
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
||||
images will be converted to np.uint8 type with range [0, 255]. If
|
||||
`dst_type` is np.float32, it converts the image to np.float32 type with
|
||||
range [0, 1].
|
||||
It is mainly used for post-processing images in colorspace convertion
|
||||
functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The image to be converted with np.float32 type and
|
||||
range [0, 255].
|
||||
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
||||
converts the image to np.uint8 type with range [0, 255]. If
|
||||
dst_type is np.float32, it converts the image to np.float32 type
|
||||
with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with desired type and range.
|
||||
"""
|
||||
if dst_type not in (np.uint8, np.float32):
|
||||
raise TypeError('The dst_type should be np.float32 or np.uint8, '
|
||||
f'but got {dst_type}')
|
||||
if dst_type == np.uint8:
|
||||
img = img.round()
|
||||
else:
|
||||
img /= 255.
|
||||
return img.astype(dst_type)
|
||||
|
||||
|
||||
def bgr2ycbcr(img, y_only=False):
|
||||
"""Convert a BGR image to YCbCr image.
|
||||
|
||||
The bgr version of rgb2ycbcr.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
# out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
||||
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 #RGB
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
||||
[65.481, -37.797, 112.0]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
def to_y_channel(img):
|
||||
"""Change to Y channel of YCbCr.
|
||||
|
||||
Args:
|
||||
img (ndarray): Images with range [0, 255].
|
||||
|
||||
Returns:
|
||||
(ndarray): Images with range [0, 255] (float type) without round.
|
||||
"""
|
||||
img = img.astype(np.float32) / 255.
|
||||
if img.ndim == 3 and img.shape[2] == 3:
|
||||
img = bgr2ycbcr(img, y_only=True)
|
||||
img = img[..., None]
|
||||
return img * 255.
|
||||
|
||||
def calculate_psnr(img1,
|
||||
img2,
|
||||
# crop_border=0,
|
||||
test_y_channel=True):
|
||||
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
||||
|
||||
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
||||
|
||||
Args:
|
||||
img1 (ndarray): Images with range [0, 255].
|
||||
img2 (ndarray): Images with range [0, 255].
|
||||
crop_border (int): Cropped pixels in each edge of an image. These
|
||||
pixels are not involved in the PSNR calculation.
|
||||
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
||||
Default: 'HWC'.
|
||||
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
||||
|
||||
Returns:
|
||||
float: psnr result.
|
||||
"""
|
||||
|
||||
# if crop_border != 0:
|
||||
# img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
# img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
|
||||
if test_y_channel:
|
||||
img1 = to_y_channel(img1)
|
||||
img2 = to_y_channel(img2)
|
||||
|
||||
mse = np.mean((img1 - img2)**2)
|
||||
if mse == 0:
|
||||
return float('inf')
|
||||
return 20. * np.log10(255. / np.sqrt(mse))
|
||||
|
||||
|
||||
def _ssim(img1, img2):
|
||||
"""Calculate SSIM (structural similarity) for one channel images.
|
||||
|
||||
It is called by func:`calculate_ssim`.
|
||||
|
||||
Args:
|
||||
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
||||
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
||||
|
||||
Returns:
|
||||
float: ssim result.
|
||||
"""
|
||||
|
||||
C1 = (0.01 * 255)**2
|
||||
C2 = (0.03 * 255)**2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) *
|
||||
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
def calculate_ssim(img1,
|
||||
img2,
|
||||
test_y_channel=True):
|
||||
"""Calculate SSIM (structural similarity).
|
||||
|
||||
Ref:
|
||||
Image quality assessment: From error visibility to structural similarity
|
||||
|
||||
The results are the same as that of the official released MATLAB code in
|
||||
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
||||
|
||||
For three-channel images, SSIM is calculated for each channel and then
|
||||
averaged.
|
||||
|
||||
Args:
|
||||
img1 (ndarray): Images with range [0, 255].
|
||||
img2 (ndarray): Images with range [0, 255].
|
||||
crop_border (int): Cropped pixels in each edge of an image. These
|
||||
pixels are not involved in the SSIM calculation.
|
||||
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
||||
Default: 'HWC'.
|
||||
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
||||
|
||||
Returns:
|
||||
float: ssim result.
|
||||
"""
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
|
||||
if test_y_channel:
|
||||
img1 = to_y_channel(img1)
|
||||
img2 = to_y_channel(img2)
|
||||
|
||||
ssims = []
|
||||
for i in range(img1.shape[2]):
|
||||
ssims.append(_ssim(img1[..., i], img2[..., i]))
|
||||
return np.array(ssims).mean()
|
||||
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Config_from_yaml.py
|
||||
# Created Date: Monday February 17th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 28th February 2020 4:30:01 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import yaml
|
||||
|
||||
def getConfigYaml(yaml_file):
|
||||
with open(yaml_file, 'r') as config_file:
|
||||
try:
|
||||
config_dict = yaml.load(config_file, Loader=yaml.FullLoader)
|
||||
return config_dict
|
||||
except ValueError:
|
||||
print('INVALID YAML file format.. Please provide a good yaml file')
|
||||
exit(-1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
a= getConfigYaml("./train_256.yaml")
|
||||
sys_state = {}
|
||||
for item in a.items():
|
||||
sys_state[item[0]] = item[1]
|
||||
Reference in New Issue
Block a user