Files
Nataniel Ruiz Gutierrez 21970b730a All
2019-12-21 16:37:10 -05:00

152 lines
6.2 KiB
Python

import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from model import Generator
from model import Discriminator
import os
import re
class Utils:
def build_model(self):
"""Create a generator and a discriminator."""
self.G = Generator(self.g_conv_dim, self.c_dim,
self.g_repeat_num).to(self.device)
self.D = Discriminator(
self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num).to(self.device)
self.g_optimizer = torch.optim.Adam(
self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.d_optimizer = torch.optim.Adam(
self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
# TODO: implement data parallelization for multiple gpus
# self.gpu_ids = torch.cuda.device_count()
# print("GPUS AVAILABLE: ", self.gpu_ids)
# if self.gpu_ids > 1:
# torch.nn.DataParallel(self.D, device_ids=list(range(self.gpu_ids)))
# torch.nn.DataParallel(self.G, device_ids=list(range(self.gpu_ids)))
def build_tensorboard(self):
"""Build a tensorboard logger."""
from logger import Logger
self.logger = Logger(self.log_dir)
self.writer = SummaryWriter(logdir=self.log_dir)
def smooth_loss(self, att):
return torch.mean(torch.mean(torch.abs(att[:, :, :, :-1] - att[:, :, :, 1:])) +
torch.mean(torch.abs(att[:, :, :-1, :] - att[:, :, 1:, :])))
def print_network(self, model, name):
"""Print out the network information."""
num_params = 0
for p in model.parameters():
num_params += p.numel()
print(model)
print(name)
print("The number of parameters: {}".format(num_params))
def update_lr(self, g_lr, d_lr):
"""Decay learning rates of the generator and discriminator."""
for param_group in self.g_optimizer.param_groups:
param_group['lr'] = g_lr
for param_group in self.d_optimizer.param_groups:
param_group['lr'] = d_lr
def reset_grad(self):
"""Reset the gradient buffers."""
self.g_optimizer.zero_grad()
self.d_optimizer.zero_grad()
def denorm(self, x):
"""Convert the range from [-1, 1] to [0, 1]."""
out = (x + 1) / 2
return out.clamp_(0, 1)
def gradient_penalty(self, y, x):
"""Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
weight = torch.ones(y.size()).to(self.device)
dydx = torch.autograd.grad(outputs=y,
inputs=x,
grad_outputs=weight,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
dydx = dydx.view(dydx.size(0), -1)
dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
return torch.mean((dydx_l2norm-1)**2)
def imFromAttReg(self, att, reg, x_real):
"""Mixes attention, color and real images"""
return (1-att)*reg + att*x_real
def create_labels(self, data_iter):
"""Return samples for visualization"""
x, c = [], []
x_data, c_data = data_iter.next()
for i in range(self.num_sample_targets):
x.append(x_data[i].repeat(
self.batch_size, 1, 1, 1).to(self.device))
c.append(c_data[i].repeat(self.batch_size, 1).to(self.device))
return x, c
def save_models(self, iteration, epoch):
try: # To avoid crashing on the first step
os.remove(os.path.join(self.model_save_dir,
'{}-{}-G.ckpt'.format(iteration+1-self.model_save_step, epoch)))
os.remove(os.path.join(self.model_save_dir,
'{}-{}-D.ckpt'.format(iteration+1-self.model_save_step, epoch)))
os.remove(os.path.join(self.model_save_dir,
'{}-{}-G_optim.ckpt'.format(iteration+1-self.model_save_step, epoch)))
os.remove(os.path.join(self.model_save_dir,
'{}-{}-D_optim.ckpt'.format(iteration+1-self.model_save_step, epoch)))
except:
pass
G_path = os.path.join(self.model_save_dir,
'{}-{}-G.ckpt'.format(iteration+1, epoch))
D_path = os.path.join(self.model_save_dir,
'{}-{}-D.ckpt'.format(iteration+1, epoch))
torch.save(self.G.state_dict(), G_path)
torch.save(self.D.state_dict(), D_path)
G_path_optim = os.path.join(
self.model_save_dir, '{}-{}-G_optim.ckpt'.format(iteration+1, epoch))
D_path_optim = os.path.join(
self.model_save_dir, '{}-{}-D_optim.ckpt'.format(iteration+1, epoch))
torch.save(self.g_optimizer.state_dict(), G_path_optim)
torch.save(self.d_optimizer.state_dict(), D_path_optim)
print(f'Saved model checkpoints in {self.model_save_dir}...')
def restore_model(self, resume_iters):
"""Restore the trained generator and discriminator."""
print('Loading the trained models from step {}-{}...'.format(resume_iters, self.first_epoch))
G_path = os.path.join(
self.model_save_dir, '{}-{}-G.ckpt'.format(resume_iters, self.first_epoch))
D_path = os.path.join(
self.model_save_dir, '{}-{}-D.ckpt'.format(resume_iters, self.first_epoch))
self.G.load_state_dict(torch.load(
G_path, map_location=lambda storage, loc: storage))
self.D.load_state_dict(torch.load(
D_path, map_location=lambda storage, loc: storage))
G_optim_path = os.path.join(
self.model_save_dir, '{}-{}-G_optim.ckpt'.format(resume_iters, self.first_epoch))
D_optim_path = os.path.join(
self.model_save_dir, '{}-{}-D_optim.ckpt'.format(resume_iters, self.first_epoch))
self.d_optimizer.load_state_dict(torch.load(D_optim_path))
self.g_optimizer.load_state_dict(torch.load(G_optim_path))
def numericalSort(self, value):
numbers = re.compile(r'(\d+)')
parts = numbers.split(value)
parts[1::2] = map(int, parts[1::2])
return parts