next
This commit is contained in:
@@ -1,48 +0,0 @@
|
||||
<img src='http://www.albertpumarola.com/images/2018/GANimation/face1_cyc.gif' align="right" width=90>
|
||||
|
||||
# GANimation: Anatomically-aware Facial Animation from a Single Image
|
||||
### [[Project]](http://www.albertpumarola.com/research/GANimation/index.html)[ [Paper]](https://rdcu.be/bPuaJ)
|
||||
Official implementation of [GANimation](http://www.albertpumarola.com/research/GANimation/index.html). In this work we introduce a novel GAN conditioning scheme based on Action Units (AU) annotations, which describe in a continuous manifold the anatomical facial movements defining a human expression. Our approach permits controlling the magnitude of activation of each AU and combine several of them. For more information please refer to the [paper](https://arxiv.org/abs/1807.09251).
|
||||
|
||||
This code was made public to share our research for the benefit of the scientific community. Do NOT use it for immoral purposes.
|
||||
|
||||

|
||||
|
||||
## Prerequisites
|
||||
- Install PyTorch (version 0.3.1), Torch Vision and dependencies from http://pytorch.org
|
||||
- Install requirements.txt (```pip install -r requirements.txt```)
|
||||
|
||||
## Data Preparation
|
||||
The code requires a directory containing the following files:
|
||||
- `imgs/`: folder with all image
|
||||
- `aus_openface.pkl`: dictionary containing the images action units.
|
||||
- `train_ids.csv`: file containing the images names to be used to train.
|
||||
- `test_ids.csv`: file containing the images names to be used to test.
|
||||
|
||||
An example of this directory is shown in `sample_dataset/`.
|
||||
|
||||
To generate the `aus_openface.pkl` extract each image Action Units with [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace/wiki/Action-Units) and store each output in a csv file the same name as the image. Then run:
|
||||
```
|
||||
python data/prepare_au_annotations.py
|
||||
```
|
||||
|
||||
## Run
|
||||
To train:
|
||||
```
|
||||
bash launch/run_train.sh
|
||||
```
|
||||
To test:
|
||||
```
|
||||
python test --input_path path/to/img
|
||||
```
|
||||
|
||||
## Citation
|
||||
If you use this code or ideas from the paper for your research, please cite our paper:
|
||||
```
|
||||
@article{Pumarola_ijcv2019,
|
||||
title={GANimation: One-Shot Anatomically Consistent Facial Animation},
|
||||
author={A. Pumarola and A. Agudo and A.M. Martinez and A. Sanfeliu and F. Moreno-Noguer},
|
||||
booktitle={International Journal of Computer Vision (IJCV)},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
@@ -1,405 +0,0 @@
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from torch.autograd import Variable
|
||||
import utils.util as util
|
||||
import utils.plots as plot_utils
|
||||
from .models import BaseModel
|
||||
from networks.networks import NetworksFactory
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GANimation(BaseModel):
|
||||
def __init__(self, opt):
|
||||
super(GANimation, self).__init__(opt)
|
||||
self._name = 'GANimation'
|
||||
|
||||
# create networks
|
||||
self._init_create_networks()
|
||||
|
||||
# init train variables
|
||||
if self._is_train:
|
||||
self._init_train_vars()
|
||||
|
||||
# load networks and optimizers
|
||||
if not self._is_train or self._opt.load_epoch > 0:
|
||||
self.load()
|
||||
|
||||
# prefetch variables
|
||||
self._init_prefetch_inputs()
|
||||
|
||||
# init
|
||||
self._init_losses()
|
||||
|
||||
def _init_create_networks(self):
|
||||
# generator network
|
||||
self._G = self._create_generator()
|
||||
self._G.init_weights()
|
||||
if len(self._gpu_ids) > 1:
|
||||
self._G = torch.nn.DataParallel(self._G, device_ids=self._gpu_ids)
|
||||
self._G.cuda()
|
||||
|
||||
# discriminator network
|
||||
self._D = self._create_discriminator()
|
||||
self._D.init_weights()
|
||||
if len(self._gpu_ids) > 1:
|
||||
self._D = torch.nn.DataParallel(self._D, device_ids=self._gpu_ids)
|
||||
self._D.cuda()
|
||||
|
||||
def _create_generator(self):
|
||||
return NetworksFactory.get_by_name('generator_wasserstein_gan', c_dim=self._opt.cond_nc)
|
||||
|
||||
def _create_discriminator(self):
|
||||
return NetworksFactory.get_by_name('discriminator_wasserstein_gan', c_dim=self._opt.cond_nc)
|
||||
|
||||
def _init_train_vars(self):
|
||||
self._current_lr_G = self._opt.lr_G
|
||||
self._current_lr_D = self._opt.lr_D
|
||||
|
||||
# initialize optimizers
|
||||
self._optimizer_G = torch.optim.Adam(self._G.parameters(), lr=self._current_lr_G,
|
||||
betas=[self._opt.G_adam_b1, self._opt.G_adam_b2])
|
||||
self._optimizer_D = torch.optim.Adam(self._D.parameters(), lr=self._current_lr_D,
|
||||
betas=[self._opt.D_adam_b1, self._opt.D_adam_b2])
|
||||
|
||||
def _init_prefetch_inputs(self):
|
||||
self._input_real_img = self._Tensor(self._opt.batch_size, 3, self._opt.image_size, self._opt.image_size)
|
||||
self._input_real_cond = self._Tensor(self._opt.batch_size, self._opt.cond_nc)
|
||||
self._input_desired_cond = self._Tensor(self._opt.batch_size, self._opt.cond_nc)
|
||||
self._input_real_img_path = None
|
||||
self._input_real_cond_path = None
|
||||
|
||||
def _init_losses(self):
|
||||
# define loss functions
|
||||
self._criterion_cycle = torch.nn.L1Loss().cuda()
|
||||
self._criterion_D_cond = torch.nn.MSELoss().cuda()
|
||||
|
||||
# init losses G
|
||||
self._loss_g_fake = Variable(self._Tensor([0]))
|
||||
self._loss_g_cond = Variable(self._Tensor([0]))
|
||||
self._loss_g_cyc = Variable(self._Tensor([0]))
|
||||
self._loss_g_mask_1 = Variable(self._Tensor([0]))
|
||||
self._loss_g_mask_2 = Variable(self._Tensor([0]))
|
||||
self._loss_g_idt = Variable(self._Tensor([0]))
|
||||
self._loss_g_masked_fake = Variable(self._Tensor([0]))
|
||||
self._loss_g_masked_cond = Variable(self._Tensor([0]))
|
||||
self._loss_g_mask_1_smooth = Variable(self._Tensor([0]))
|
||||
self._loss_g_mask_2_smooth = Variable(self._Tensor([0]))
|
||||
self._loss_rec_real_img_rgb = Variable(self._Tensor([0]))
|
||||
self._loss_g_fake_imgs_smooth = Variable(self._Tensor([0]))
|
||||
self._loss_g_unmasked_rgb = Variable(self._Tensor([0]))
|
||||
|
||||
# init losses D
|
||||
self._loss_d_real = Variable(self._Tensor([0]))
|
||||
self._loss_d_cond = Variable(self._Tensor([0]))
|
||||
self._loss_d_fake = Variable(self._Tensor([0]))
|
||||
self._loss_d_gp = Variable(self._Tensor([0]))
|
||||
|
||||
def set_input(self, input):
|
||||
self._input_real_img.resize_(input['real_img'].size()).copy_(input['real_img'])
|
||||
self._input_real_cond.resize_(input['real_cond'].size()).copy_(input['real_cond'])
|
||||
self._input_desired_cond.resize_(input['desired_cond'].size()).copy_(input['desired_cond'])
|
||||
self._input_real_id = input['sample_id']
|
||||
self._input_real_img_path = input['real_img_path']
|
||||
|
||||
if len(self._gpu_ids) > 0:
|
||||
self._input_real_img = self._input_real_img.cuda(self._gpu_ids[0], async=True)
|
||||
self._input_real_cond = self._input_real_cond.cuda(self._gpu_ids[0], async=True)
|
||||
self._input_desired_cond = self._input_desired_cond.cuda(self._gpu_ids[0], async=True)
|
||||
|
||||
def set_train(self):
|
||||
self._G.train()
|
||||
self._D.train()
|
||||
self._is_train = True
|
||||
|
||||
def set_eval(self):
|
||||
self._G.eval()
|
||||
self._is_train = False
|
||||
|
||||
# get image paths
|
||||
def get_image_paths(self):
|
||||
return OrderedDict([('real_img', self._input_real_img_path)])
|
||||
|
||||
def forward(self, keep_data_for_visuals=False, return_estimates=False):
|
||||
if not self._is_train:
|
||||
# convert tensor to variables
|
||||
real_img = Variable(self._input_real_img, volatile=True)
|
||||
real_cond = Variable(self._input_real_cond, volatile=True)
|
||||
desired_cond = Variable(self._input_desired_cond, volatile=True)
|
||||
|
||||
# generate fake images
|
||||
fake_imgs, fake_img_mask = self._G.forward(real_img, desired_cond)
|
||||
fake_img_mask = self._do_if_necessary_saturate_mask(fake_img_mask, saturate=self._opt.do_saturate_mask)
|
||||
fake_imgs_masked = fake_img_mask * real_img + (1 - fake_img_mask) * fake_imgs
|
||||
|
||||
rec_real_img_rgb, rec_real_img_mask = self._G.forward(fake_imgs_masked, real_cond)
|
||||
rec_real_img_mask = self._do_if_necessary_saturate_mask(rec_real_img_mask, saturate=self._opt.do_saturate_mask)
|
||||
rec_real_imgs = rec_real_img_mask * fake_imgs_masked + (1 - rec_real_img_mask) * rec_real_img_rgb
|
||||
|
||||
imgs = None
|
||||
data = None
|
||||
if return_estimates:
|
||||
# normalize mask for better visualization
|
||||
fake_img_mask_max = fake_imgs_masked.view(fake_img_mask.size(0), -1).max(-1)[0]
|
||||
fake_img_mask_max = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(fake_img_mask_max, -1), -1), -1)
|
||||
# fake_img_mask_norm = fake_img_mask / fake_img_mask_max
|
||||
fake_img_mask_norm = fake_img_mask
|
||||
|
||||
# generate images
|
||||
im_real_img = util.tensor2im(real_img.data)
|
||||
im_fake_imgs = util.tensor2im(fake_imgs.data)
|
||||
im_fake_img_mask_norm = util.tensor2maskim(fake_img_mask_norm.data)
|
||||
im_fake_imgs_masked = util.tensor2im(fake_imgs_masked.data)
|
||||
im_rec_imgs = util.tensor2im(rec_real_img_rgb.data)
|
||||
im_rec_img_mask_norm = util.tensor2maskim(rec_real_img_mask.data)
|
||||
im_rec_imgs_masked = util.tensor2im(rec_real_imgs.data)
|
||||
im_concat_img = np.concatenate([im_real_img, im_fake_imgs_masked, im_fake_img_mask_norm, im_fake_imgs,
|
||||
im_rec_imgs, im_rec_img_mask_norm, im_rec_imgs_masked],
|
||||
1)
|
||||
|
||||
im_real_img_batch = util.tensor2im(real_img.data, idx=-1, nrows=1)
|
||||
im_fake_imgs_batch = util.tensor2im(fake_imgs.data, idx=-1, nrows=1)
|
||||
im_fake_img_mask_norm_batch = util.tensor2maskim(fake_img_mask_norm.data, idx=-1, nrows=1)
|
||||
im_fake_imgs_masked_batch = util.tensor2im(fake_imgs_masked.data, idx=-1, nrows=1)
|
||||
im_concat_img_batch = np.concatenate([im_real_img_batch, im_fake_imgs_masked_batch,
|
||||
im_fake_img_mask_norm_batch, im_fake_imgs_batch],
|
||||
1)
|
||||
|
||||
imgs = OrderedDict([('real_img', im_real_img),
|
||||
('fake_imgs', im_fake_imgs),
|
||||
('fake_img_mask', im_fake_img_mask_norm),
|
||||
('fake_imgs_masked', im_fake_imgs_masked),
|
||||
('concat', im_concat_img),
|
||||
('real_img_batch', im_real_img_batch),
|
||||
('fake_imgs_batch', im_fake_imgs_batch),
|
||||
('fake_img_mask_batch', im_fake_img_mask_norm_batch),
|
||||
('fake_imgs_masked_batch', im_fake_imgs_masked_batch),
|
||||
('concat_batch', im_concat_img_batch),
|
||||
])
|
||||
|
||||
data = OrderedDict([('real_path', self._input_real_img_path),
|
||||
('desired_cond', desired_cond.data[0, ...].cpu().numpy().astype('str'))
|
||||
])
|
||||
|
||||
# keep data for visualization
|
||||
if keep_data_for_visuals:
|
||||
self._vis_real_img = util.tensor2im(self._input_real_img)
|
||||
self._vis_fake_img_unmasked = util.tensor2im(fake_imgs.data)
|
||||
self._vis_fake_img = util.tensor2im(fake_imgs_masked.data)
|
||||
self._vis_fake_img_mask = util.tensor2maskim(fake_img_mask.data)
|
||||
self._vis_real_cond = self._input_real_cond.cpu()[0, ...].numpy()
|
||||
self._vis_desired_cond = self._input_desired_cond.cpu()[0, ...].numpy()
|
||||
self._vis_batch_real_img = util.tensor2im(self._input_real_img, idx=-1)
|
||||
self._vis_batch_fake_img_mask = util.tensor2maskim(fake_img_mask.data, idx=-1)
|
||||
self._vis_batch_fake_img = util.tensor2im(fake_imgs_masked.data, idx=-1)
|
||||
|
||||
return imgs, data
|
||||
|
||||
def optimize_parameters(self, train_generator=True, keep_data_for_visuals=False):
|
||||
if self._is_train:
|
||||
# convert tensor to variables
|
||||
self._B = self._input_real_img.size(0)
|
||||
self._real_img = Variable(self._input_real_img)
|
||||
self._real_cond = Variable(self._input_real_cond)
|
||||
self._desired_cond = Variable(self._input_desired_cond)
|
||||
|
||||
# train D
|
||||
loss_D, fake_imgs_masked = self._forward_D()
|
||||
self._optimizer_D.zero_grad()
|
||||
loss_D.backward()
|
||||
self._optimizer_D.step()
|
||||
|
||||
loss_D_gp= self._gradinet_penalty_D(fake_imgs_masked)
|
||||
self._optimizer_D.zero_grad()
|
||||
loss_D_gp.backward()
|
||||
self._optimizer_D.step()
|
||||
|
||||
# train G
|
||||
if train_generator:
|
||||
loss_G = self._forward_G(keep_data_for_visuals)
|
||||
self._optimizer_G.zero_grad()
|
||||
loss_G.backward()
|
||||
self._optimizer_G.step()
|
||||
|
||||
def _forward_G(self, keep_data_for_visuals):
|
||||
# generate fake images
|
||||
fake_imgs, fake_img_mask = self._G.forward(self._real_img, self._desired_cond)
|
||||
fake_img_mask = self._do_if_necessary_saturate_mask(fake_img_mask, saturate=self._opt.do_saturate_mask)
|
||||
fake_imgs_masked = fake_img_mask * self._real_img + (1 - fake_img_mask) * fake_imgs
|
||||
|
||||
# D(G(Ic1, c2)*M) masked
|
||||
d_fake_desired_img_masked_prob, d_fake_desired_img_masked_cond = self._D.forward(fake_imgs_masked)
|
||||
self._loss_g_masked_fake = self._compute_loss_D(d_fake_desired_img_masked_prob, True) * self._opt.lambda_D_prob
|
||||
self._loss_g_masked_cond = self._criterion_D_cond(d_fake_desired_img_masked_cond, self._desired_cond) / self._B * self._opt.lambda_D_cond
|
||||
|
||||
# G(G(Ic1,c2), c1)
|
||||
rec_real_img_rgb, rec_real_img_mask = self._G.forward(fake_imgs_masked, self._real_cond)
|
||||
rec_real_img_mask = self._do_if_necessary_saturate_mask(rec_real_img_mask, saturate=self._opt.do_saturate_mask)
|
||||
rec_real_imgs = rec_real_img_mask * fake_imgs_masked + (1 - rec_real_img_mask) * rec_real_img_rgb
|
||||
|
||||
# l_cyc(G(G(Ic1,c2), c1)*M)
|
||||
self._loss_g_cyc = self._criterion_cycle(rec_real_imgs, self._real_img) * self._opt.lambda_cyc
|
||||
|
||||
# loss mask
|
||||
self._loss_g_mask_1 = torch.mean(fake_img_mask) * self._opt.lambda_mask
|
||||
self._loss_g_mask_2 = torch.mean(rec_real_img_mask) * self._opt.lambda_mask
|
||||
self._loss_g_mask_1_smooth = self._compute_loss_smooth(fake_img_mask) * self._opt.lambda_mask_smooth
|
||||
self._loss_g_mask_2_smooth = self._compute_loss_smooth(rec_real_img_mask) * self._opt.lambda_mask_smooth
|
||||
|
||||
# keep data for visualization
|
||||
if keep_data_for_visuals:
|
||||
self._vis_real_img = util.tensor2im(self._input_real_img)
|
||||
self._vis_fake_img_unmasked = util.tensor2im(fake_imgs.data)
|
||||
self._vis_fake_img = util.tensor2im(fake_imgs_masked.data)
|
||||
self._vis_fake_img_mask = util.tensor2maskim(fake_img_mask.data)
|
||||
self._vis_real_cond = self._input_real_cond.cpu()[0, ...].numpy()
|
||||
self._vis_desired_cond = self._input_desired_cond.cpu()[0, ...].numpy()
|
||||
self._vis_batch_real_img = util.tensor2im(self._input_real_img, idx=-1)
|
||||
self._vis_batch_fake_img_mask = util.tensor2maskim(fake_img_mask.data, idx=-1)
|
||||
self._vis_batch_fake_img = util.tensor2im(fake_imgs_masked.data, idx=-1)
|
||||
self._vis_rec_img_unmasked = util.tensor2im(rec_real_img_rgb.data)
|
||||
self._vis_rec_real_img = util.tensor2im(rec_real_imgs.data)
|
||||
self._vis_rec_real_img_mask = util.tensor2maskim(rec_real_img_mask.data)
|
||||
self._vis_batch_rec_real_img = util.tensor2im(rec_real_imgs.data, idx=-1)
|
||||
|
||||
# combine losses
|
||||
return self._loss_g_masked_fake + self._loss_g_masked_cond + \
|
||||
self._loss_g_cyc + \
|
||||
self._loss_g_mask_1 + self._loss_g_mask_2 + \
|
||||
self._loss_g_mask_1_smooth + self._loss_g_mask_2_smooth
|
||||
|
||||
def _forward_D(self):
|
||||
# generate fake images
|
||||
fake_imgs, fake_img_mask = self._G.forward(self._real_img, self._desired_cond)
|
||||
fake_img_mask = self._do_if_necessary_saturate_mask(fake_img_mask, saturate=self._opt.do_saturate_mask)
|
||||
fake_imgs_masked = fake_img_mask * self._real_img + (1 - fake_img_mask) * fake_imgs
|
||||
|
||||
# D(real_I)
|
||||
d_real_img_prob, d_real_img_cond = self._D.forward(self._real_img)
|
||||
self._loss_d_real = self._compute_loss_D(d_real_img_prob, True) * self._opt.lambda_D_prob
|
||||
self._loss_d_cond = self._criterion_D_cond(d_real_img_cond, self._real_cond) / self._B * self._opt.lambda_D_cond
|
||||
|
||||
# D(fake_I)
|
||||
d_fake_desired_img_prob, _ = self._D.forward(fake_imgs_masked.detach())
|
||||
self._loss_d_fake = self._compute_loss_D(d_fake_desired_img_prob, False) * self._opt.lambda_D_prob
|
||||
|
||||
# combine losses
|
||||
return self._loss_d_real + self._loss_d_cond + self._loss_d_fake, fake_imgs_masked
|
||||
|
||||
def _gradinet_penalty_D(self, fake_imgs_masked):
|
||||
# interpolate sample
|
||||
alpha = torch.rand(self._B, 1, 1, 1).cuda().expand_as(self._real_img)
|
||||
interpolated = Variable(alpha * self._real_img.data + (1 - alpha) * fake_imgs_masked.data, requires_grad=True)
|
||||
interpolated_prob, _ = self._D(interpolated)
|
||||
|
||||
# compute gradients
|
||||
grad = torch.autograd.grad(outputs=interpolated_prob,
|
||||
inputs=interpolated,
|
||||
grad_outputs=torch.ones(interpolated_prob.size()).cuda(),
|
||||
retain_graph=True,
|
||||
create_graph=True,
|
||||
only_inputs=True)[0]
|
||||
|
||||
# penalize gradients
|
||||
grad = grad.view(grad.size(0), -1)
|
||||
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
|
||||
self._loss_d_gp = torch.mean((grad_l2norm - 1) ** 2) * self._opt.lambda_D_gp
|
||||
|
||||
return self._loss_d_gp
|
||||
|
||||
def _compute_loss_D(self, estim, is_real):
|
||||
return -torch.mean(estim) if is_real else torch.mean(estim)
|
||||
|
||||
def _compute_loss_smooth(self, mat):
|
||||
return torch.sum(torch.abs(mat[:, :, :, :-1] - mat[:, :, :, 1:])) + \
|
||||
torch.sum(torch.abs(mat[:, :, :-1, :] - mat[:, :, 1:, :]))
|
||||
|
||||
def get_current_errors(self):
|
||||
loss_dict = OrderedDict([('g_fake', self._loss_g_fake.data[0]),
|
||||
('g_cond', self._loss_g_cond.data[0]),
|
||||
('g_mskd_fake', self._loss_g_masked_fake.data[0]),
|
||||
('g_mskd_cond', self._loss_g_masked_cond.data[0]),
|
||||
('g_cyc', self._loss_g_cyc.data[0]),
|
||||
('g_rgb', self._loss_rec_real_img_rgb.data[0]),
|
||||
('g_rgb_un', self._loss_g_unmasked_rgb.data[0]),
|
||||
('g_rgb_s', self._loss_g_fake_imgs_smooth.data[0]),
|
||||
('g_m1', self._loss_g_mask_1.data[0]),
|
||||
('g_m2', self._loss_g_mask_2.data[0]),
|
||||
('g_m1_s', self._loss_g_mask_1_smooth.data[0]),
|
||||
('g_m2_s', self._loss_g_mask_2_smooth.data[0]),
|
||||
('g_idt', self._loss_g_idt.data[0]),
|
||||
('d_real', self._loss_d_real.data[0]),
|
||||
('d_cond', self._loss_d_cond.data[0]),
|
||||
('d_fake', self._loss_d_fake.data[0]),
|
||||
('d_gp', self._loss_d_gp.data[0])])
|
||||
|
||||
return loss_dict
|
||||
|
||||
def get_current_scalars(self):
|
||||
return OrderedDict([('lr_G', self._current_lr_G), ('lr_D', self._current_lr_D)])
|
||||
|
||||
def get_current_visuals(self):
|
||||
# visuals return dictionary
|
||||
visuals = OrderedDict()
|
||||
|
||||
# input visuals
|
||||
title_input_img = os.path.basename(self._input_real_img_path[0])
|
||||
visuals['1_input_img'] = plot_utils.plot_au(self._vis_real_img, self._vis_real_cond, title=title_input_img)
|
||||
visuals['2_fake_img'] = plot_utils.plot_au(self._vis_fake_img, self._vis_desired_cond)
|
||||
visuals['3_rec_real_img'] = plot_utils.plot_au(self._vis_rec_real_img, self._vis_real_cond)
|
||||
visuals['4_fake_img_unmasked'] = self._vis_fake_img_unmasked
|
||||
visuals['5_fake_img_mask'] = self._vis_fake_img_mask
|
||||
visuals['6_rec_real_img_mask'] = self._vis_rec_real_img_mask
|
||||
visuals['7_cyc_img_unmasked'] = self._vis_fake_img_unmasked
|
||||
# visuals['8_fake_img_mask_sat'] = self._vis_fake_img_mask_saturated
|
||||
# visuals['9_rec_real_img_mask_sat'] = self._vis_rec_real_img_mask_saturated
|
||||
visuals['10_batch_real_img'] = self._vis_batch_real_img
|
||||
visuals['11_batch_fake_img'] = self._vis_batch_fake_img
|
||||
visuals['12_batch_fake_img_mask'] = self._vis_batch_fake_img_mask
|
||||
# visuals['11_idt_img'] = self._vis_idt_img
|
||||
|
||||
return visuals
|
||||
|
||||
def save(self, label):
|
||||
# save networks
|
||||
self._save_network(self._G, 'G', label)
|
||||
self._save_network(self._D, 'D', label)
|
||||
|
||||
# save optimizers
|
||||
self._save_optimizer(self._optimizer_G, 'G', label)
|
||||
self._save_optimizer(self._optimizer_D, 'D', label)
|
||||
|
||||
def load(self):
|
||||
load_epoch = self._opt.load_epoch
|
||||
|
||||
# load G
|
||||
self._load_network(self._G, 'G', load_epoch)
|
||||
|
||||
if self._is_train:
|
||||
# load D
|
||||
self._load_network(self._D, 'D', load_epoch)
|
||||
|
||||
# load optimizers
|
||||
self._load_optimizer(self._optimizer_G, 'G', load_epoch)
|
||||
self._load_optimizer(self._optimizer_D, 'D', load_epoch)
|
||||
|
||||
def update_learning_rate(self):
|
||||
# updated learning rate G
|
||||
lr_decay_G = self._opt.lr_G / self._opt.nepochs_decay
|
||||
self._current_lr_G -= lr_decay_G
|
||||
for param_group in self._optimizer_G.param_groups:
|
||||
param_group['lr'] = self._current_lr_G
|
||||
print('update G learning rate: %f -> %f' % (self._current_lr_G + lr_decay_G, self._current_lr_G))
|
||||
|
||||
# update learning rate D
|
||||
lr_decay_D = self._opt.lr_D / self._opt.nepochs_decay
|
||||
self._current_lr_D -= lr_decay_D
|
||||
for param_group in self._optimizer_D.param_groups:
|
||||
param_group['lr'] = self._current_lr_D
|
||||
print('update D learning rate: %f -> %f' % (self._current_lr_D + lr_decay_D, self._current_lr_D))
|
||||
|
||||
def _l1_loss_with_target_gradients(self, input, target):
|
||||
return torch.sum(torch.abs(input - target)) / input.data.nelement()
|
||||
|
||||
def _do_if_necessary_saturate_mask(self, m, saturate=False):
|
||||
return torch.clamp(0.55*torch.tanh(3*(m-0.5))+0.5, 0, 1) if saturate else m
|
||||
@@ -1,132 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
class ModelsFactory:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_by_name(model_name, *args, **kwargs):
|
||||
model = None
|
||||
|
||||
if model_name == 'ganimation':
|
||||
from .ganimation import GANimation
|
||||
model = GANimation(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError("Model %s not recognized." % model_name)
|
||||
|
||||
print("Model %s was created" % model.name)
|
||||
return model
|
||||
|
||||
|
||||
class BaseModel(object):
|
||||
|
||||
def __init__(self, opt):
|
||||
self._name = 'BaseModel'
|
||||
|
||||
self._opt = opt
|
||||
self._gpu_ids = opt.gpu_ids
|
||||
self._is_train = opt.is_train
|
||||
|
||||
self._Tensor = torch.cuda.FloatTensor if self._gpu_ids else torch.Tensor
|
||||
self._save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def is_train(self):
|
||||
return self._is_train
|
||||
|
||||
def set_input(self, input):
|
||||
assert False, "set_input not implemented"
|
||||
|
||||
def set_train(self):
|
||||
assert False, "set_train not implemented"
|
||||
|
||||
def set_eval(self):
|
||||
assert False, "set_eval not implemented"
|
||||
|
||||
def forward(self, keep_data_for_visuals=False):
|
||||
assert False, "forward not implemented"
|
||||
|
||||
# used in test time, no backprop
|
||||
def test(self):
|
||||
assert False, "test not implemented"
|
||||
|
||||
def get_image_paths(self):
|
||||
return {}
|
||||
|
||||
def optimize_parameters(self):
|
||||
assert False, "optimize_parameters not implemented"
|
||||
|
||||
def get_current_visuals(self):
|
||||
return {}
|
||||
|
||||
def get_current_errors(self):
|
||||
return {}
|
||||
|
||||
def get_current_scalars(self):
|
||||
return {}
|
||||
|
||||
def save(self, label):
|
||||
assert False, "save not implemented"
|
||||
|
||||
def load(self):
|
||||
assert False, "load not implemented"
|
||||
|
||||
def _save_optimizer(self, optimizer, optimizer_label, epoch_label):
|
||||
save_filename = 'opt_epoch_%s_id_%s.pth' % (epoch_label, optimizer_label)
|
||||
save_path = os.path.join(self._save_dir, save_filename)
|
||||
torch.save(optimizer.state_dict(), save_path)
|
||||
|
||||
def _load_optimizer(self, optimizer, optimizer_label, epoch_label):
|
||||
load_filename = 'opt_epoch_%s_id_%s.pth' % (epoch_label, optimizer_label)
|
||||
load_path = os.path.join(self._save_dir, load_filename)
|
||||
assert os.path.exists(
|
||||
load_path), 'Weights file not found. Have you trained a model!? We are not providing one' % load_path
|
||||
|
||||
optimizer.load_state_dict(torch.load(load_path))
|
||||
print 'loaded optimizer: %s' % load_path
|
||||
|
||||
def _save_network(self, network, network_label, epoch_label):
|
||||
save_filename = 'net_epoch_%s_id_%s.pth' % (epoch_label, network_label)
|
||||
save_path = os.path.join(self._save_dir, save_filename)
|
||||
torch.save(network.state_dict(), save_path)
|
||||
print 'saved net: %s' % save_path
|
||||
|
||||
def _load_network(self, network, network_label, epoch_label):
|
||||
load_filename = 'net_epoch_%s_id_%s.pth' % (epoch_label, network_label)
|
||||
load_path = os.path.join(self._save_dir, load_filename)
|
||||
assert os.path.exists(
|
||||
load_path), 'Weights file not found. Have you trained a model!? We are not providing one' % load_path
|
||||
|
||||
network.load_state_dict(torch.load(load_path))
|
||||
print 'loaded net: %s' % load_path
|
||||
|
||||
def update_learning_rate(self):
|
||||
pass
|
||||
|
||||
def print_network(self, network):
|
||||
num_params = 0
|
||||
for param in network.parameters():
|
||||
num_params += param.numel()
|
||||
print(network)
|
||||
print('Total number of parameters: %d' % num_params)
|
||||
|
||||
def _get_scheduler(self, optimizer, opt):
|
||||
if opt.lr_policy == 'lambda':
|
||||
def lambda_rule(epoch):
|
||||
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
|
||||
return lr_l
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
||||
elif opt.lr_policy == 'step':
|
||||
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
||||
elif opt.lr_policy == 'plateau':
|
||||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
||||
else:
|
||||
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
||||
return scheduler
|
||||
@@ -1,102 +0,0 @@
|
||||
# GANimation
|
||||
|
||||
This repository contains an implementation of [GANimation](https://arxiv.org/pdf/1807.09251.pdf) by Pumarola et al. based on [StarGAN code](https://github.com/yunjey/stargan) by @yunjey. With this model they are able to modify in a continuous way facial expressions of single images.
|
||||
|
||||
[Pretrained models](https://www.dropbox.com/sh/108g19dk3gt1l7l/AAB4OJHHrMHlBDbNK8aFQVZSa?dl=0) and the [preprocessed CelebA dataset](https://www.dropbox.com/s/payjdk08292csra/celeba.zip?dl=0) are provided to facilitate the use of this model as well as the process for preparing other datasets for training this model.
|
||||
|
||||
<p align="center">
|
||||
<img width="170" height="170" src="https://github.com/vipermu/ganimation/blob/master/video_results/frida.gif">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img width="600" height="150" src="https://github.com/vipermu/ganimation/blob/master/video_results/eric_andre.gif">
|
||||
</p>
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
#### Conda environment
|
||||
Create your conda environment by just running the following command:
|
||||
`conda env create -f environment.yml`
|
||||
|
||||
|
||||
## Datasets
|
||||
|
||||
#### CelebA preprocessed dataset
|
||||
Download and unzip the *CelebA* preprocessed dataset uploaded to [this link](https://www.dropbox.com/s/payjdk08292csra/celeba.zip?dl=0) extracted from [MMLAB](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). Here you can find a folder containing the aligned and resized 128x128 images as well as a _txt_ file containing their respective Action Units vectors computed using [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace). By default, this code assumes that you have these two elements in _`./data/celeba/`_.
|
||||
|
||||
#### Use your own dataset
|
||||
If you want to use other datasets you will need to detect and crop bounding boxes around the face of each image, compute their corresponding Action Unit vectors and resize them to 128x128px.
|
||||
|
||||
You can perform all these steps using [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace). First you will need to setup the project. They provide guides for [linux](https://github.com/TadasBaltrusaitis/OpenFace/wiki/Unix-Installation) and [windows](https://github.com/TadasBaltrusaitis/OpenFace/wiki/Windows-Installation). Once the models are compiled, read their [Action Unit wiki](https://github.com/TadasBaltrusaitis/OpenFace/wiki/Action-Units) and their [documentation](https://github.com/TadasBaltrusaitis/OpenFace/wiki/Command-line-arguments) on these models to find out which is the command that you need to execute.
|
||||
|
||||
In my case the command was the following: `./build/bin/FaceLandmarkImg -fdir datasets/my-dataset/ -out_dir processed/my-processed-dataset/ -aus -simalign -au_static -nobadaligned -simsize 128 -format_aligned jpg -nomask`
|
||||
|
||||
After computing these Action Units, depending on the command that you have used, you will obtain different output formats. With the command that I used, I obtained a _csv_ file for each image containing its corresponding Action Units vector among extra information, a folder for each image containing the resized and cropped image and a _txt_ file with extra details about each image. You can find in _openface_utils_ folder the code that I used to extract all the Action Unit information in a _txt_ file and to group all the images into a single folder.
|
||||
|
||||
After having the Action Unit _txt_ file and the image folder you can move them to the directory of this project. By default, this code assumes that you have these two elements in _`./data/celeba/`_.
|
||||
|
||||
## Generate animations
|
||||
Pretrained models can be downloaded from [this](https://www.dropbox.com/sh/108g19dk3gt1l7l/AAB4OJHHrMHlBDbNK8aFQVZSa?dl=0) link. This folder contains the weights of both models (the Generator and the Discriminator) after training the model for 37 epochs.
|
||||
|
||||
By running `python main.py --mode animation` the default animation will be executed. There are two different types of animations already implemented which can be selected with the parameter 'animation_mode'. It is presuposed that the following folders are present:
|
||||
|
||||
- **attribute_images**: images from which the Action Units that we want to use for the animation were computed.
|
||||
- **images_to_animate**: images that we want to animate.
|
||||
- **pretrained_models**: pretrained models (only the generator is needed, you can download it from [here](https://www.dropbox.com/home/data/pretrained_models)
|
||||
- **results**: folder where the resulting images will be stored.
|
||||
- **attributes.txt**: file with the action units from 'attribute_images' computed.
|
||||
|
||||
The two options already implemented are the following:
|
||||
- **animate_image**: applies the expressions from 'attributes.txt' to the images in 'images_to_animate'.
|
||||
- **animate_random_batch**: applies the expressions from 'attributes.txt' to random batches of images from the training dataset.
|
||||
|
||||
|
||||
## Train the model
|
||||
|
||||
#### Parameters
|
||||
|
||||
You can either modify these parameters in `main.py` or by calling them as command line arguments.
|
||||
|
||||
|
||||
##### Lambdas
|
||||
|
||||
- *lambda_cls*: classification lambda.
|
||||
- *lambda_rec*: lambda for the cycle consistency loss.
|
||||
- *lambda_gp*: gradient penalty lambda.
|
||||
- *lambda_sat*: lambda for attention saturation loss.
|
||||
- *lambda_smooth*: lambda for attention smoothing loss.
|
||||
|
||||
##### Training parameters
|
||||
|
||||
- *c_dim*: number of Action Units to use to train the model.
|
||||
- *batch_size*
|
||||
- *num_epochs*
|
||||
- *num_epochs_decay*: number of epochs to start decaying the learning rate.
|
||||
- *g_lr*: generator's learning rate.
|
||||
- *d_lr*: discriminator's learning rate.
|
||||
|
||||
##### Pretrained models parameters
|
||||
The weights are stored in the following format: `<epoch>-<iteration>-<G/D>.ckpt` where G and D represent the Generator and the Discriminator respectively. We save the state of thoptimizers in the same format and extension but add '_optim'.
|
||||
|
||||
- *resume_iters*: iteration numbre from which we want to start the training. Note that we will need to have a saved model corresponding to that exact iteration number.
|
||||
- *first_epoch*: initial epoch for when we train from pretrained models.
|
||||
|
||||
##### Miscellaneous:
|
||||
- *mode*: train/test.
|
||||
- *image_dir*: path to your image folder.
|
||||
- *attr_path*: path to your attributes _txt_ folder.
|
||||
- *outputs_dir*: name for the output folder.
|
||||
|
||||
#### Virtual
|
||||
- *use_virtual*: this flag activates the use of _cycle consistency loss_ during the training.
|
||||
|
||||
## Virtual Cycle Consistency Loss
|
||||
The aim of this new component is to minimize the noise produced by the Action Unit regression. This idea was extracted from [Label-Noise Robust Multi-Domain Image-to-Image Translation](https://arxiv.org/abs/1905.02185) by Kaneko et al.. It is not proven that this new component improves the outcomes of the model but the masks seem to be darker when it is applied without losing realism on the output images.
|
||||
|
||||
## TODOs
|
||||
|
||||
- Clean Test function. (DONE)
|
||||
- Add an Action Units selector option for training.
|
||||
- Add multi-gpu support.
|
||||
- Smoother video generation.
|
||||
+9
-3
@@ -7,13 +7,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class LinfPGDAttack(object):
|
||||
def __init__(self, model=None, device=None, epsilon=0.05, k=1, a=0.05):
|
||||
def __init__(self, model=None, device=None, epsilon=0.05, k=1, a=0.05, feat = None):
|
||||
self.model = model
|
||||
self.epsilon = epsilon
|
||||
self.k = k
|
||||
self.a = a
|
||||
self.loss_fn = nn.MSELoss().to(device)
|
||||
self.device = device
|
||||
self.feat = feat
|
||||
|
||||
def perturb(self, X_nat, y, c_trg):
|
||||
"""
|
||||
@@ -25,7 +26,12 @@ class LinfPGDAttack(object):
|
||||
for i in range(self.k):
|
||||
# print(i)
|
||||
X.requires_grad = True
|
||||
output, _ = self.model(X, c_trg)
|
||||
output, feats = self.model(X, c_trg)
|
||||
|
||||
if self.feat:
|
||||
output = feats[self.feat]
|
||||
y = np.zeros(output.shape)
|
||||
y = torch.FloatTensor(y).to(self.device)
|
||||
|
||||
self.model.zero_grad()
|
||||
loss = self.loss_fn(output, y)
|
||||
@@ -39,7 +45,7 @@ class LinfPGDAttack(object):
|
||||
|
||||
self.model.zero_grad()
|
||||
|
||||
return X, (X_nat) - X # the eta here might be wrong!
|
||||
return X, eta
|
||||
|
||||
def clip_tensor(X, Y, Z):
|
||||
# Clip X with Y min and Z max
|
||||
|
||||
+9
-1
@@ -61,7 +61,15 @@ class Generator(nn.Module):
|
||||
c = c.view(c.size(0), c.size(1), 1, 1)
|
||||
c = c.repeat(1, 1, x.size(2), x.size(3))
|
||||
x = torch.cat([x, c], dim=1)
|
||||
return self.main(x)
|
||||
|
||||
feature_maps = []
|
||||
|
||||
# Get intermediate feature maps
|
||||
for layer in self.main:
|
||||
x = layer(x)
|
||||
feature_maps.append(x)
|
||||
|
||||
return x, feature_maps
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
|
||||
+11
-4
@@ -78,7 +78,8 @@ class Solver(object):
|
||||
"""Create a generator and a discriminator."""
|
||||
if self.dataset in ['CelebA', 'RaFD']:
|
||||
# self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
|
||||
self.G = AvgBlurGenerator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
|
||||
# self.G = AvgBlurGenerator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
|
||||
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
|
||||
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
|
||||
elif self.dataset in ['Both']:
|
||||
self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector.
|
||||
@@ -584,6 +585,9 @@ class Solver(object):
|
||||
l2_error = 0.0
|
||||
perceptual_error = 0.0
|
||||
n_samples = 0
|
||||
|
||||
# 11 layers
|
||||
layer_num = 0
|
||||
|
||||
for i, (x_real, c_org) in enumerate(data_loader):
|
||||
# Black image
|
||||
@@ -604,20 +608,23 @@ class Solver(object):
|
||||
|
||||
for c_trg in c_trg_list:
|
||||
# Attack
|
||||
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg)
|
||||
layer_num = (layer_num + 1) * 3 - 1
|
||||
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg, feat=layer_num)
|
||||
# x_adv = x_real + perturb
|
||||
# x_adv = self.blur_tensor(x_adv)
|
||||
|
||||
# Metrics
|
||||
with torch.no_grad():
|
||||
gen, preproc_x = self.G(x_adv, c_trg)
|
||||
# gen, preproc_x = self.G(x_adv, c_trg)
|
||||
gen, gen_feats = self.G(x_adv, c_trg)
|
||||
|
||||
# Add to lists
|
||||
x_fake_list.append(preproc_x)
|
||||
x_fake_list.append(gen)
|
||||
|
||||
# No Attack
|
||||
gen_noattack, _ = self.G(x_real, c_trg)
|
||||
# gen_noattack, _ = self.G(x_real, c_trg)
|
||||
gen_noattack, gen_noattack_feats = self.G(x_real, c_trg)
|
||||
|
||||
l1_error += F.l1_loss(gen, gen_noattack)
|
||||
l2_error += F.mse_loss(gen, gen_noattack)
|
||||
|
||||
Reference in New Issue
Block a user