Files
SimSwapPlus/utilities/transfer_checkpoint.py
T
chenxuanhong 736d45ec00 update
2022-01-17 13:26:59 +08:00

145 lines
5.6 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: transfer_checkpoint.py
# Created Date: Wednesday February 3rd 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 17th January 2022 1:25:56 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
import os
class RepSRPlain_pixel(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
Currently, it supports x4 upsampling scale factor.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(self,
num_in_ch,
num_out_ch,
num_feat=32,
num_layer = 3,
upsampling=4):
super(RepSRPlain_pixel, self).__init__()
self.scale = upsampling
self.ssqu = upsampling ** 2
self.rep1 = nn.Conv2d(num_in_ch, num_feat,3,1,1)
self.rep2 = nn.Conv2d(num_feat, num_feat*2,3,1,1)
self.rep3 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
self.rep4 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
self.rep5 = nn.Conv2d(num_feat*2, num_feat*2,3,1,1)
self.rep6 = nn.Conv2d(num_feat*2, num_feat,3,1,1)
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.activator = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# self.activator = nn.ReLU(inplace=True)
# default_init_weights([self.conv_up1,self.conv_up2,self.conv_hr,self.conv_last], 0.1)
def forward(self, x):
f_d = self.activator(self.rep1(x))
f_d = self.activator(self.rep2(f_d))
f_d = self.activator(self.rep3(f_d))
f_d = self.activator(self.rep4(f_d))
f_d = self.activator(self.rep5(f_d))
f_d = self.activator(self.rep6(f_d))
feat = self.activator(
self.conv_up1(F.interpolate(f_d, scale_factor=2, mode='nearest')))
feat = self.activator(
self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.activator(self.conv_hr(feat)))
return out
def create_identity_conv(dim,kernel_size=3):
zeros = torch.zeros((dim,dim,kernel_size,kernel_size)).cuda()
for i_dim in range(dim):
zeros[i_dim,i_dim,kernel_size//2,kernel_size//2] = 1.0
return zeros
def fill_conv_kernel(in_tensor,kernel_size=3):
shape = in_tensor.shape
zeros = torch.zeros(shape[0],shape[1],kernel_size,kernel_size).cuda()
for i_dim in range(shape[0]):
zeros[i_dim,:,kernel_size//2,kernel_size//2] = in_tensor[i_dim,:,0,0]
return zeros
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
base_path = "H:\\Multi Scale Kernel Prediction Networks\\Mobile_Oriented_KPN\\train_logs\\"
version = "repsr_pixel_0"
epoch = 73
path_ckp= os.path.join(base_path,version,"checkpoints\\epoch%d_RepSR.pth"%epoch)
path_plain_ckp= os.path.join(base_path,version,"checkpoints\\epoch%d_RepSR_Plain.pth"%epoch)
network = RepSRPlain_pixel(3,
3,
64,
3,
4
)
network = network.cuda()
wocao = network.state_dict()
# for data_key in wocao.keys():
# print(data_key)
# print(wocao[data_key].shape)
wocao_cpk = torch.load(path_ckp)
# for data_key in wocao_cpk.keys():
# print(data_key)
# print(wocao_cpk[data_key].shape)
name_list = ["rep1","rep2","rep3","rep4","rep5","rep6"]
other_list = ["conv_up1","conv_up2","conv_hr","conv_last"]
for i_name in name_list:
temp= wocao_cpk[i_name+".conv3.weight"] + fill_conv_kernel(wocao_cpk[i_name+".conv1x1.weight"])
wocao[i_name+".weight"] = temp
temp= wocao_cpk[i_name+".conv3.bias"] + wocao_cpk[i_name+".conv1x1.bias"]
wocao[i_name+".bias"] = temp
if wocao_cpk[i_name+".conv3.weight"].shape[0] == wocao_cpk[i_name+".conv3.weight"].shape[1]:
print("include identity")
temp = wocao[i_name+".weight"] + create_identity_conv(wocao_cpk[i_name+".conv3.weight"].shape[0])
wocao[i_name+".weight"] = temp
for i_name in other_list:
wocao[i_name+".weight"] = wocao_cpk[i_name+".weight"]
wocao[i_name+".bias"] = wocao_cpk[i_name+".bias"]
torch.save(wocao,path_plain_ckp)
# wocao = torch.load(path_plain_ckp)
# for data_key in wocao.keys():
# result1 = wocao[data_key].cpu().numpy()
# # np.savetxt(i_name+"_conv3_weight.txt",result1)
# str_temp = ("%s"%data_key).replace(".","_")
# io.savemat(str_temp+".mat",{str_temp:result1})
# for data_key in wocao.keys():
# print(data_key)
# print(wocao[data_key].shape)