Files
disrupting-deepfakes/ganimation/solver.py
2020-01-09 17:57:43 -04:00

514 lines
23 KiB
Python

import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from model import Generator, Discriminator
from utils import Utils
import numpy as np
import os
import time
import datetime
import random
import glob
import attacks
class Solver(Utils):
def __init__(self, data_loader, config_dict):
# NOTE: the following line create new class arguments with the
# values in config_dict
self.__dict__.update(**config_dict)
self.data_loader = data_loader
self.device = 'cuda:' + \
str(self.gpu_id) if torch.cuda.is_available() else 'cpu'
print(f"Model running on {self.device}")
if self.use_tensorboard:
self.build_tensorboard()
self.loss_visualization = {}
self.build_model()
def train(self):
print('Training...')
self.global_counter = 0
if self.resume_iters:
self.first_iteration = self.resume_iters
self.restore_model(self.resume_iters)
else:
self.first_iteration = 0
self.start_time = time.time()
for epoch in range(self.first_epoch, self.num_epochs):
print(f"EPOCH {epoch} WITH {len(self.data_loader)} STEPS")
self.alpha_rec = 1
self.epoch = epoch
for iteration in range(self.first_iteration, len(self.data_loader)):
self.iteration = iteration
self.get_training_data()
self.train_discriminator()
if (self.iteration+1) % self.n_critic == 0:
generation_outputs = self.train_generator()
if (self.iteration+1) % self.sample_step == 0:
self.print_generations(generation_outputs)
if self.iteration % self.model_save_step == 0:
self.save_models(self.iteration, self.epoch)
if self.iteration % self.log_step == 0:
self.update_tensorboard()
self.global_counter += 1
# Decay learning rates.
if (self.epoch+1) > self.num_epochs_decay:
# float(self.num_epochs_decay))
self.g_lr -= (self.g_lr / 10.0)
# float(self.num_epochs_decay))
self.d_lr -= (self.d_lr / 10.0)
self.update_lr(self.g_lr, self.d_lr)
print('Decayed learning rates, self.g_lr: {}, self.d_lr: {}.'.format(
self.g_lr, self.d_lr))
# Save the last model
self.save_models()
self.first_iteration = 0 # Next epochs start from 0
def get_training_data(self):
try:
self.x_real, self.label_org = next(self.data_iter)
except:
self.data_iter = iter(self.data_loader)
self.x_real, self.label_org = next(self.data_iter)
self.x_real = self.x_real.to(self.device) # Input images.
# Labels for computing classification loss.
self.label_org = self.label_org.to(self.device)
# Get random targets for training
self.label_trg = self.get_random_labels_list()
self.label_trg = torch.FloatTensor(self.label_trg).clamp(0, 1)
# Labels for computing classification loss.
self.label_trg = self.label_trg.to(self.device)
if self.use_virtual:
self.label_trg_virtual = self.get_random_labels_list()
self.label_trg_virtual = torch.FloatTensor(
self.label_trg_virtual).clamp(0, 1)
# Labels for computing classification loss.
self.label_trg_virtual = self.label_trg_virtual.to(self.device)
assert not torch.equal(
self.label_trg_virtual, self.label_trg), "Target label and virtual label are the same"
def get_random_labels_list(self):
trg_list = []
for _ in range(self.batch_size):
random_num = random.randint(
0, len(self.data_loader)*self.batch_size-1)
# Select a random AU vector from the dataset
trg_list_aux = self.data_loader.dataset[random_num][1]
# Apply a variance of 0.1 to the vector
trg_list.append(trg_list_aux.numpy() +
np.random.uniform(-0.1, 0.1, trg_list_aux.shape))
return trg_list
def train_discriminator(self):
# Compute loss with real images.
critic_output, classification_output = self.D(self.x_real)
d_loss_critic_real = -torch.mean(critic_output)
d_loss_classification = torch.nn.functional.mse_loss(
classification_output, self.label_org)
# Compute loss with fake images.
attention_mask, color_regression = self.G(self.x_real, self.label_trg)
x_fake = self.imFromAttReg(
attention_mask, color_regression, self.x_real)
critic_output, _ = self.D(x_fake.detach())
d_loss_critic_fake = torch.mean(critic_output)
# Compute loss for gradient penalty.
alpha = torch.rand(self.x_real.size(0), 1, 1, 1).to(self.device)
# Half of image info from fake and half from real
x_hat = (alpha * self.x_real.data + (1 - alpha)
* x_fake.data).requires_grad_(True)
critic_output, _ = self.D(x_hat)
d_loss_gp = self.gradient_penalty(critic_output, x_hat)
# Backward and optimize.
d_loss = d_loss_critic_real + d_loss_critic_fake + self.lambda_cls * \
d_loss_classification + self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
# Logging.
self.loss_visualization['D/loss'] = d_loss.item()
self.loss_visualization['D/loss_real'] = d_loss_critic_real.item()
self.loss_visualization['D/loss_fake'] = d_loss_critic_fake.item()
self.loss_visualization['D/loss_cls'] = self.lambda_cls * \
d_loss_classification.item()
self.loss_visualization['D/loss_gp'] = self.lambda_gp * \
d_loss_gp.item()
def train_generator(self):
# Original-to-target domain.
attention_mask, color_regression = self.G(self.x_real, self.label_trg)
x_fake = self.imFromAttReg(
attention_mask, color_regression, self.x_real)
critic_output, classification_output = self.D(x_fake)
g_loss_fake = -torch.mean(critic_output)
g_loss_cls = torch.nn.functional.mse_loss(
classification_output, self.label_trg)
# Target-to-original domain.
if not self.use_virtual:
reconstructed_attention_mask, reconstructed_color_regression = self.G(
x_fake, self.label_org)
x_rec = self.imFromAttReg(
reconstructed_attention_mask, reconstructed_color_regression, x_fake)
else:
reconstructed_attention_mask, reconstructed_color_regression = self.G(
x_fake, self.label_org)
x_rec = self.imFromAttReg(
reconstructed_attention_mask, reconstructed_color_regression, x_fake)
reconstructed_attention_mask_2, reconstructed_color_regression_2 = self.G(
x_fake, self.label_trg_virtual)
x_fake_virtual = self.imFromAttReg(
reconstructed_attention_mask_2, reconstructed_color_regression_2, x_fake)
reconstructed_virtual_attention_mask, reconstructed_virtual_color_regression = self.G(
x_fake_virtual, self.label_trg)
x_rec_virtual = self.imFromAttReg(
reconstructed_virtual_attention_mask, reconstructed_virtual_color_regression, x_fake_virtual.detach())
# Compute losses
g_loss_saturation_1 = attention_mask.mean()
g_loss_smooth1 = self.smooth_loss(attention_mask)
if not self.use_virtual:
g_loss_rec = torch.nn.functional.l1_loss(self.x_real, x_rec)
g_loss_saturation_2 = reconstructed_attention_mask.mean()
g_loss_smooth2 = self.smooth_loss(reconstructed_attention_mask)
else:
g_loss_rec = (1-self.alpha_rec)*torch.nn.functional.l1_loss(self.x_real, x_rec) + \
self.alpha_rec * \
torch.nn.functional.l1_loss(x_fake, x_rec_virtual)
g_loss_saturation_2 = (1-self.alpha_rec) * reconstructed_attention_mask.mean() + \
self.alpha_rec * reconstructed_virtual_attention_mask.mean()
g_loss_smooth2 = (1-self.alpha_rec) * self.smooth_loss(reconstructed_virtual_attention_mask) + \
self.alpha_rec * self.smooth_loss(reconstructed_attention_mask)
g_attention_loss = self.lambda_smooth * g_loss_smooth1 + self.lambda_smooth * g_loss_smooth2 \
+ self.lambda_sat * g_loss_saturation_1 + self.lambda_sat * g_loss_saturation_2
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + \
self.lambda_cls * g_loss_cls + g_attention_loss
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
# Logging.
self.loss_visualization['G/loss'] = g_loss.item()
self.loss_visualization['G/loss_fake'] = g_loss_fake.item()
self.loss_visualization['G/loss_rec'] = self.lambda_rec * \
g_loss_rec.item()
self.loss_visualization['G/loss_cls'] = self.lambda_cls * \
g_loss_cls.item()
self.loss_visualization['G/attention_loss'] = g_attention_loss.item()
self.loss_visualization['G/loss_smooth1'] = self.lambda_smooth * \
g_loss_smooth1.item()
self.loss_visualization['G/loss_smooth2'] = self.lambda_smooth * \
g_loss_smooth2.item()
self.loss_visualization['G/loss_sat1'] = self.lambda_sat * \
g_loss_saturation_1.item()
self.loss_visualization['G/loss_sat2'] = self.lambda_sat * \
g_loss_saturation_2.item()
self.loss_visualization['G/alpha'] = self.alpha_rec
if not self.use_virtual:
return {
"color_regression": color_regression,
"x_fake": x_fake,
"attention_mask": attention_mask,
"x_rec": x_rec,
"reconstructed_attention_mask": reconstructed_attention_mask,
"reconstructed_attention_mask": reconstructed_attention_mask,
"reconstructed_color_regression": reconstructed_color_regression,
}
else:
return {
"color_regression": color_regression,
"x_fake": x_fake,
"attention_mask": attention_mask,
"x_rec": x_rec,
"reconstructed_attention_mask": reconstructed_attention_mask,
"reconstructed_attention_mask": reconstructed_attention_mask,
"reconstructed_color_regression": reconstructed_color_regression,
"reconstructed_virtual_attention_mask": reconstructed_virtual_attention_mask,
"reconstructed_virtual_color_regression": reconstructed_virtual_color_regression,
"x_rec_virtual": x_rec_virtual,
}
def print_generations(self, generator_outputs_dict):
print_epoch_images = False
save_image(self.denorm(self.x_real), self.sample_dir +
'/{}_4real_.png'.format(self.epoch))
save_image((generator_outputs_dict["color_regression"]+1)/2,
self.sample_dir + '/{}_2reg_.png'.format(self.epoch))
save_image(self.denorm(
generator_outputs_dict["x_fake"]), self.sample_dir + '/{}_3res_.png'.format(self.epoch))
save_image(generator_outputs_dict["attention_mask"],
self.sample_dir + '/{}_1attention_.png'.format(self.epoch))
save_image(self.denorm(
generator_outputs_dict["x_rec"]), self.sample_dir + '/{}_5rec_.png'.format(self.epoch))
if not self.use_virtual:
save_image(generator_outputs_dict["reconstructed_attention_mask"],
self.sample_dir + '/{}_6rec_attention.png'.format(self.epoch))
save_image(self.denorm(
generator_outputs_dict["reconstructed_color_regression"]), self.sample_dir + '/{}_7rec_reg.png'.format(self.epoch))
else:
save_image(generator_outputs_dict["reconstructed_attention_mask"],
self.sample_dir + '/{}_6rec_attention_.png'.format(self.epoch))
save_image(self.denorm(
generator_outputs_dict["reconstructed_color_regression"]), self.sample_dir + '/{}_7rec_reg.png'.format(self.epoch))
save_image(generator_outputs_dict["reconstructed_virtual_attention_mask"],
self.sample_dir + '/{}_8rec_virtual_attention.png'.format(self.epoch))
save_image(self.denorm(generator_outputs_dict["reconstructed_virtual_color_regression"]),
self.sample_dir + '/{}_91rec_virtual_reg.png'.format(self.epoch))
save_image(self.denorm(
generator_outputs_dict["x_rec_virtual"]), self.sample_dir + '/{}_92rec_epoch_.png'.format(self.epoch))
def update_tensorboard(self):
# Print out training information.
et = time.time() - self.start_time
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], [{}/{}], Epoch [{}/{}]".format(
et, self.iteration+1, len(self.data_loader), self.epoch+1, self.num_epochs)
for tag, value in self.loss_visualization.items():
log += ", {}: {:.4f}".format(tag, value)
print(log)
if self.use_tensorboard:
for tag, value in self.loss_visualization.items():
self.writer.add_scalar(
tag, value, global_step=self.global_counter)
def animation(self, mode='animate_image'):
from PIL import Image
from torchvision import transforms as T
regular_image_transform = []
regular_image_transform.append(T.ToTensor())
regular_image_transform.append(T.Normalize(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
regular_image_transform = T.Compose(regular_image_transform)
G_path = sorted(glob.glob(os.path.join(
self.animation_models_dir, '*G.ckpt')), key=self.numericalSort)[0]
self.G.load_state_dict(torch.load(G_path, map_location=f'cuda:{self.gpu_id}'))
self.G = self.G.cuda(0)
reference_expression_images = []
with torch.no_grad():
with open(self.animation_attributes_path, 'r') as txt_file:
csv_lines = txt_file.readlines()
targets = torch.zeros(len(csv_lines), self.c_dim)
input_images = torch.zeros(len(csv_lines), 3, 128, 128)
for idx, line in enumerate(csv_lines):
splitted_lines = line.split(' ')
image_path = os.path.join(
self.animation_attribute_images_dir, splitted_lines[0])
input_images[idx, :] = regular_image_transform(
Image.open(image_path)).cuda()
reference_expression_images.append(splitted_lines[0])
targets[idx, :] = torch.Tensor(
np.array(list(map(lambda x: float(x)/5., splitted_lines[1::]))))
if mode == 'animate_random_batch':
animation_batch_size = 7
self.data_iter = iter(self.data_loader)
images_to_animate, _ = next(self.data_iter)
images_to_animate = images_to_animate[0:animation_batch_size].cuda(
)
for target_idx in range(targets.size(0)):
targets_au = targets[target_idx, :].unsqueeze(
0).repeat(animation_batch_size, 1).cuda()
resulting_images_att, resulting_images_reg = self.G(
images_to_animate, targets_au)
resulting_images = self.imFromAttReg(
resulting_images_att, resulting_images_reg, images_to_animate).cuda()
save_images = - \
torch.ones((animation_batch_size + 1)
* 2, 3, 128, 128).cuda()
save_images[1:animation_batch_size+1] = images_to_animate
save_images[animation_batch_size+1] = input_images[target_idx]
save_images[animation_batch_size +
2:(animation_batch_size + 1)*2] = resulting_images
save_image((save_images+1)/2, os.path.join(self.animation_results_dir,
reference_expression_images[target_idx]))
if mode == 'animate_image':
black = np.zeros((1,3,128,128))
black = torch.FloatTensor(black).to(self.device)
# Initialize Metrics
l1_error = 0.0
l2_error = 0.0
min_dist = 0.0
l0_error = 0.0
perceptual_error = 0.0
n_samples = 0
pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device)
images_to_animate_path = sorted(glob.glob(
self.animation_images_dir + '/*'))
x_advs = []
for idx, image_path in enumerate(images_to_animate_path):
image_to_animate = regular_image_transform(Image.open(image_path)).unsqueeze(0).cuda()
all_images = torch.cat([regular_image_transform(Image.open(path)).unsqueeze(0) for path in images_to_animate_path], dim=0).cuda()
# Transfer to different images
# if idx == 0:
# for target_idx in range(targets.size(0)):
# x_adv, perturb = pgd_attack.perturb(image_to_animate, black, targets[target_idx, :].unsqueeze(0).cuda())
# x_advs.append((x_adv, perturb))
for target_idx in range(targets.size(0)):
# Transfer to different classes
if target_idx == 0:
# img = regular_image_transform(Image.open(images_to_animate_path[idx])).unsqueeze(0).cuda()
# Wrong Class
# x_adv, perturb = pgd_attack.perturb(image_to_animate, black, targets[0, :].unsqueeze(0).cuda())
# Joint Class Conditional
# x_adv, perturb = pgd_attack.perturb_joint_class(image_to_animate, black, targets[:, :].cuda())
# Iterative Class Conditional
x_adv, perturb = pgd_attack.perturb_iter_class(image_to_animate, black, targets[:, :].cuda())
# Iterative Data
# _, perturb = pgd_attack.perturb_iter_data(image_to_animate, all_images, black, targets[68, :].unsqueeze(0).cuda())
targets_au = targets[target_idx, :].unsqueeze(0).cuda()
# Normal Attack
# x_adv, perturb = pgd_attack.perturb(image_to_animate, black, targets_au)
# x_adv, perturb = x_advs[target_idx]
x_adv = image_to_animate + perturb
# No Attack
# x_adv = image_to_animate
# print(image_to_animate.shape, x_adv.shape)
with torch.no_grad():
resulting_images_att, resulting_images_reg = self.G(
x_adv, targets_au)
resulting_image = self.imFromAttReg(
resulting_images_att, resulting_images_reg, x_adv).cuda()
# with torch.no_grad():
# resulting_images_att_noattack, resulting_images_reg_noattack = self.G(
# image_to_animate, targets_au)
# resulting_image_noattack = self.imFromAttReg(
# resulting_images_att_noattack, resulting_images_reg_noattack, image_to_animate).cuda()
save_image((resulting_image+1)/2, os.path.join(self.animation_results_dir,
image_path.split('/')[-1].split('.')[0]
+ '_' + reference_expression_images[target_idx]))
if target_idx == 0:
save_image((x_adv+1)/2, os.path.join(self.animation_results_dir,
image_path.split('/')[-1].split('.')[0]
+ '_ref.jpg'))
# l1_error += F.l1_loss(resulting_image, resulting_image_noattack)
# l2_error += F.mse_loss(resulting_image, resulting_image_noattack)
# l0_error += (resulting_image - resulting_image_noattack).norm(0)
# min_dist += (resulting_image - resulting_image_noattack).norm(float('-inf'))
# Compare to input image
l1_error += F.l1_loss(resulting_image, image_to_animate)
l2_error += F.mse_loss(resulting_image, image_to_animate)
l0_error += (resulting_image - image_to_animate).norm(0)
min_dist += (resulting_image - image_to_animate).norm(float('-inf'))
n_samples += 1
# Print metrics
print('{} images. L1 error: {}. L2 error: {}. L0 error: {}. L_-inf error: {}. Perceptual error: {}.'.format(n_samples,
l1_error / n_samples, l2_error / n_samples, l0_error / n_samples, min_dist / n_samples, perceptual_error / n_samples))
# """ Code to modify single Action Units """
# Set data loader.
# self.data_loader = self.data_loader
# with torch.no_grad():
# for i, (self.x_real, c_org) in enumerate(self.data_loader):
# # Prepare input images and target domain labels.
# self.x_real = self.x_real.to(self.device)
# c_org = c_org.to(self.device)
# # c_trg_list = self.create_labels(self.data_loader)
# crit, cl_regression = self.D(self.x_real)
# # print(crit)
# print("ORIGINAL", c_org[0])
# print("REGRESSION", cl_regression[0])
# for au in range(17):
# alpha = np.linspace(-0.3,0.3,10)
# for j, a in enumerate(alpha):
# new_emotion = c_org.clone()
# new_emotion[:,au]=torch.clamp(new_emotion[:,au]+a, 0, 1)
# attention, reg = self.G(self.x_real, new_emotion)
# x_fake = self.imFromAttReg(attention, reg, self.x_real)
# save_image((x_fake+1)/2, os.path.join(self.result_dir, '{}-{}-{}-images.jpg'.format(i,au,j)))
# if i >= 3:
# break