Compare commits
44 Commits
512_beta
...
StevenCyb/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
45473fa490 | ||
|
|
dd1ecdd2a7 | ||
|
|
399367ee1d | ||
|
|
54aed6ce21 | ||
|
|
b364910d1d | ||
|
|
4c31e16841 | ||
|
|
75d2662414 | ||
|
|
0f4ea44655 | ||
|
|
5ef13b2863 | ||
|
|
b76dc0a54b | ||
|
|
2632de96e0 | ||
|
|
47e5c25351 | ||
|
|
b93603b899 | ||
|
|
3bd2c74388 | ||
|
|
35040b2c39 | ||
|
|
8112ca1549 | ||
|
|
b4ba93300b | ||
|
|
6a2e03d798 | ||
|
|
e48f32d872 | ||
|
|
bc5ac1ef22 | ||
|
|
7c44bc4b9a | ||
|
|
9f3daca179 | ||
|
|
7ed12d218f | ||
|
|
b893316e41 | ||
|
|
1bbc1eff67 | ||
|
|
5181115399 | ||
|
|
3bf03c8136 | ||
|
|
8617bea87c | ||
|
|
fe9fede9c5 | ||
|
|
f48dc8cf62 | ||
|
|
9492873690 | ||
|
|
6c5c0db17a | ||
|
|
44191913fc | ||
|
|
b3aef1ac3e | ||
|
|
a8cd175706 | ||
|
|
2008091ea6 | ||
|
|
25cd3f3399 | ||
|
|
793c1e26de | ||
|
|
9e620e250b | ||
|
|
5a1137c9c7 | ||
|
|
4b551af16c | ||
|
|
8b329f8fc2 | ||
|
|
050ed8a00a | ||
|
|
589e31ad9c |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -135,4 +135,8 @@ checkpoints/
|
||||
*.zip
|
||||
*.avi
|
||||
*.pdf
|
||||
*.pptx
|
||||
*.pptx
|
||||
|
||||
*.pth
|
||||
*.onnx
|
||||
wandb/
|
||||
15
.vscode/launch.json
vendored
15
.vscode/launch.json
vendored
@@ -1,15 +0,0 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "pwa-chrome",
|
||||
"request": "launch",
|
||||
"name": "Launch Chrome against localhost",
|
||||
"url": "http://localhost:8080",
|
||||
"webRoot": "${workspaceFolder}"
|
||||
}
|
||||
]
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
66
README.md
66
README.md
@@ -2,9 +2,14 @@
|
||||
## Proceedings of the 28th ACM International Conference on Multimedia
|
||||
**The official repository with Pytorch**
|
||||
|
||||
*Our method can realize **arbitrary face swapping** on images and videos with **one single trained model**.*
|
||||
**Our method can realize **arbitrary face swapping** on images and videos with **one single trained model**.**
|
||||
|
||||
Currently, only the test code is available. Training scripts are coming soon
|
||||
Training and test code are now available!
|
||||
[ <a href="https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/train.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/train.ipynb)
|
||||
|
||||
We are working with our incoming paper SimSwap++, keeping expecting!
|
||||
|
||||
The high resolution version of ***SimSwap-HQ*** is supported!
|
||||
|
||||
[](https://github.com/neuralchen/SimSwap)
|
||||
|
||||
@@ -16,12 +21,20 @@ Our paper can be downloaded from [[Arxiv]](https://arxiv.org/pdf/2106.06340v1.pd
|
||||
## Attention
|
||||
***This project is for technical and academic use only. Please do not apply it to illegal and unethical scenarios.***
|
||||
|
||||
***In the event of violation of the legal and ethical requirements of the user's country or region, this code repository is exempt from liability***
|
||||
|
||||
***Please do not ignore the content at the end of this README!***
|
||||
|
||||
If you find this project useful, please star it. It is the greatest appreciation of our work.
|
||||
|
||||
## Top News <img width=8% src="./docs/img/new.gif"/>
|
||||
|
||||
**`2022-04-21`**: For resource limited users, we provide the cropped VGGFace2-224 dataset [[Google Driver] VGGFace2-224 (10.8G)](https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing) [[Baidu Driver] ](https://pan.baidu.com/s/1OiwLJHVBSYB4AY2vEcfN0A) [Password: lrod].
|
||||
|
||||
**`2022-04-20`**: Training scripts are now available. We highly recommend that you guys train the simswap model with our released high quality dataset [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ).
|
||||
|
||||
**`2021-11-24`**: We have trained a beta version of ***SimSwap-HQ*** on [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ) and open sourced the checkpoint of this model (if you think the Simswap 512 is cool, please star our [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ) repo). Please don’t forget to go to [Preparation](./docs/guidance/preparation.md) and [Inference for image or video face swapping](./docs/guidance/usage.md) to check the latest set up.
|
||||
|
||||
**`2021-11-23`**: The google drive link of [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ) is released.
|
||||
|
||||
**`2021-11-17`**: We released a high resolution face dataset [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ) and the method to generate this dataset. This dataset is for research purpose.
|
||||
@@ -32,16 +45,6 @@ If you find this project useful, please star it. It is the greatest appreciation
|
||||
|
||||
**`2021-07-19`**: ***Obvious border abruptness has been resolved***. We add the ability to using mask and upgrade the old algorithm for better visual effect, please go to [Inference for image or video face swapping](./docs/guidance/usage.md) for details. Please don’t forget to go to [Preparation](./docs/guidance/preparation.md) to check the latest set up. (Thanks for the help from [@woctezuma](https://github.com/woctezuma) and [@instant-high](https://github.com/instant-high))
|
||||
|
||||
**`2021-07-04`**: A new Colab performing **multi specific** face video swapping has been added. You can check it out [here](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/MultiSpecific.ipynb)
|
||||
|
||||
**`2021-07-03`**: We add the scripts for **multi specific** face swapping, please go to [Inference for image or video face swapping](./docs/guidance/usage.md) for details.
|
||||
|
||||
**`2021-07-02`**: We add the scripts for designating a **specific** person in arbitrary video or image to change face, please go to [Inference for image or video face swapping](./docs/guidance/usage.md) for details.
|
||||
|
||||
**`2021-07-02`**: We have added a hyper parameter to allow users to choose whether to add the simswap logo as a watermark, please go to the section "About watermark of simswap logo" of [Inference for image or video face swapping](./docs/guidance/usage.md) for details.
|
||||
|
||||
**`2021-06-20`**: We release the scripts for arbitrary video and image processing, and a colab demo.
|
||||
|
||||
## The first open source high resolution dataset for face swapping!!!
|
||||
## High Resolution Dataset [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ)
|
||||
|
||||
@@ -58,7 +61,39 @@ If you find this project useful, please star it. It is the greatest appreciation
|
||||
- moviepy
|
||||
- insightface
|
||||
|
||||
## Usage
|
||||
## Training
|
||||
|
||||
[Preparation](./docs/guidance/preparation.md)
|
||||
|
||||
The training script is slightly different from the original version, e.g., we replace the patch discriminator with the projected discriminator, which saves a lot of hardware overhead and achieves slightly better results.
|
||||
|
||||
In order to ensure the normal training, the batch size must be greater than 1.
|
||||
|
||||
Friendly reminder, due to the difference in training settings, the user-trained model will have subtle differences in visual effects from the pre-trained model we provide.
|
||||
|
||||
- Train 224 models with VGGFace2 224*224 [[Google Driver] VGGFace2-224 (10.8G)](https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing) [[Baidu Driver] ](https://pan.baidu.com/s/1OiwLJHVBSYB4AY2vEcfN0A) [Password: lrod]
|
||||
|
||||
For faster convergence and better results, a large batch size (more than 16) is recommended!
|
||||
|
||||
***We recommend training more than 400K iterations (batch size is 16), 600K~800K will be better, more iterations will not be recommended.***
|
||||
|
||||
|
||||
```
|
||||
python train.py --name simswap224_test --batchSize 8 --gpu_ids 0 --dataset /path/to/VGGFace2HQ --Gdeep False
|
||||
```
|
||||
|
||||
[Colab demo for training 224 model](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/train.ipynb)
|
||||
|
||||
For faster convergence and better results, a large batch size (more than 16) is recommended!
|
||||
|
||||
- Train 512 models with VGGFace2-HQ 512*512 [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ).
|
||||
```
|
||||
python train.py --name simswap512_test --batchSize 16 --gpu_ids 0 --dataset /path/to/VGGFace2HQ --Gdeep True
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Inference with a pretrained SimSwap model
|
||||
[Preparation](./docs/guidance/preparation.md)
|
||||
|
||||
[Inference for image or video face swapping](./docs/guidance/usage.md)
|
||||
@@ -67,11 +102,10 @@ If you find this project useful, please star it. It is the greatest appreciation
|
||||
|
||||
<div style="background: yellow; width:140px; font-weight:bold;font-family: sans-serif;">Stronger feature</div>
|
||||
|
||||
[Colab fo switching specific faces in multi-face videos](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/MultiSpecific.ipynb)
|
||||
[Colab for switching specific faces in multi-face videos](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/MultiSpecific.ipynb)
|
||||
|
||||
[Image face swapping demo & Docker image on Replicate](https://replicate.ai/neuralchen/simswap-image)
|
||||
|
||||
Training: **coming soon**
|
||||
|
||||
|
||||
## Video
|
||||
@@ -131,8 +165,6 @@ For academic and non-commercial use only.The whole project is under the CC-BY-NC
|
||||
Yanhao Ge},
|
||||
title = {SimSwap: An Efficient Framework For High Fidelity Face Swapping},
|
||||
booktitle = {{MM} '20: The 28th {ACM} International Conference on Multimedia},
|
||||
pages = {2003--2011},
|
||||
publisher = {{ACM}},
|
||||
year = {2020}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -398,7 +398,7 @@
|
||||
"opt.isTrain = False\n",
|
||||
"opt.use_mask = True ## new feature up-to-date\n",
|
||||
"\n",
|
||||
"crop_size = 224\n",
|
||||
"crop_size = opt.crop_size\n",
|
||||
"\n",
|
||||
"torch.nn.Module.dump_patches = True\n",
|
||||
"model = create_model(opt)\n",
|
||||
@@ -420,7 +420,7 @@
|
||||
" img_id = img_id.cuda()\n",
|
||||
"\n",
|
||||
" #create latent id\n",
|
||||
" img_id_downsample = F.interpolate(img_id, scale_factor=0.5)\n",
|
||||
" img_id_downsample = F.interpolate(img_id, size=(112,112))\n",
|
||||
" latend_id = model.netArc(img_id_downsample)\n",
|
||||
" latend_id = latend_id.detach().to('cpu')\n",
|
||||
" latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True)\n",
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
class FaceDataSet(Dataset):
|
||||
def __init__(self, dataset_path, batch_size):
|
||||
super(FaceDataSet, self).__init__()
|
||||
|
||||
|
||||
|
||||
'''picture_dir_list = []
|
||||
for i in range(self.people_num):
|
||||
picture_dir_list.append('/data/home/renwangchen/vgg_align_224/'+self.people_list[i])
|
||||
|
||||
self.people_pic_list = []
|
||||
for i in range(self.people_num):
|
||||
pic_list = os.listdir(picture_dir_list[i])
|
||||
person_pic_list = []
|
||||
for j in range(len(pic_list)):
|
||||
pic_dir = os.path.join(picture_dir_list[i], pic_list[j])
|
||||
person_pic_list.append(pic_dir)
|
||||
self.people_pic_list.append(person_pic_list)'''
|
||||
|
||||
pic_dir = '/data/home/renwangchen/CelebA_224/'
|
||||
latent_dir = '/data/home/renwangchen/CelebA_latent/'
|
||||
|
||||
tmp_list = os.listdir(pic_dir)
|
||||
self.pic_list = []
|
||||
self.latent_list = []
|
||||
for i in range(len(tmp_list)):
|
||||
self.pic_list.append(pic_dir + tmp_list[i])
|
||||
self.latent_list.append(latent_dir + tmp_list[i][:-3] + 'npy')
|
||||
|
||||
self.pic_list = self.pic_list[:29984]
|
||||
'''for i in range(29984):
|
||||
print(self.pic_list[i])'''
|
||||
self.latent_list = self.latent_list[:29984]
|
||||
|
||||
self.people_num = len(self.pic_list)
|
||||
|
||||
self.type = 1
|
||||
self.bs = batch_size
|
||||
self.count = 0
|
||||
|
||||
self.transformer = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def __getitem__(self, index):
|
||||
p1 = random.randint(0, self.people_num - 1)
|
||||
p2 = p1
|
||||
|
||||
if self.type == 0:
|
||||
# load pictures from the same folder
|
||||
pass
|
||||
else:
|
||||
# load pictures from different folders
|
||||
p2 = p1
|
||||
while p2 == p1:
|
||||
p2 = random.randint(0, self.people_num - 1)
|
||||
|
||||
pic_id_dir = self.pic_list[p1]
|
||||
pic_att_dir = self.pic_list[p2]
|
||||
latent_id_dir = self.latent_list[p1]
|
||||
latent_att_dir = self.latent_list[p2]
|
||||
|
||||
img_id = Image.open(pic_id_dir).convert('RGB')
|
||||
img_id = self.transformer(img_id)
|
||||
latent_id = np.load(latent_id_dir)
|
||||
latent_id = latent_id / np.linalg.norm(latent_id)
|
||||
latent_id = torch.from_numpy(latent_id)
|
||||
|
||||
img_att = Image.open(pic_att_dir).convert('RGB')
|
||||
img_att = self.transformer(img_att)
|
||||
latent_att = np.load(latent_att_dir)
|
||||
latent_att = latent_att / np.linalg.norm(latent_att)
|
||||
latent_att = torch.from_numpy(latent_att)
|
||||
|
||||
self.count += 1
|
||||
data_type = self.type
|
||||
if self.count == self.bs:
|
||||
self.type = 1 - self.type
|
||||
self.count = 0
|
||||
|
||||
return img_id, img_att, latent_id, latent_att, data_type
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pic_list)
|
||||
@@ -1,76 +0,0 @@
|
||||
import os.path
|
||||
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
|
||||
from data.image_folder import make_dataset
|
||||
from PIL import Image
|
||||
|
||||
class AlignedDataset(BaseDataset):
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
self.root = opt.dataroot
|
||||
|
||||
### input A (label maps)
|
||||
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
|
||||
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
|
||||
self.A_paths = sorted(make_dataset(self.dir_A))
|
||||
|
||||
### input B (real images)
|
||||
if opt.isTrain or opt.use_encoded_image:
|
||||
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
|
||||
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
|
||||
self.B_paths = sorted(make_dataset(self.dir_B))
|
||||
|
||||
### instance maps
|
||||
if not opt.no_instance:
|
||||
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
|
||||
self.inst_paths = sorted(make_dataset(self.dir_inst))
|
||||
|
||||
### load precomputed instance-wise encoded features
|
||||
if opt.load_features:
|
||||
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
|
||||
print('----------- loading features from %s ----------' % self.dir_feat)
|
||||
self.feat_paths = sorted(make_dataset(self.dir_feat))
|
||||
|
||||
self.dataset_size = len(self.A_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
### input A (label maps)
|
||||
A_path = self.A_paths[index]
|
||||
A = Image.open(A_path)
|
||||
params = get_params(self.opt, A.size)
|
||||
if self.opt.label_nc == 0:
|
||||
transform_A = get_transform(self.opt, params)
|
||||
A_tensor = transform_A(A.convert('RGB'))
|
||||
else:
|
||||
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
|
||||
A_tensor = transform_A(A) * 255.0
|
||||
|
||||
B_tensor = inst_tensor = feat_tensor = 0
|
||||
### input B (real images)
|
||||
if self.opt.isTrain or self.opt.use_encoded_image:
|
||||
B_path = self.B_paths[index]
|
||||
B = Image.open(B_path).convert('RGB')
|
||||
transform_B = get_transform(self.opt, params)
|
||||
B_tensor = transform_B(B)
|
||||
|
||||
### if using instance maps
|
||||
if not self.opt.no_instance:
|
||||
inst_path = self.inst_paths[index]
|
||||
inst = Image.open(inst_path)
|
||||
inst_tensor = transform_A(inst)
|
||||
|
||||
if self.opt.load_features:
|
||||
feat_path = self.feat_paths[index]
|
||||
feat = Image.open(feat_path).convert('RGB')
|
||||
norm = normalize()
|
||||
feat_tensor = norm(transform_A(feat))
|
||||
|
||||
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
|
||||
'feat': feat_tensor, 'path': A_path}
|
||||
|
||||
return input_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize
|
||||
|
||||
def name(self):
|
||||
return 'AlignedDataset'
|
||||
@@ -1,90 +0,0 @@
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
class BaseDataset(data.Dataset):
|
||||
def __init__(self):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
def name(self):
|
||||
return 'BaseDataset'
|
||||
|
||||
def initialize(self, opt):
|
||||
pass
|
||||
|
||||
def get_params(opt, size):
|
||||
w, h = size
|
||||
new_h = h
|
||||
new_w = w
|
||||
if opt.resize_or_crop == 'resize_and_crop':
|
||||
new_h = new_w = opt.loadSize
|
||||
elif opt.resize_or_crop == 'scale_width_and_crop':
|
||||
new_w = opt.loadSize
|
||||
new_h = opt.loadSize * h // w
|
||||
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
|
||||
|
||||
flip = random.random() > 0.5
|
||||
return {'crop_pos': (x, y), 'flip': flip}
|
||||
|
||||
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
|
||||
transform_list = []
|
||||
if 'resize' in opt.resize_or_crop:
|
||||
osize = [opt.loadSize, opt.loadSize]
|
||||
transform_list.append(transforms.Scale(osize, method))
|
||||
elif 'scale_width' in opt.resize_or_crop:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
|
||||
|
||||
if 'crop' in opt.resize_or_crop:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
|
||||
|
||||
if opt.resize_or_crop == 'none':
|
||||
base = float(2 ** opt.n_downsample_global)
|
||||
if opt.netG == 'local':
|
||||
base *= (2 ** opt.n_local_enhancers)
|
||||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
||||
|
||||
if opt.isTrain and not opt.no_flip:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
||||
|
||||
transform_list += [transforms.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
||||
(0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
def normalize():
|
||||
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
|
||||
def __make_power_2(img, base, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
h = int(round(oh / base) * base)
|
||||
w = int(round(ow / base) * base)
|
||||
if (h == oh) and (w == ow):
|
||||
return img
|
||||
return img.resize((w, h), method)
|
||||
|
||||
def __scale_width(img, target_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
if (ow == target_width):
|
||||
return img
|
||||
w = target_width
|
||||
h = int(target_width * oh / ow)
|
||||
return img.resize((w, h), method)
|
||||
|
||||
def __crop(img, pos, size):
|
||||
ow, oh = img.size
|
||||
x1, y1 = pos
|
||||
tw = th = size
|
||||
if (ow > tw or oh > th):
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
return img
|
||||
|
||||
def __flip(img, flip):
|
||||
if flip:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
@@ -1,7 +0,0 @@
|
||||
|
||||
def CreateDataLoader(opt):
|
||||
from data.custom_dataset_data_loader import CustomDatasetDataLoader
|
||||
data_loader = CustomDatasetDataLoader()
|
||||
print(data_loader.name())
|
||||
data_loader.initialize(opt)
|
||||
return data_loader
|
||||
125
data/data_loader_Swapping.py
Normal file
125
data/data_loader_Swapping.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image
|
||||
from torch.utils import data
|
||||
from torchvision import transforms as T
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader):
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
self.num_images = len(loader)
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
except StopIteration:
|
||||
self.dataiter = iter(self.loader)
|
||||
self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.src_image1 = self.src_image1.cuda(non_blocking=True)
|
||||
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
|
||||
self.src_image2 = self.src_image2.cuda(non_blocking=True)
|
||||
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
|
||||
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
src_image1 = self.src_image1
|
||||
src_image2 = self.src_image2
|
||||
self.preload()
|
||||
return src_image1, src_image2
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
class SwappingDataset(data.Dataset):
|
||||
"""Dataset class for the Artworks dataset and content dataset."""
|
||||
|
||||
def __init__(self,
|
||||
image_dir,
|
||||
img_transform,
|
||||
subffix='jpg',
|
||||
random_seed=1234):
|
||||
"""Initialize and preprocess the Swapping dataset."""
|
||||
self.image_dir = image_dir
|
||||
self.img_transform = img_transform
|
||||
self.subffix = subffix
|
||||
self.dataset = []
|
||||
self.random_seed = random_seed
|
||||
self.preprocess()
|
||||
self.num_images = len(self.dataset)
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the Swapping dataset."""
|
||||
print("processing Swapping dataset images...")
|
||||
|
||||
temp_path = os.path.join(self.image_dir,'*/')
|
||||
pathes = glob.glob(temp_path)
|
||||
self.dataset = []
|
||||
for dir_item in pathes:
|
||||
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
|
||||
print("processing %s"%dir_item,end='\r')
|
||||
temp_list = []
|
||||
for item in join_path:
|
||||
temp_list.append(item)
|
||||
self.dataset.append(temp_list)
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.dataset)
|
||||
print('Finished preprocessing the Swapping dataset, total dirs number: %d...'%len(self.dataset))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return two src domain images and two dst domain images."""
|
||||
dir_tmp1 = self.dataset[index]
|
||||
dir_tmp1_len = len(dir_tmp1)
|
||||
|
||||
filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
|
||||
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
|
||||
image1 = self.img_transform(Image.open(filename1))
|
||||
image2 = self.img_transform(Image.open(filename2))
|
||||
return image1, image2
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
def GetLoader( dataset_roots,
|
||||
batch_size=16,
|
||||
dataloader_workers=8,
|
||||
random_seed = 1234
|
||||
):
|
||||
"""Build and return a data loader."""
|
||||
|
||||
num_workers = dataloader_workers
|
||||
data_root = dataset_roots
|
||||
random_seed = random_seed
|
||||
|
||||
c_transforms = []
|
||||
|
||||
c_transforms.append(T.ToTensor())
|
||||
c_transforms = T.Compose(c_transforms)
|
||||
|
||||
content_dataset = SwappingDataset(
|
||||
data_root,
|
||||
c_transforms,
|
||||
"jpg",
|
||||
random_seed)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
prefetcher = data_prefetcher(content_data_loader)
|
||||
return prefetcher
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
@@ -16,7 +16,7 @@ pip install insightface==0.2.1 onnxruntime moviepy
|
||||
- We use the face parsing from **[face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch)** for image postprocessing. Please download the relative file and place it in ./parsing_model/checkpoint from [this link](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view).
|
||||
- The pytorch and cuda versions above are most recommanded. They may vary.
|
||||
- Using insightface with different versions is not recommanded. Please use this specific version.
|
||||
- These settings are tested valid on both Windows and Ununtu.
|
||||
- These settings are tested valid on both Windows and Ubuntu.
|
||||
|
||||
### Pretrained model
|
||||
There are two archive files in the drive: **checkpoints.zip** and **arcface_checkpoint.tar**
|
||||
@@ -27,5 +27,11 @@ There are two archive files in the drive: **checkpoints.zip** and **arcface_chec
|
||||
[[Google Drive]](https://drive.google.com/drive/folders/1jV6_0FIMPC53FZ2HzZNJZGMe55bbu17R?usp=sharing)
|
||||
[[Baidu Drive]](https://pan.baidu.com/s/1wFV11RVZMHqd-ky4YpLdcA) Password: ```jd2v```
|
||||
|
||||
**Simswap 512 (optional)**
|
||||
|
||||
The checkpoint of **Simswap 512 beta version** has been uploaded in [Github release](https://github.com/neuralchen/SimSwap/releases/download/512_beta/512.zip).If you want to experience Simswap 512, feel free to try.
|
||||
- **Unzip 512.zip, place it in the root dir ./checkpoints**.
|
||||
|
||||
|
||||
### Note
|
||||
We expect users to have GPU with at least 3G memory. For those who do not, we provide [[Colab Notebook implementation]](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/SimSwap%20colab.ipynb).
|
||||
|
||||
@@ -14,28 +14,28 @@
|
||||
|
||||
### Simple face swapping for already face-aligned images
|
||||
```
|
||||
python test_one_image.py --isTrain false --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path crop_224/6.jpg --pic_b_path crop_224/ds.jpg --output_path output/
|
||||
python test_one_image.py --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path crop_224/6.jpg --pic_b_path crop_224/ds.jpg --output_path output/
|
||||
```
|
||||
|
||||
### Face swapping for video
|
||||
|
||||
- Swap only one face within the video(the one with highest confidence by face detection).
|
||||
```
|
||||
python test_video_swapsingle.py --isTrain false --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --video_path ./demo_file/multi_people_1080p.mp4 --output_path ./output/multi_test_swapsingle.mp4 --temp_path ./temp_results
|
||||
python test_video_swapsingle.py --crop_size 224 --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --video_path ./demo_file/multi_people_1080p.mp4 --output_path ./output/multi_test_swapsingle.mp4 --temp_path ./temp_results
|
||||
```
|
||||
- Swap all faces within the video.
|
||||
```
|
||||
python test_video_swapmulti.py --isTrain false --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --video_path ./demo_file/multi_people_1080p.mp4 --output_path ./output/multi_test_swapmulti.mp4 --temp_path ./temp_results
|
||||
python test_video_swapmulti.py --crop_size 224 --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --video_path ./demo_file/multi_people_1080p.mp4 --output_path ./output/multi_test_swapmulti.mp4 --temp_path ./temp_results
|
||||
```
|
||||
- Swap the ***specific*** face within the video.
|
||||
```
|
||||
python test_video_swapspecific.py --use_mask --pic_specific_path ./demo_file/specific1.png --isTrain false --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --video_path ./demo_file/multi_people_1080p.mp4 --output_path ./output/multi_test_specific.mp4 --temp_path ./temp_results
|
||||
python test_video_swapspecific.py --crop_size 224 --use_mask --pic_specific_path ./demo_file/specific1.png --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --video_path ./demo_file/multi_people_1080p.mp4 --output_path ./output/multi_test_specific.mp4 --temp_path ./temp_results
|
||||
```
|
||||
When changing the specified face, you need to give a picture of the person whose face is to be changed. Then assign the picture path to the argument "***--pic_specific_path***". This picture should be a front face and show the entire head and neck, which can help accurately change the face (if you still don’t know how to choose the picture, you can refer to the specific*.png of [./demo_file/](https://github.com/neuralchen/SimSwap/tree/main/demo_file)). It would be better if this picture was taken from the video to be changed.
|
||||
|
||||
- Swap ***multi specific*** face with **multi specific id** within the video.
|
||||
```
|
||||
python test_video_swap_multispecific.py --isTrain false --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --video_path ./demo_file/multi_people_1080p.mp4 --output_path ./output/multi_test_multispecific.mp4 --temp_path ./temp_results --multisepcific_dir ./demo_file/multispecific
|
||||
python test_video_swap_multispecific.py --crop_size 224 --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --video_path ./demo_file/multi_people_1080p.mp4 --output_path ./output/multi_test_multispecific.mp4 --temp_path ./temp_results --multisepcific_dir ./demo_file/multispecific
|
||||
```
|
||||
The folder you assign to ***"--multisepcific_dir"*** should be looked like:
|
||||
```
|
||||
@@ -56,26 +56,36 @@ The result is that the face corresponding to SRC_01.jpg (png) in the video will
|
||||
|
||||
- Swap only one face within one image(the one with highest confidence by face detection). The result would be saved to ./output/result_whole_swapsingle.jpg
|
||||
```
|
||||
python test_wholeimage_swapsingle.py --isTrain false --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/
|
||||
python test_wholeimage_swapsingle.py --crop_size 224 --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/
|
||||
```
|
||||
- Swap all faces within one image. The result would be saved to ./output/result_whole_swapmulti.jpg
|
||||
```
|
||||
python test_wholeimage_swapmulti.py --isTrain false --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/
|
||||
python test_wholeimage_swapmulti.py --crop_size 224 --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/
|
||||
```
|
||||
- Swap **specific** face within one image. The result would be saved to ./output/result_whole_swapspecific.jpg
|
||||
```
|
||||
python test_wholeimage_swapspecific.py --isTrain false --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/ --pic_specific_path ./demo_file/specific2.png
|
||||
python test_wholeimage_swapspecific.py --crop_size 224 --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/ --pic_specific_path ./demo_file/specific2.png
|
||||
```
|
||||
- Swap **multi specific** face with **multi specific id** within one image. The result would be saved to ./output/result_whole_swap_multispecific.jpg
|
||||
```
|
||||
python test_wholeimage_swap_multispecific.py --isTrain false --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/ --multisepcific_dir ./demo_file/multispecific
|
||||
python test_wholeimage_swap_multispecific.py --crop_size 224 --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/ --multisepcific_dir ./demo_file/multispecific
|
||||
```
|
||||
### About using Simswap 512 (beta version)
|
||||
We trained a beta version of Simswap 512 on [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ) and open sourced the model (if you think the Simswap 512 is cool, please star our [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ) repo).
|
||||
|
||||
The usage of applying Simswap 512 is to modify the value of the argument: "***--crop_size***" to 512 , take the command line of "Swap **multi specific** face with **multi specific id** within one image." as an example, the following command line can get the result without watermark:
|
||||
```
|
||||
python test_wholeimage_swap_multispecific.py --crop_size 512 --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/ --multisepcific_dir ./demo_file/multispecific
|
||||
```
|
||||
The effect of Simswap 512 is shown below.
|
||||
<img src="../img/result_whole_swap_multispecific_512.jpg"/>
|
||||
|
||||
### About watermark of simswap logo
|
||||
The above example command lines are to add the simswap logo as the watermark by default. After our discussion, we have added a hyper parameter to control whether to remove watermark.
|
||||
|
||||
The usage of removing the watermark is to add an argument: "***--no_simswaplogo***" to the command line, take the command line of "Swap all faces within one image" as an example, the following command line can get the result without watermark:
|
||||
```
|
||||
python test_wholeimage_swapmulti.py --no_simswaplogo --isTrain false --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/
|
||||
python test_wholeimage_swapmulti.py --no_simswaplogo --crop_size 224 --use_mask --name people --Arc_path arcface_model/arcface_checkpoint.tar --pic_a_path ./demo_file/Iron_man.jpg --pic_b_path ./demo_file/multi_people.jpg --output_path ./output/
|
||||
```
|
||||
### About using mask for better result
|
||||
We provide two methods to paste the face back to the original image after changing the face: Using mask or using bounding box. At present, the effect of using mask is the best. All the above code examples are using mask. If you want to use the bounding box, you only need to remove the --use_mask in the code example.
|
||||
@@ -88,18 +98,19 @@ Difference between using mask and not using mask can be found [here](https://img
|
||||
|
||||
|
||||
### Parameters
|
||||
| Parameters | Function |
|
||||
| :---- | :---- |
|
||||
| --name | The SimSwap training logs name |
|
||||
| --pic_a_path | Path of image with the target face |
|
||||
| --pic_b_path | Path of image with the source face to swap |
|
||||
| --pic_specific_path | Path of image with the specific face to be swapped |
|
||||
|--multisepcific_dir |Path of image folder for multi specific face swapping|
|
||||
| --video_path | Path of video with the source face to swap |
|
||||
| --temp_path | Path to store intermediate files |
|
||||
| --output_path | Path of directory to store the face swapping result |
|
||||
| --no_simswaplogo |The hyper parameter to control whether to remove watermark |
|
||||
| --use_mask |The hyper parameter to control whether to use face parsing for the better visual effects(I recommend to use)|
|
||||
| Parameters | Function |
|
||||
| :---- | :---- |
|
||||
| --name | The SimSwap training logs name |
|
||||
| --pic_a_path | Path of image with the target face |
|
||||
| --pic_b_path | Path of image with the source face to swap |
|
||||
| --pic_specific_path | Path of image with the specific face to be swapped |
|
||||
| --multisepcific_dir |Path of image folder for multi specific face swapping |
|
||||
| --video_path | Path of video with the source face to swap |
|
||||
| --temp_path | Path to store intermediate files |
|
||||
| --output_path | Path of directory to store the face swapping result |
|
||||
| --no_simswaplogo | The hyper parameter to control whether to remove watermark |
|
||||
| --use_mask | The hyper parameter to control whether to use face parsing for the better visual effects(I recommend to use) |
|
||||
| --skip_existing_frames | Skip frame index if already exist in temp_path (will not compare if the same video) |
|
||||
|
||||
### Note
|
||||
We expect users to have GPU with at least 3G memory.the For those who do not, we will provide Colab Notebook implementation in the future.
|
||||
|
||||
BIN
docs/img/result_whole_swap_multispecific_512.jpg
Normal file
BIN
docs/img/result_whole_swap_multispecific_512.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 335 KiB |
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 16:45:41
|
||||
Description:
|
||||
'''
|
||||
from __future__ import division
|
||||
import collections
|
||||
import numpy as np
|
||||
@@ -6,7 +14,7 @@ import os
|
||||
import os.path as osp
|
||||
import cv2
|
||||
from insightface.model_zoo import model_zoo
|
||||
from insightface.utils import face_align
|
||||
from insightface_func.utils import face_align_ffhqandnewarc as face_align
|
||||
|
||||
__all__ = ['Face_detect_crop', 'Face']
|
||||
|
||||
@@ -40,8 +48,9 @@ class Face_detect_crop:
|
||||
self.det_model = self.models['detection']
|
||||
|
||||
|
||||
def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
|
||||
def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640), mode ='None'):
|
||||
self.det_thresh = det_thresh
|
||||
self.mode = mode
|
||||
assert det_size is not None
|
||||
print('set det-size:', det_size)
|
||||
self.det_size = det_size
|
||||
@@ -73,7 +82,7 @@ class Face_detect_crop:
|
||||
kps = None
|
||||
if kpss is not None:
|
||||
kps = kpss[i]
|
||||
M, _ = face_align.estimate_norm(kps, crop_size, mode ='None')
|
||||
M, _ = face_align.estimate_norm(kps, crop_size, mode = self.mode)
|
||||
align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
|
||||
align_img_list.append(align_img)
|
||||
M_list.append(M)
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 16:46:04
|
||||
Description:
|
||||
'''
|
||||
from __future__ import division
|
||||
import collections
|
||||
import numpy as np
|
||||
@@ -6,7 +14,7 @@ import os
|
||||
import os.path as osp
|
||||
import cv2
|
||||
from insightface.model_zoo import model_zoo
|
||||
from insightface.utils import face_align
|
||||
from insightface_func.utils import face_align_ffhqandnewarc as face_align
|
||||
|
||||
__all__ = ['Face_detect_crop', 'Face']
|
||||
|
||||
@@ -40,8 +48,9 @@ class Face_detect_crop:
|
||||
self.det_model = self.models['detection']
|
||||
|
||||
|
||||
def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
|
||||
def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640), mode ='None'):
|
||||
self.det_thresh = det_thresh
|
||||
self.mode = mode
|
||||
assert det_size is not None
|
||||
print('set det-size:', det_size)
|
||||
self.det_size = det_size
|
||||
@@ -82,7 +91,7 @@ class Face_detect_crop:
|
||||
kps = None
|
||||
if kpss is not None:
|
||||
kps = kpss[best_index]
|
||||
M, _ = face_align.estimate_norm(kps, crop_size, mode ='None')
|
||||
M, _ = face_align.estimate_norm(kps, crop_size, mode = self.mode)
|
||||
align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
|
||||
|
||||
return [align_img], [M]
|
||||
|
||||
159
insightface_func/utils/face_align_ffhqandnewarc.py
Normal file
159
insightface_func/utils/face_align_ffhqandnewarc.py
Normal file
@@ -0,0 +1,159 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-15 19:42:42
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-15 20:01:47
|
||||
Description:
|
||||
'''
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage import transform as trans
|
||||
|
||||
src1 = np.array([[51.642, 50.115], [57.617, 49.990], [35.740, 69.007],
|
||||
[51.157, 89.050], [57.025, 89.702]],
|
||||
dtype=np.float32)
|
||||
#<--left
|
||||
src2 = np.array([[45.031, 50.118], [65.568, 50.872], [39.677, 68.111],
|
||||
[45.177, 86.190], [64.246, 86.758]],
|
||||
dtype=np.float32)
|
||||
|
||||
#---frontal
|
||||
src3 = np.array([[39.730, 51.138], [72.270, 51.138], [56.000, 68.493],
|
||||
[42.463, 87.010], [69.537, 87.010]],
|
||||
dtype=np.float32)
|
||||
|
||||
#-->right
|
||||
src4 = np.array([[46.845, 50.872], [67.382, 50.118], [72.737, 68.111],
|
||||
[48.167, 86.758], [67.236, 86.190]],
|
||||
dtype=np.float32)
|
||||
|
||||
#-->right profile
|
||||
src5 = np.array([[54.796, 49.990], [60.771, 50.115], [76.673, 69.007],
|
||||
[55.388, 89.702], [61.257, 89.050]],
|
||||
dtype=np.float32)
|
||||
|
||||
src = np.array([src1, src2, src3, src4, src5])
|
||||
src_map = src
|
||||
|
||||
ffhq_src = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
|
||||
[201.26117, 371.41043], [313.08905, 371.15118]])
|
||||
ffhq_src = np.expand_dims(ffhq_src, axis=0)
|
||||
|
||||
# arcface_src = np.array(
|
||||
# [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
|
||||
# [41.5493, 92.3655], [70.7299, 92.2041]],
|
||||
# dtype=np.float32)
|
||||
|
||||
# arcface_src = np.expand_dims(arcface_src, axis=0)
|
||||
|
||||
# In[66]:
|
||||
|
||||
|
||||
# lmk is prediction; src is template
|
||||
def estimate_norm(lmk, image_size=112, mode='ffhq'):
|
||||
assert lmk.shape == (5, 2)
|
||||
tform = trans.SimilarityTransform()
|
||||
lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1)
|
||||
min_M = []
|
||||
min_index = []
|
||||
min_error = float('inf')
|
||||
if mode == 'ffhq':
|
||||
# assert image_size == 112
|
||||
src = ffhq_src * image_size / 512
|
||||
else:
|
||||
src = src_map * image_size / 112
|
||||
for i in np.arange(src.shape[0]):
|
||||
tform.estimate(lmk, src[i])
|
||||
M = tform.params[0:2, :]
|
||||
results = np.dot(M, lmk_tran.T)
|
||||
results = results.T
|
||||
error = np.sum(np.sqrt(np.sum((results - src[i])**2, axis=1)))
|
||||
# print(error)
|
||||
if error < min_error:
|
||||
min_error = error
|
||||
min_M = M
|
||||
min_index = i
|
||||
return min_M, min_index
|
||||
|
||||
|
||||
def norm_crop(img, landmark, image_size=112, mode='ffhq'):
|
||||
if mode == 'Both':
|
||||
M_None, _ = estimate_norm(landmark, image_size, mode = 'newarc')
|
||||
M_ffhq, _ = estimate_norm(landmark, image_size, mode='ffhq')
|
||||
warped_None = cv2.warpAffine(img, M_None, (image_size, image_size), borderValue=0.0)
|
||||
warped_ffhq = cv2.warpAffine(img, M_ffhq, (image_size, image_size), borderValue=0.0)
|
||||
return warped_ffhq, warped_None
|
||||
else:
|
||||
M, pose_index = estimate_norm(landmark, image_size, mode)
|
||||
warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
|
||||
return warped
|
||||
|
||||
def square_crop(im, S):
|
||||
if im.shape[0] > im.shape[1]:
|
||||
height = S
|
||||
width = int(float(im.shape[1]) / im.shape[0] * S)
|
||||
scale = float(S) / im.shape[0]
|
||||
else:
|
||||
width = S
|
||||
height = int(float(im.shape[0]) / im.shape[1] * S)
|
||||
scale = float(S) / im.shape[1]
|
||||
resized_im = cv2.resize(im, (width, height))
|
||||
det_im = np.zeros((S, S, 3), dtype=np.uint8)
|
||||
det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
|
||||
return det_im, scale
|
||||
|
||||
|
||||
def transform(data, center, output_size, scale, rotation):
|
||||
scale_ratio = scale
|
||||
rot = float(rotation) * np.pi / 180.0
|
||||
#translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
|
||||
t1 = trans.SimilarityTransform(scale=scale_ratio)
|
||||
cx = center[0] * scale_ratio
|
||||
cy = center[1] * scale_ratio
|
||||
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
||||
t3 = trans.SimilarityTransform(rotation=rot)
|
||||
t4 = trans.SimilarityTransform(translation=(output_size / 2,
|
||||
output_size / 2))
|
||||
t = t1 + t2 + t3 + t4
|
||||
M = t.params[0:2]
|
||||
cropped = cv2.warpAffine(data,
|
||||
M, (output_size, output_size),
|
||||
borderValue=0.0)
|
||||
return cropped, M
|
||||
|
||||
|
||||
def trans_points2d(pts, M):
|
||||
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
||||
for i in range(pts.shape[0]):
|
||||
pt = pts[i]
|
||||
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
|
||||
new_pt = np.dot(M, new_pt)
|
||||
#print('new_pt', new_pt.shape, new_pt)
|
||||
new_pts[i] = new_pt[0:2]
|
||||
|
||||
return new_pts
|
||||
|
||||
|
||||
def trans_points3d(pts, M):
|
||||
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
|
||||
#print(scale)
|
||||
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
||||
for i in range(pts.shape[0]):
|
||||
pt = pts[i]
|
||||
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
|
||||
new_pt = np.dot(M, new_pt)
|
||||
#print('new_pt', new_pt.shape, new_pt)
|
||||
new_pts[i][0:2] = new_pt[0:2]
|
||||
new_pts[i][2] = pts[i][2] * scale
|
||||
|
||||
return new_pts
|
||||
|
||||
|
||||
def trans_points(pts, M):
|
||||
if pts.shape[1] == 2:
|
||||
return trans_points2d(pts, M)
|
||||
else:
|
||||
return trans_points3d(pts, M)
|
||||
|
||||
@@ -37,15 +37,21 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
def save(self, label):
|
||||
pass
|
||||
|
||||
|
||||
# helper saving function that can be used by subclasses
|
||||
def save_network(self, network, network_label, epoch_label, gpu_ids):
|
||||
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
||||
def save_network(self, network, network_label, epoch_label, gpu_ids=None):
|
||||
save_filename = '{}_net_{}.pth'.format(epoch_label, network_label)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
torch.save(network.cpu().state_dict(), save_path)
|
||||
if len(gpu_ids) and torch.cuda.is_available():
|
||||
if torch.cuda.is_available():
|
||||
network.cuda()
|
||||
|
||||
def save_optim(self, network, network_label, epoch_label, gpu_ids=None):
|
||||
save_filename = '{}_optim_{}.pth'.format(epoch_label, network_label)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
torch.save(network.state_dict(), save_path)
|
||||
|
||||
|
||||
# helper loading function that can be used by subclasses
|
||||
def load_network(self, network, network_label, epoch_label, save_dir=''):
|
||||
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
||||
@@ -63,6 +69,47 @@ class BaseModel(torch.nn.Module):
|
||||
except:
|
||||
pretrained_dict = torch.load(save_path)
|
||||
model_dict = network.state_dict()
|
||||
try:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
network.load_state_dict(pretrained_dict)
|
||||
if self.opt.verbose:
|
||||
print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
|
||||
except:
|
||||
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
|
||||
for k, v in pretrained_dict.items():
|
||||
if v.size() == model_dict[k].size():
|
||||
model_dict[k] = v
|
||||
|
||||
if sys.version_info >= (3,0):
|
||||
not_initialized = set()
|
||||
else:
|
||||
from sets import Set
|
||||
not_initialized = Set()
|
||||
|
||||
for k, v in model_dict.items():
|
||||
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
|
||||
not_initialized.add(k.split('.')[0])
|
||||
|
||||
print(sorted(not_initialized))
|
||||
network.load_state_dict(model_dict)
|
||||
|
||||
# helper loading function that can be used by subclasses
|
||||
def load_optim(self, network, network_label, epoch_label, save_dir=''):
|
||||
save_filename = '%s_optim_%s.pth' % (epoch_label, network_label)
|
||||
if not save_dir:
|
||||
save_dir = self.save_dir
|
||||
save_path = os.path.join(save_dir, save_filename)
|
||||
if not os.path.isfile(save_path):
|
||||
print('%s not exists yet!' % save_path)
|
||||
if network_label == 'G':
|
||||
raise('Generator must exist!')
|
||||
else:
|
||||
#network.load_state_dict(torch.load(save_path))
|
||||
try:
|
||||
network.load_state_dict(torch.load(save_path, map_location=torch.device("cpu")))
|
||||
except:
|
||||
pretrained_dict = torch.load(save_path, map_location=torch.device("cpu"))
|
||||
model_dict = network.state_dict()
|
||||
try:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
network.load_state_dict(pretrained_dict)
|
||||
|
||||
@@ -4,10 +4,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
from util.image_pool import ImagePool
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from .fs_networks import Generator_Adain_Upsample, Discriminator
|
||||
|
||||
class SpecificNorm(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
@@ -52,6 +50,11 @@ class fsModel(BaseModel):
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
if opt.crop_size == 224:
|
||||
from .fs_networks import Generator_Adain_Upsample, Discriminator
|
||||
elif opt.crop_size == 512:
|
||||
from .fs_networks_512 import Generator_Adain_Upsample, Discriminator
|
||||
|
||||
# Generator network
|
||||
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False)
|
||||
self.netG.to(device)
|
||||
@@ -197,7 +200,7 @@ class fsModel(BaseModel):
|
||||
|
||||
|
||||
#G_ID
|
||||
img_fake_down = F.interpolate(img_fake, scale_factor=0.5)
|
||||
img_fake_down = F.interpolate(img_fake, size=(112,112))
|
||||
img_fake_down = self.spNorm(img_fake_down)
|
||||
latent_fake = self.netArc(img_fake_down)
|
||||
loss_G_ID = (1 - self.cosin_metric(latent_fake, latent_id))
|
||||
|
||||
232
models/fs_networks_512.py
Normal file
232
models/fs_networks_512.py
Normal file
@@ -0,0 +1,232 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 16:55:48
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 16:58:06
|
||||
Description:
|
||||
'''
|
||||
"""
|
||||
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
||||
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class InstanceNorm(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
"""
|
||||
@notice: avoid in-place ops.
|
||||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
||||
"""
|
||||
super(InstanceNorm, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
x = x - torch.mean(x, (2, 3), True)
|
||||
tmp = torch.mul(x, x) # or x ** 2
|
||||
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
||||
return x * tmp
|
||||
|
||||
class ApplyStyle(nn.Module):
|
||||
"""
|
||||
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
||||
"""
|
||||
def __init__(self, latent_size, channels):
|
||||
super(ApplyStyle, self).__init__()
|
||||
self.linear = nn.Linear(latent_size, channels * 2)
|
||||
|
||||
def forward(self, x, latent):
|
||||
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
||||
shape = [-1, 2, x.size(1), 1, 1]
|
||||
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
||||
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
||||
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
|
||||
return x
|
||||
|
||||
class ResnetBlock_Adain(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Adain, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
if padding_type == 'reflect':
|
||||
conv1 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv1 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = ApplyStyle(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
conv2 = []
|
||||
if padding_type == 'reflect':
|
||||
conv2 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv2 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = ApplyStyle(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
y = self.conv1(x)
|
||||
y = self.style1(y, dlatents_in_slice)
|
||||
y = self.act1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.style2(y, dlatents_in_slice)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class Generator_Adain_Upsample(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
padding_type='reflect'):
|
||||
assert (n_blocks >= 0)
|
||||
super(Generator_Adain_Upsample, self).__init__()
|
||||
activation = nn.ReLU(True)
|
||||
self.deep = deep
|
||||
|
||||
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 32, kernel_size=7, padding=0),
|
||||
norm_layer(32), activation)
|
||||
### downsample
|
||||
self.down0 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(64), activation)
|
||||
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(128), activation)
|
||||
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(256), activation)
|
||||
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(512), activation)
|
||||
if self.deep:
|
||||
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(512), activation)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(n_blocks):
|
||||
BN += [
|
||||
ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
if self.deep:
|
||||
self.up4 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(512), activation
|
||||
)
|
||||
self.up3 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(256), activation
|
||||
)
|
||||
self.up2 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(128), activation
|
||||
)
|
||||
self.up1 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(64), activation
|
||||
)
|
||||
self.up0 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(32), activation
|
||||
)
|
||||
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(32, output_nc, kernel_size=7, padding=0),
|
||||
nn.Tanh())
|
||||
|
||||
def forward(self, input, dlatents):
|
||||
x = input # 3*224*224
|
||||
|
||||
skip0 = self.first_layer(x)
|
||||
skip1 = self.down0(skip0)
|
||||
skip2 = self.down1(skip1)
|
||||
skip3 = self.down2(skip2)
|
||||
if self.deep:
|
||||
skip4 = self.down3(skip3)
|
||||
x = self.down4(skip4)
|
||||
else:
|
||||
x = self.down3(skip3)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
x = self.BottleNeck[i](x, dlatents)
|
||||
|
||||
if self.deep:
|
||||
x = self.up4(x)
|
||||
x = self.up3(x)
|
||||
x = self.up2(x)
|
||||
x = self.up1(x)
|
||||
x = self.up0(x)
|
||||
x = self.last_layer(x)
|
||||
x = (x + 1) / 2
|
||||
|
||||
return x
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
kw = 4
|
||||
padw = 1
|
||||
self.down1 = nn.Sequential(
|
||||
nn.Conv2d(input_nc, 64, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)
|
||||
)
|
||||
self.down2 = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=kw, stride=2, padding=padw),
|
||||
norm_layer(128), nn.LeakyReLU(0.2, True)
|
||||
)
|
||||
self.down3 = nn.Sequential(
|
||||
nn.Conv2d(128, 256, kernel_size=kw, stride=2, padding=padw),
|
||||
norm_layer(256), nn.LeakyReLU(0.2, True)
|
||||
)
|
||||
self.down4 = nn.Sequential(
|
||||
nn.Conv2d(256, 512, kernel_size=kw, stride=2, padding=padw),
|
||||
norm_layer(512), nn.LeakyReLU(0.2, True)
|
||||
)
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(512, 512, kernel_size=kw, stride=1, padding=padw),
|
||||
norm_layer(512),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
)
|
||||
|
||||
if use_sigmoid:
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw), nn.Sigmoid()
|
||||
)
|
||||
else:
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw)
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
out = []
|
||||
x = self.down1(input)
|
||||
out.append(x)
|
||||
x = self.down2(x)
|
||||
out.append(x)
|
||||
x = self.down3(x)
|
||||
out.append(x)
|
||||
x = self.down4(x)
|
||||
out.append(x)
|
||||
x = self.conv1(x)
|
||||
out.append(x)
|
||||
x = self.conv2(x)
|
||||
out.append(x)
|
||||
|
||||
return out
|
||||
172
models/fs_networks_fix.py
Normal file
172
models/fs_networks_fix.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
||||
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class InstanceNorm(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
"""
|
||||
@notice: avoid in-place ops.
|
||||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
||||
"""
|
||||
super(InstanceNorm, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
x = x - torch.mean(x, (2, 3), True)
|
||||
tmp = torch.mul(x, x) # or x ** 2
|
||||
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
||||
return x * tmp
|
||||
|
||||
class ApplyStyle(nn.Module):
|
||||
"""
|
||||
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
||||
"""
|
||||
def __init__(self, latent_size, channels):
|
||||
super(ApplyStyle, self).__init__()
|
||||
self.linear = nn.Linear(latent_size, channels * 2)
|
||||
|
||||
def forward(self, x, latent):
|
||||
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
||||
shape = [-1, 2, x.size(1), 1, 1]
|
||||
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
||||
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
||||
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
|
||||
return x
|
||||
|
||||
class ResnetBlock_Adain(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Adain, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
if padding_type == 'reflect':
|
||||
conv1 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv1 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = ApplyStyle(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
conv2 = []
|
||||
if padding_type == 'reflect':
|
||||
conv2 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv2 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = ApplyStyle(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
y = self.conv1(x)
|
||||
y = self.style1(y, dlatents_in_slice)
|
||||
y = self.act1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.style2(y, dlatents_in_slice)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class Generator_Adain_Upsample(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
padding_type='reflect'):
|
||||
assert (n_blocks >= 0)
|
||||
super(Generator_Adain_Upsample, self).__init__()
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.deep = deep
|
||||
|
||||
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
|
||||
norm_layer(64), activation)
|
||||
### downsample
|
||||
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(128), activation)
|
||||
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(256), activation)
|
||||
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(512), activation)
|
||||
|
||||
if self.deep:
|
||||
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(512), activation)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(n_blocks):
|
||||
BN += [
|
||||
ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
if self.deep:
|
||||
self.up4 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
|
||||
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(512), activation
|
||||
)
|
||||
self.up3 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
|
||||
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(256), activation
|
||||
)
|
||||
self.up2 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
|
||||
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(128), activation
|
||||
)
|
||||
self.up1 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
|
||||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(64), activation
|
||||
)
|
||||
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0))
|
||||
|
||||
def forward(self, input, dlatents):
|
||||
x = input # 3*224*224
|
||||
|
||||
skip1 = self.first_layer(x)
|
||||
skip2 = self.down1(skip1)
|
||||
skip3 = self.down2(skip2)
|
||||
if self.deep:
|
||||
skip4 = self.down3(skip3)
|
||||
x = self.down4(skip4)
|
||||
else:
|
||||
x = self.down3(skip3)
|
||||
bot = []
|
||||
bot.append(x)
|
||||
features = []
|
||||
for i in range(len(self.BottleNeck)):
|
||||
x = self.BottleNeck[i](x, dlatents)
|
||||
bot.append(x)
|
||||
|
||||
if self.deep:
|
||||
x = self.up4(x)
|
||||
features.append(x)
|
||||
x = self.up3(x)
|
||||
features.append(x)
|
||||
x = self.up2(x)
|
||||
features.append(x)
|
||||
x = self.up1(x)
|
||||
features.append(x)
|
||||
x = self.last_layer(x)
|
||||
# x = (x + 1) / 2
|
||||
|
||||
# return x, bot, features, dlatents
|
||||
return x
|
||||
122
models/projected_model.py
Normal file
122
models/projected_model.py
Normal file
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: fs_model_fix_idnorm_donggp_saveoptim copy.py
|
||||
# Created Date: Wednesday January 12th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 21st April 2022 8:13:37 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .fs_networks_fix import Generator_Adain_Upsample
|
||||
|
||||
from pg_modules.projected_discriminator import ProjectedDiscriminator
|
||||
|
||||
def compute_grad2(d_out, x_in):
|
||||
batch_size = x_in.size(0)
|
||||
grad_dout = torch.autograd.grad(
|
||||
outputs=d_out.sum(), inputs=x_in,
|
||||
create_graph=True, retain_graph=True, only_inputs=True
|
||||
)[0]
|
||||
grad_dout2 = grad_dout.pow(2)
|
||||
assert(grad_dout2.size() == x_in.size())
|
||||
reg = grad_dout2.view(batch_size, -1).sum(1)
|
||||
return reg
|
||||
|
||||
class fsModel(BaseModel):
|
||||
def name(self):
|
||||
return 'fsModel'
|
||||
|
||||
def initialize(self, opt):
|
||||
BaseModel.initialize(self, opt)
|
||||
# if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
|
||||
self.isTrain = opt.isTrain
|
||||
|
||||
# Generator network
|
||||
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep)
|
||||
self.netG.cuda()
|
||||
|
||||
# Id network
|
||||
netArc_checkpoint = opt.Arc_path
|
||||
netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
|
||||
self.netArc = netArc_checkpoint['model'].module
|
||||
self.netArc = self.netArc.cuda()
|
||||
self.netArc.eval()
|
||||
self.netArc.requires_grad_(False)
|
||||
if not self.isTrain:
|
||||
pretrained_path = opt.checkpoints_dir
|
||||
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
|
||||
return
|
||||
self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
|
||||
# self.netD.feature_network.requires_grad_(False)
|
||||
self.netD.cuda()
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
# define loss functions
|
||||
self.criterionFeat = nn.L1Loss()
|
||||
self.criterionRec = nn.L1Loss()
|
||||
|
||||
|
||||
# initialize optimizers
|
||||
|
||||
# optimizer G
|
||||
params = list(self.netG.parameters())
|
||||
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
|
||||
|
||||
# optimizer D
|
||||
params = list(self.netD.parameters())
|
||||
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
|
||||
|
||||
# load networks
|
||||
if opt.continue_train:
|
||||
pretrained_path = '' if not self.isTrain else opt.load_pretrain
|
||||
# print (pretrained_path)
|
||||
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
|
||||
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
|
||||
self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
|
||||
self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def cosin_metric(self, x1, x2):
|
||||
#return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
|
||||
|
||||
|
||||
|
||||
def save(self, which_epoch):
|
||||
self.save_network(self.netG, 'G', which_epoch)
|
||||
self.save_network(self.netD, 'D', which_epoch)
|
||||
self.save_optim(self.optimizer_G, 'G', which_epoch)
|
||||
self.save_optim(self.optimizer_D, 'D', which_epoch)
|
||||
'''if self.gen_features:
|
||||
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
|
||||
|
||||
def update_fixed_params(self):
|
||||
# after fixing the global generator for a number of iterations, also start finetuning it
|
||||
params = list(self.netG.parameters())
|
||||
if self.gen_features:
|
||||
params += list(self.netE.parameters())
|
||||
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
|
||||
if self.opt.verbose:
|
||||
print('------------ Now also finetuning global generator -----------')
|
||||
|
||||
def update_learning_rate(self):
|
||||
lrd = self.opt.lr / self.opt.niter_decay
|
||||
lr = self.old_lr - lrd
|
||||
for param_group in self.optimizer_D.param_groups:
|
||||
param_group['lr'] = lr
|
||||
for param_group in self.optimizer_G.param_groups:
|
||||
param_group['lr'] = lr
|
||||
if self.opt.verbose:
|
||||
print('update learning rate: %f -> %f' % (self.old_lr, lr))
|
||||
self.old_lr = lr
|
||||
|
||||
|
||||
14
models/projectionhead.py
Normal file
14
models/projectionhead.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class ProjectionHead(nn.Module):
|
||||
def __init__(self, proj_dim=256):
|
||||
super(ProjectionHead, self).__init__()
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(proj_dim, proj_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(proj_dim, proj_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x)
|
||||
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-23 17:08:08
|
||||
Description:
|
||||
'''
|
||||
from .base_options import BaseOptions
|
||||
|
||||
class TestOptions(BaseOptions):
|
||||
@@ -25,6 +33,7 @@ class TestOptions(BaseOptions):
|
||||
self.parser.add_argument('--id_thres', type=float, default=0.03, help='how many test images to run')
|
||||
self.parser.add_argument('--no_simswaplogo', action='store_true', help='Remove the watermark')
|
||||
self.parser.add_argument('--use_mask', action='store_true', help='Use mask for better result')
|
||||
self.parser.add_argument('--crop_size', type=int, default=224, help='Crop of size of input image')
|
||||
self.parser.add_argument('--skip_existing_frames', action='store_true', help='Skip frame index if already exist in temp_path')
|
||||
|
||||
|
||||
self.isTrain = False
|
||||
self.isTrain = False
|
||||
325
pg_modules/blocks.py
Normal file
325
pg_modules/blocks.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
|
||||
### single layers
|
||||
|
||||
|
||||
def conv2d(*args, **kwargs):
|
||||
return spectral_norm(nn.Conv2d(*args, **kwargs))
|
||||
|
||||
|
||||
def convTranspose2d(*args, **kwargs):
|
||||
return spectral_norm(nn.ConvTranspose2d(*args, **kwargs))
|
||||
|
||||
|
||||
def embedding(*args, **kwargs):
|
||||
return spectral_norm(nn.Embedding(*args, **kwargs))
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
return spectral_norm(nn.Linear(*args, **kwargs))
|
||||
|
||||
|
||||
def NormLayer(c, mode='batch'):
|
||||
if mode == 'group':
|
||||
return nn.GroupNorm(c//2, c)
|
||||
elif mode == 'batch':
|
||||
return nn.BatchNorm2d(c)
|
||||
|
||||
|
||||
### Activations
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
def forward(self, x):
|
||||
nc = x.size(1)
|
||||
assert nc % 2 == 0, 'channels dont divide 2!'
|
||||
nc = int(nc/2)
|
||||
return x[:, :nc] * torch.sigmoid(x[:, nc:])
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def forward(self, feat):
|
||||
return feat * torch.sigmoid(feat)
|
||||
|
||||
|
||||
### Upblocks
|
||||
|
||||
|
||||
class InitLayer(nn.Module):
|
||||
def __init__(self, nz, channel, sz=4):
|
||||
super().__init__()
|
||||
|
||||
self.init = nn.Sequential(
|
||||
convTranspose2d(nz, channel*2, sz, 1, 0, bias=False),
|
||||
NormLayer(channel*2),
|
||||
GLU(),
|
||||
)
|
||||
|
||||
def forward(self, noise):
|
||||
noise = noise.view(noise.shape[0], -1, 1, 1)
|
||||
return self.init(noise)
|
||||
|
||||
|
||||
def UpBlockSmall(in_planes, out_planes):
|
||||
block = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
|
||||
NormLayer(out_planes*2), GLU())
|
||||
return block
|
||||
|
||||
|
||||
class UpBlockSmallCond(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, z_dim):
|
||||
super().__init__()
|
||||
self.in_planes = in_planes
|
||||
self.out_planes = out_planes
|
||||
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
|
||||
|
||||
which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
|
||||
self.bn = which_bn(2*out_planes)
|
||||
self.act = GLU()
|
||||
|
||||
def forward(self, x, c):
|
||||
x = self.up(x)
|
||||
x = self.conv(x)
|
||||
x = self.bn(x, c)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
def UpBlockBig(in_planes, out_planes):
|
||||
block = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
|
||||
NoiseInjection(),
|
||||
NormLayer(out_planes*2), GLU(),
|
||||
conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False),
|
||||
NoiseInjection(),
|
||||
NormLayer(out_planes*2), GLU()
|
||||
)
|
||||
return block
|
||||
|
||||
|
||||
class UpBlockBigCond(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, z_dim):
|
||||
super().__init__()
|
||||
self.in_planes = in_planes
|
||||
self.out_planes = out_planes
|
||||
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
|
||||
self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False)
|
||||
|
||||
which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
|
||||
self.bn1 = which_bn(2*out_planes)
|
||||
self.bn2 = which_bn(2*out_planes)
|
||||
self.act = GLU()
|
||||
self.noise = NoiseInjection()
|
||||
|
||||
def forward(self, x, c):
|
||||
# block 1
|
||||
x = self.up(x)
|
||||
x = self.conv1(x)
|
||||
x = self.noise(x)
|
||||
x = self.bn1(x, c)
|
||||
x = self.act(x)
|
||||
|
||||
# block 2
|
||||
x = self.conv2(x)
|
||||
x = self.noise(x)
|
||||
x = self.bn2(x, c)
|
||||
x = self.act(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
def __init__(self, ch_in, ch_out):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(4),
|
||||
conv2d(ch_in, ch_out, 4, 1, 0, bias=False),
|
||||
Swish(),
|
||||
conv2d(ch_out, ch_out, 1, 1, 0, bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, feat_small, feat_big):
|
||||
return feat_big * self.main(feat_small)
|
||||
|
||||
|
||||
### Downblocks
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, bias=False):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size,
|
||||
groups=in_channels, bias=bias, padding=1)
|
||||
self.pointwise = conv2d(in_channels, out_channels,
|
||||
kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.depthwise(x)
|
||||
out = self.pointwise(out)
|
||||
return out
|
||||
|
||||
|
||||
class DownBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, separable=False):
|
||||
super().__init__()
|
||||
if not separable:
|
||||
self.main = nn.Sequential(
|
||||
conv2d(in_planes, out_planes, 4, 2, 1),
|
||||
NormLayer(out_planes),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
else:
|
||||
self.main = nn.Sequential(
|
||||
SeparableConv2d(in_planes, out_planes, 3),
|
||||
NormLayer(out_planes),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.AvgPool2d(2, 2),
|
||||
)
|
||||
|
||||
def forward(self, feat):
|
||||
return self.main(feat)
|
||||
|
||||
|
||||
class DownBlockPatch(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, separable=False):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(
|
||||
DownBlock(in_planes, out_planes, separable),
|
||||
conv2d(out_planes, out_planes, 1, 1, 0, bias=False),
|
||||
NormLayer(out_planes),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
|
||||
def forward(self, feat):
|
||||
return self.main(feat)
|
||||
|
||||
|
||||
### CSM
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
def __init__(self, cin, activation, bn):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True)
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
return self.skip_add.add(self.conv(x), x)
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False):
|
||||
super().__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand==True:
|
||||
out_features = features//2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, *xs):
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
output = self.skip_add.add(output, xs[1])
|
||||
|
||||
output = nn.functional.interpolate(
|
||||
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
||||
)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
### Misc
|
||||
|
||||
|
||||
class NoiseInjection(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
|
||||
|
||||
def forward(self, feat, noise=None):
|
||||
if noise is None:
|
||||
batch, _, height, width = feat.shape
|
||||
noise = torch.randn(batch, 1, height, width).to(feat.device)
|
||||
|
||||
return feat + self.weight * noise
|
||||
|
||||
|
||||
class CCBN(nn.Module):
|
||||
''' conditional batchnorm '''
|
||||
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1):
|
||||
super().__init__()
|
||||
self.output_size, self.input_size = output_size, input_size
|
||||
|
||||
# Prepare gain and bias layers
|
||||
self.gain = which_linear(input_size, output_size)
|
||||
self.bias = which_linear(input_size, output_size)
|
||||
|
||||
# epsilon to avoid dividing by 0
|
||||
self.eps = eps
|
||||
# Momentum
|
||||
self.momentum = momentum
|
||||
|
||||
self.register_buffer('stored_mean', torch.zeros(output_size))
|
||||
self.register_buffer('stored_var', torch.ones(output_size))
|
||||
|
||||
def forward(self, x, y):
|
||||
# Calculate class-conditional gains and biases
|
||||
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
|
||||
bias = self.bias(y).view(y.size(0), -1, 1, 1)
|
||||
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
|
||||
self.training, 0.1, self.eps)
|
||||
return out * gain + bias
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
"""Interpolation module."""
|
||||
|
||||
def __init__(self, size, mode='bilinear', align_corners=False):
|
||||
"""Init.
|
||||
Args:
|
||||
scale_factor (float): scaling
|
||||
mode (str): interpolation mode
|
||||
"""
|
||||
super(Interpolate, self).__init__()
|
||||
|
||||
self.interp = nn.functional.interpolate
|
||||
self.size = size
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
Args:
|
||||
x (tensor): input
|
||||
Returns:
|
||||
tensor: interpolated data
|
||||
"""
|
||||
|
||||
x = self.interp(
|
||||
x,
|
||||
size=self.size,
|
||||
mode=self.mode,
|
||||
align_corners=self.align_corners,
|
||||
)
|
||||
|
||||
return x
|
||||
76
pg_modules/diffaug.py
Normal file
76
pg_modules/diffaug.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Differentiable Augmentation for Data-Efficient GAN Training
|
||||
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
|
||||
# https://arxiv.org/pdf/2006.10738
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def DiffAugment(x, policy='', channels_first=True):
|
||||
if policy:
|
||||
if not channels_first:
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
for p in policy.split(','):
|
||||
for f in AUGMENT_FNS[p]:
|
||||
x = f(x)
|
||||
if not channels_first:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = x.contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def rand_brightness(x):
|
||||
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
||||
return x
|
||||
|
||||
|
||||
def rand_saturation(x):
|
||||
x_mean = x.mean(dim=1, keepdim=True)
|
||||
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
||||
return x
|
||||
|
||||
|
||||
def rand_contrast(x):
|
||||
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
||||
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
||||
return x
|
||||
|
||||
|
||||
def rand_translation(x, ratio=0.125):
|
||||
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
||||
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
||||
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
||||
grid_batch, grid_x, grid_y = torch.meshgrid(
|
||||
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
||||
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
||||
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
||||
)
|
||||
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
||||
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
||||
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
||||
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
def rand_cutout(x, ratio=0.2):
|
||||
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
||||
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
||||
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
||||
grid_batch, grid_x, grid_y = torch.meshgrid(
|
||||
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
||||
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
||||
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
||||
)
|
||||
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
||||
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
||||
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
||||
mask[grid_batch, grid_x, grid_y] = 0
|
||||
x = x * mask.unsqueeze(1)
|
||||
return x
|
||||
|
||||
|
||||
AUGMENT_FNS = {
|
||||
'color': [rand_brightness, rand_saturation, rand_contrast],
|
||||
'translation': [rand_translation],
|
||||
'cutout': [rand_cutout],
|
||||
}
|
||||
191
pg_modules/projected_discriminator.py
Normal file
191
pg_modules/projected_discriminator.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pg_modules.blocks import DownBlock, DownBlockPatch, conv2d
|
||||
from pg_modules.projector import F_RandomProj
|
||||
from pg_modules.diffaug import DiffAugment
|
||||
|
||||
|
||||
class SingleDisc(nn.Module):
|
||||
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False):
|
||||
super().__init__()
|
||||
channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
|
||||
256: 32, 512: 16, 1024: 8}
|
||||
|
||||
# interpolate for start sz that are not powers of two
|
||||
if start_sz not in channel_dict.keys():
|
||||
sizes = np.array(list(channel_dict.keys()))
|
||||
start_sz = sizes[np.argmin(abs(sizes - start_sz))]
|
||||
self.start_sz = start_sz
|
||||
|
||||
# if given ndf, allocate all layers with the same ndf
|
||||
if ndf is None:
|
||||
nfc = channel_dict
|
||||
else:
|
||||
nfc = {k: ndf for k, v in channel_dict.items()}
|
||||
|
||||
# for feature map discriminators with nfc not in channel_dict
|
||||
# this is the case for the pretrained backbone (midas.pretrained)
|
||||
if nc is not None and head is None:
|
||||
nfc[start_sz] = nc
|
||||
|
||||
layers = []
|
||||
|
||||
# Head if the initial input is the full modality
|
||||
if head:
|
||||
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True)]
|
||||
|
||||
# Down Blocks
|
||||
DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
|
||||
while start_sz > end_sz:
|
||||
layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
|
||||
start_sz = start_sz // 2
|
||||
|
||||
layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, c):
|
||||
return self.main(x)
|
||||
|
||||
|
||||
class SingleDiscCond(nn.Module):
|
||||
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128):
|
||||
super().__init__()
|
||||
self.cmap_dim = cmap_dim
|
||||
|
||||
# midas channels
|
||||
channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
|
||||
256: 32, 512: 16, 1024: 8}
|
||||
|
||||
# interpolate for start sz that are not powers of two
|
||||
if start_sz not in channel_dict.keys():
|
||||
sizes = np.array(list(channel_dict.keys()))
|
||||
start_sz = sizes[np.argmin(abs(sizes - start_sz))]
|
||||
self.start_sz = start_sz
|
||||
|
||||
# if given ndf, allocate all layers with the same ndf
|
||||
if ndf is None:
|
||||
nfc = channel_dict
|
||||
else:
|
||||
nfc = {k: ndf for k, v in channel_dict.items()}
|
||||
|
||||
# for feature map discriminators with nfc not in channel_dict
|
||||
# this is the case for the pretrained backbone (midas.pretrained)
|
||||
if nc is not None and head is None:
|
||||
nfc[start_sz] = nc
|
||||
|
||||
layers = []
|
||||
|
||||
# Head if the initial input is the full modality
|
||||
if head:
|
||||
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True)]
|
||||
|
||||
# Down Blocks
|
||||
DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
|
||||
while start_sz > end_sz:
|
||||
layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
|
||||
start_sz = start_sz // 2
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
# additions for conditioning on class information
|
||||
self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False)
|
||||
self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim)
|
||||
self.embed_proj = nn.Sequential(
|
||||
nn.Linear(self.embed.embedding_dim, self.cmap_dim),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
h = self.main(x)
|
||||
out = self.cls(h)
|
||||
|
||||
# conditioning via projection
|
||||
cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1)
|
||||
out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MultiScaleD(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
resolutions,
|
||||
num_discs=4,
|
||||
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
|
||||
cond=0,
|
||||
separable=False,
|
||||
patch=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert num_discs in [1, 2, 3, 4]
|
||||
|
||||
# the first disc is on the lowest level of the backbone
|
||||
self.disc_in_channels = channels[:num_discs]
|
||||
self.disc_in_res = resolutions[:num_discs]
|
||||
Disc = SingleDiscCond if cond else SingleDisc
|
||||
|
||||
mini_discs = []
|
||||
for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
|
||||
start_sz = res if not patch else 16
|
||||
mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)],
|
||||
self.mini_discs = nn.ModuleDict(mini_discs)
|
||||
|
||||
def forward(self, features, c):
|
||||
all_logits = []
|
||||
for k, disc in self.mini_discs.items():
|
||||
res = disc(features[k], c).view(features[k].size(0), -1)
|
||||
all_logits.append(res)
|
||||
|
||||
all_logits = torch.cat(all_logits, dim=1)
|
||||
return all_logits
|
||||
|
||||
|
||||
class ProjectedDiscriminator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
diffaug=True,
|
||||
interp224=True,
|
||||
backbone_kwargs={},
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.diffaug = diffaug
|
||||
self.interp224 = interp224
|
||||
self.feature_network = F_RandomProj(**backbone_kwargs)
|
||||
self.discriminator = MultiScaleD(
|
||||
channels=self.feature_network.CHANNELS,
|
||||
resolutions=self.feature_network.RESOLUTIONS,
|
||||
**backbone_kwargs,
|
||||
)
|
||||
|
||||
def train(self, mode=True):
|
||||
self.feature_network = self.feature_network.train(False)
|
||||
self.discriminator = self.discriminator.train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
return self.train(False)
|
||||
|
||||
def get_feature(self, x):
|
||||
features = self.feature_network(x, get_features=True)
|
||||
return features
|
||||
|
||||
def forward(self, x, c):
|
||||
# if self.diffaug:
|
||||
# x = DiffAugment(x, policy='color,translation,cutout')
|
||||
|
||||
# if self.interp224:
|
||||
# x = F.interpolate(x, 224, mode='bilinear', align_corners=False)
|
||||
|
||||
features,backbone_features = self.feature_network(x)
|
||||
logits = self.discriminator(features, c)
|
||||
|
||||
return logits,backbone_features
|
||||
|
||||
158
pg_modules/projector.py
Normal file
158
pg_modules/projector.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
from pg_modules.blocks import FeatureFusionBlock
|
||||
|
||||
|
||||
def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
|
||||
# shapes
|
||||
out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
|
||||
|
||||
scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
|
||||
scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
|
||||
scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
|
||||
scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
|
||||
|
||||
scratch.CHANNELS = out_channels
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_scratch_csm(scratch, in_channels, cout, expand):
|
||||
scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
|
||||
scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
|
||||
scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
|
||||
scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
|
||||
|
||||
# last refinenet does not expand to save channels in higher dimensions
|
||||
scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_efficientnet(model):
|
||||
pretrained = nn.Module()
|
||||
pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2])
|
||||
pretrained.layer1 = nn.Sequential(*model.blocks[2:3])
|
||||
pretrained.layer2 = nn.Sequential(*model.blocks[3:5])
|
||||
pretrained.layer3 = nn.Sequential(*model.blocks[5:9])
|
||||
return pretrained
|
||||
|
||||
|
||||
def calc_channels(pretrained, inp_res=224):
|
||||
channels = []
|
||||
tmp = torch.zeros(1, 3, inp_res, inp_res)
|
||||
|
||||
# forward pass
|
||||
tmp = pretrained.layer0(tmp)
|
||||
channels.append(tmp.shape[1])
|
||||
tmp = pretrained.layer1(tmp)
|
||||
channels.append(tmp.shape[1])
|
||||
tmp = pretrained.layer2(tmp)
|
||||
channels.append(tmp.shape[1])
|
||||
tmp = pretrained.layer3(tmp)
|
||||
channels.append(tmp.shape[1])
|
||||
|
||||
return channels
|
||||
|
||||
|
||||
def _make_projector(im_res, cout, proj_type, expand=False):
|
||||
assert proj_type in [0, 1, 2], "Invalid projection type"
|
||||
|
||||
### Build pretrained feature network
|
||||
model = timm.create_model('tf_efficientnet_lite0', pretrained=True)
|
||||
pretrained = _make_efficientnet(model)
|
||||
|
||||
# determine resolution of feature maps, this is later used to calculate the number
|
||||
# of down blocks in the discriminators. Interestingly, the best results are achieved
|
||||
# by fixing this to 256, ie., we use the same number of down blocks per discriminator
|
||||
# independent of the dataset resolution
|
||||
im_res = 256
|
||||
pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
|
||||
pretrained.CHANNELS = calc_channels(pretrained)
|
||||
|
||||
if proj_type == 0: return pretrained, None
|
||||
|
||||
### Build CCM
|
||||
scratch = nn.Module()
|
||||
scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
|
||||
pretrained.CHANNELS = scratch.CHANNELS
|
||||
|
||||
if proj_type == 1: return pretrained, scratch
|
||||
|
||||
### build CSM
|
||||
scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
|
||||
|
||||
# CSM upsamples x2 so the feature map resolution doubles
|
||||
pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
|
||||
pretrained.CHANNELS = scratch.CHANNELS
|
||||
|
||||
return pretrained, scratch
|
||||
|
||||
|
||||
class F_RandomProj(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
im_res=256,
|
||||
cout=64,
|
||||
expand=True,
|
||||
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj_type = proj_type
|
||||
self.cout = cout
|
||||
self.expand = expand
|
||||
|
||||
# build pretrained feature network and random decoder (scratch)
|
||||
self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand)
|
||||
self.CHANNELS = self.pretrained.CHANNELS
|
||||
self.RESOLUTIONS = self.pretrained.RESOLUTIONS
|
||||
|
||||
def forward(self, x, get_features=False):
|
||||
# predict feature maps
|
||||
out0 = self.pretrained.layer0(x)
|
||||
out1 = self.pretrained.layer1(out0)
|
||||
out2 = self.pretrained.layer2(out1)
|
||||
out3 = self.pretrained.layer3(out2)
|
||||
|
||||
# start enumerating at the lowest layer (this is where we put the first discriminator)
|
||||
backbone_features = {
|
||||
'0': out0,
|
||||
'1': out1,
|
||||
'2': out2,
|
||||
'3': out3,
|
||||
}
|
||||
if get_features:
|
||||
return backbone_features
|
||||
|
||||
if self.proj_type == 0: return backbone_features
|
||||
|
||||
out0_channel_mixed = self.scratch.layer0_ccm(backbone_features['0'])
|
||||
out1_channel_mixed = self.scratch.layer1_ccm(backbone_features['1'])
|
||||
out2_channel_mixed = self.scratch.layer2_ccm(backbone_features['2'])
|
||||
out3_channel_mixed = self.scratch.layer3_ccm(backbone_features['3'])
|
||||
|
||||
out = {
|
||||
'0': out0_channel_mixed,
|
||||
'1': out1_channel_mixed,
|
||||
'2': out2_channel_mixed,
|
||||
'3': out3_channel_mixed,
|
||||
}
|
||||
|
||||
if self.proj_type == 1: return out
|
||||
|
||||
# from bottom to top
|
||||
out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
|
||||
out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
|
||||
out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
|
||||
out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
|
||||
|
||||
out = {
|
||||
'0': out0_scale_mixed,
|
||||
'1': out1_scale_mixed,
|
||||
'2': out2_scale_mixed,
|
||||
'3': out3_scale_mixed,
|
||||
}
|
||||
|
||||
return out, backbone_features
|
||||
@@ -56,7 +56,7 @@ class Predictor(cog.Predictor):
|
||||
model = create_model(opt)
|
||||
model.eval()
|
||||
|
||||
crop_size = 224
|
||||
crop_size = opt.crop_size
|
||||
spNorm = SpecificNorm()
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -71,7 +71,7 @@ class Predictor(cog.Predictor):
|
||||
img_id = img_id.cuda()
|
||||
|
||||
# create latent id
|
||||
img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ if __name__ == '__main__':
|
||||
img_att = img_att.cuda()
|
||||
|
||||
#create latent id
|
||||
img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = latend_id.detach().to('cpu')
|
||||
latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import cv2
|
||||
import torch
|
||||
import fractions
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
@@ -35,16 +34,22 @@ if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
pic_specific = opt.pic_specific_path
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
crop_size = 224
|
||||
crop_size = opt.crop_size
|
||||
|
||||
multisepcific_dir = opt.multisepcific_dir
|
||||
torch.nn.Module.dump_patches = True
|
||||
if crop_size == 512:
|
||||
opt.which_epoch = 550000
|
||||
opt.name = '512'
|
||||
mode = 'ffhq'
|
||||
else:
|
||||
mode = 'None'
|
||||
model = create_model(opt)
|
||||
model.eval()
|
||||
|
||||
|
||||
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
|
||||
|
||||
# The specific person to be swapped(source)
|
||||
|
||||
@@ -61,7 +66,7 @@ if __name__ == '__main__':
|
||||
# convert numpy to tensor
|
||||
specific_person = specific_person.cuda()
|
||||
#create latent id
|
||||
specific_person_downsample = F.interpolate(specific_person, scale_factor=0.5)
|
||||
specific_person_downsample = F.interpolate(specific_person, size=(112,112))
|
||||
specific_person_id_nonorm = model.netArc(specific_person_downsample)
|
||||
source_specific_id_nonorm_list.append(specific_person_id_nonorm.clone())
|
||||
|
||||
@@ -80,7 +85,7 @@ if __name__ == '__main__':
|
||||
# convert numpy to tensor
|
||||
img_id = img_id.cuda()
|
||||
#create latent id
|
||||
img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
target_id_norm_list.append(latend_id.clone())
|
||||
@@ -90,5 +95,5 @@ if __name__ == '__main__':
|
||||
|
||||
|
||||
video_swap(opt.video_path, target_id_norm_list,source_specific_id_nonorm_list, opt.id_thres, \
|
||||
model, app, opt.output_path,temp_results_dir=opt.temp_path,no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask)
|
||||
model, app, opt.output_path,temp_results_dir=opt.temp_path,no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size,skip_existing_frames=opt.skip_existing_frames)
|
||||
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 19:00:34
|
||||
Description:
|
||||
'''
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@@ -34,15 +42,21 @@ if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
crop_size = 224
|
||||
crop_size = opt.crop_size
|
||||
|
||||
torch.nn.Module.dump_patches = True
|
||||
|
||||
if crop_size == 512:
|
||||
opt.which_epoch = 550000
|
||||
opt.name = '512'
|
||||
mode = 'ffhq'
|
||||
else:
|
||||
mode = 'None'
|
||||
model = create_model(opt)
|
||||
model.eval()
|
||||
|
||||
|
||||
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
with torch.no_grad():
|
||||
pic_a = opt.pic_a_path
|
||||
@@ -65,10 +79,10 @@ if __name__ == '__main__':
|
||||
# img_att = img_att.cuda()
|
||||
|
||||
#create latent id
|
||||
img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
|
||||
video_swap(opt.video_path, latend_id, model, app, opt.output_path,temp_results_dir=opt.temp_path,\
|
||||
no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask)
|
||||
no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size,skip_existing_frames=opt.skip_existing_frames)
|
||||
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 19:00:38
|
||||
Description:
|
||||
'''
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@@ -34,15 +42,21 @@ if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
crop_size = 224
|
||||
crop_size = opt.crop_size
|
||||
|
||||
torch.nn.Module.dump_patches = True
|
||||
if crop_size == 512:
|
||||
opt.which_epoch = 550000
|
||||
opt.name = '512'
|
||||
mode = 'ffhq'
|
||||
else:
|
||||
mode = 'None'
|
||||
model = create_model(opt)
|
||||
model.eval()
|
||||
|
||||
|
||||
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
|
||||
with torch.no_grad():
|
||||
pic_a = opt.pic_a_path
|
||||
# img_a = Image.open(pic_a).convert('RGB')
|
||||
@@ -64,10 +78,10 @@ if __name__ == '__main__':
|
||||
# img_att = img_att.cuda()
|
||||
|
||||
#create latent id
|
||||
img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
|
||||
video_swap(opt.video_path, latend_id, model, app, opt.output_path,temp_results_dir=opt.temp_path,\
|
||||
no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask)
|
||||
no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size,skip_existing_frames=opt.skip_existing_frames)
|
||||
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 19:00:42
|
||||
Description:
|
||||
'''
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@@ -34,15 +42,21 @@ if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
pic_specific = opt.pic_specific_path
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
crop_size = 224
|
||||
crop_size = opt.crop_size
|
||||
|
||||
torch.nn.Module.dump_patches = True
|
||||
if crop_size == 512:
|
||||
opt.which_epoch = 550000
|
||||
opt.name = '512'
|
||||
mode = 'ffhq'
|
||||
else:
|
||||
mode = 'None'
|
||||
model = create_model(opt)
|
||||
model.eval()
|
||||
|
||||
|
||||
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
|
||||
with torch.no_grad():
|
||||
pic_a = opt.pic_a_path
|
||||
# img_a = Image.open(pic_a).convert('RGB')
|
||||
@@ -64,7 +78,7 @@ if __name__ == '__main__':
|
||||
# img_att = img_att.cuda()
|
||||
|
||||
#create latent id
|
||||
img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
|
||||
@@ -76,9 +90,9 @@ if __name__ == '__main__':
|
||||
specific_person = transformer_Arcface(specific_person_align_crop_pil)
|
||||
specific_person = specific_person.view(-1, specific_person.shape[0], specific_person.shape[1], specific_person.shape[2])
|
||||
specific_person = specific_person.cuda()
|
||||
specific_person_downsample = F.interpolate(specific_person, scale_factor=0.5)
|
||||
specific_person_downsample = F.interpolate(specific_person, size=(112,112))
|
||||
specific_person_id_nonorm = model.netArc(specific_person_downsample)
|
||||
|
||||
video_swap(opt.video_path, latend_id,specific_person_id_nonorm, opt.id_thres, \
|
||||
model, app, opt.output_path,temp_results_dir=opt.temp_path,no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask)
|
||||
model, app, opt.output_path,temp_results_dir=opt.temp_path,no_simswaplogo=opt.no_simswaplogo,use_mask=opt.use_mask,crop_size=crop_size,skip_existing_frames=opt.skip_existing_frames)
|
||||
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 19:19:22
|
||||
Description:
|
||||
'''
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@@ -38,11 +46,19 @@ if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
crop_size = 224
|
||||
crop_size = opt.crop_size
|
||||
|
||||
multisepcific_dir = opt.multisepcific_dir
|
||||
|
||||
torch.nn.Module.dump_patches = True
|
||||
|
||||
if crop_size == 512:
|
||||
opt.which_epoch = 550000
|
||||
opt.name = '512'
|
||||
mode = 'ffhq'
|
||||
else:
|
||||
mode = 'None'
|
||||
|
||||
logoclass = watermark_image('./simswaplogo/simswaplogo.png')
|
||||
model = create_model(opt)
|
||||
model.eval()
|
||||
@@ -52,7 +68,7 @@ if __name__ == '__main__':
|
||||
|
||||
|
||||
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
with torch.no_grad():
|
||||
# The specific person to be swapped(source)
|
||||
@@ -70,7 +86,7 @@ if __name__ == '__main__':
|
||||
# convert numpy to tensor
|
||||
specific_person = specific_person.cuda()
|
||||
#create latent id
|
||||
specific_person_downsample = F.interpolate(specific_person, scale_factor=0.5)
|
||||
specific_person_downsample = F.interpolate(specific_person, size=(112,112))
|
||||
specific_person_id_nonorm = model.netArc(specific_person_downsample)
|
||||
source_specific_id_nonorm_list.append(specific_person_id_nonorm.clone())
|
||||
|
||||
@@ -89,7 +105,7 @@ if __name__ == '__main__':
|
||||
# convert numpy to tensor
|
||||
img_id = img_id.cuda()
|
||||
#create latent id
|
||||
img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
target_id_norm_list.append(latend_id.clone())
|
||||
@@ -112,7 +128,7 @@ if __name__ == '__main__':
|
||||
b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
|
||||
|
||||
b_align_crop_tenor_arcnorm = spNorm(b_align_crop_tenor)
|
||||
b_align_crop_tenor_arcnorm_downsample = F.interpolate(b_align_crop_tenor_arcnorm, scale_factor=0.5)
|
||||
b_align_crop_tenor_arcnorm_downsample = F.interpolate(b_align_crop_tenor_arcnorm, size=(112,112))
|
||||
b_align_crop_id_nonorm = model.netArc(b_align_crop_tenor_arcnorm_downsample)
|
||||
|
||||
id_compare_values.append([])
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 19:19:26
|
||||
Description:
|
||||
'''
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@@ -31,16 +39,22 @@ if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
crop_size = 224
|
||||
crop_size = opt.crop_size
|
||||
|
||||
torch.nn.Module.dump_patches = True
|
||||
if crop_size == 512:
|
||||
opt.which_epoch = 550000
|
||||
opt.name = '512'
|
||||
mode = 'ffhq'
|
||||
else:
|
||||
mode = 'None'
|
||||
logoclass = watermark_image('./simswaplogo/simswaplogo.png')
|
||||
model = create_model(opt)
|
||||
model.eval()
|
||||
spNorm =SpecificNorm()
|
||||
|
||||
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
|
||||
|
||||
with torch.no_grad():
|
||||
pic_a = opt.pic_a_path
|
||||
@@ -55,7 +69,7 @@ if __name__ == '__main__':
|
||||
img_id = img_id.cuda()
|
||||
|
||||
#create latent id
|
||||
img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 19:19:43
|
||||
Description:
|
||||
'''
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@@ -30,16 +38,22 @@ if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
crop_size = 224
|
||||
crop_size = opt.crop_size
|
||||
|
||||
torch.nn.Module.dump_patches = True
|
||||
if crop_size == 512:
|
||||
opt.which_epoch = 550000
|
||||
opt.name = '512'
|
||||
mode = 'ffhq'
|
||||
else:
|
||||
mode = 'None'
|
||||
logoclass = watermark_image('./simswaplogo/simswaplogo.png')
|
||||
model = create_model(opt)
|
||||
model.eval()
|
||||
|
||||
spNorm =SpecificNorm()
|
||||
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
|
||||
|
||||
with torch.no_grad():
|
||||
pic_a = opt.pic_a_path
|
||||
@@ -54,7 +68,7 @@ if __name__ == '__main__':
|
||||
img_id = img_id.cuda()
|
||||
|
||||
#create latent id
|
||||
img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 19:19:47
|
||||
Description:
|
||||
'''
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@@ -37,9 +45,15 @@ if __name__ == '__main__':
|
||||
opt = TestOptions().parse()
|
||||
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
crop_size = 224
|
||||
crop_size = opt.crop_size
|
||||
|
||||
torch.nn.Module.dump_patches = True
|
||||
if crop_size == 512:
|
||||
opt.which_epoch = 550000
|
||||
opt.name = '512'
|
||||
mode = 'ffhq'
|
||||
else:
|
||||
mode = 'None'
|
||||
logoclass = watermark_image('./simswaplogo/simswaplogo.png')
|
||||
model = create_model(opt)
|
||||
model.eval()
|
||||
@@ -49,7 +63,7 @@ if __name__ == '__main__':
|
||||
|
||||
|
||||
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
|
||||
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640),mode=mode)
|
||||
|
||||
pic_a = opt.pic_a_path
|
||||
pic_specific = opt.pic_specific_path
|
||||
@@ -65,7 +79,7 @@ if __name__ == '__main__':
|
||||
img_id = img_id.cuda()
|
||||
|
||||
#create latent id
|
||||
img_id_downsample = F.interpolate(img_id, scale_factor=0.5)
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
|
||||
@@ -81,7 +95,7 @@ if __name__ == '__main__':
|
||||
specific_person = specific_person.cuda()
|
||||
|
||||
#create latent id
|
||||
specific_person_downsample = F.interpolate(specific_person, scale_factor=0.5)
|
||||
specific_person_downsample = F.interpolate(specific_person, size=(112,112))
|
||||
specific_person_id_nonorm = model.netArc(specific_person_downsample)
|
||||
# specific_person_id_norm = F.normalize(specific_person_id_nonorm, p=2, dim=1)
|
||||
|
||||
@@ -101,7 +115,7 @@ if __name__ == '__main__':
|
||||
b_align_crop_tenor = _totensor(cv2.cvtColor(b_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
|
||||
|
||||
b_align_crop_tenor_arcnorm = spNorm(b_align_crop_tenor)
|
||||
b_align_crop_tenor_arcnorm_downsample = F.interpolate(b_align_crop_tenor_arcnorm, scale_factor=0.5)
|
||||
b_align_crop_tenor_arcnorm_downsample = F.interpolate(b_align_crop_tenor_arcnorm, size=(112,112))
|
||||
b_align_crop_id_nonorm = model.netArc(b_align_crop_tenor_arcnorm_downsample)
|
||||
|
||||
id_compare_values.append(mse(b_align_crop_id_nonorm,specific_person_id_nonorm).detach().cpu().numpy())
|
||||
|
||||
191
train.ipynb
Normal file
191
train.ipynb
Normal file
@@ -0,0 +1,191 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "fC7QoKePuJWu"
|
||||
},
|
||||
"source": [
|
||||
"#Training Demo\n",
|
||||
"This is a simple example for training the SimSwap 224*224 with VGGFace2-224.\n",
|
||||
"\n",
|
||||
"Code path: https://github.com/neuralchen/SimSwap\n",
|
||||
"If you like the SimSwap project, please star it!\n",
|
||||
"Paper path: https://arxiv.org/pdf/2106.06340v1.pdf or https://dl.acm.org/doi/10.1145/3394171.3413630"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"id": "J8WrNaQHuUGC"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Fri Apr 22 12:19:42 2022 \n",
|
||||
"+-----------------------------------------------------------------------------+\n",
|
||||
"| NVIDIA-SMI 456.71 Driver Version: 456.71 CUDA Version: 11.1 |\n",
|
||||
"|-------------------------------+----------------------+----------------------+\n",
|
||||
"| GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
||||
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
|
||||
"|===============================+======================+======================|\n",
|
||||
"| 0 TITAN Xp WDDM | 00000000:01:00.0 On | N/A |\n",
|
||||
"| 23% 36C P8 15W / 250W | 1135MiB / 12288MiB | 4% Default |\n",
|
||||
"+-------------------------------+----------------------+----------------------+\n",
|
||||
" \n",
|
||||
"+-----------------------------------------------------------------------------+\n",
|
||||
"| Processes: |\n",
|
||||
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
||||
"| ID ID Usage |\n",
|
||||
"|=============================================================================|\n",
|
||||
"| 0 N/A N/A 1232 C+G Insufficient Permissions N/A |\n",
|
||||
"| 0 N/A N/A 1240 C+G Insufficient Permissions N/A |\n",
|
||||
"| 0 N/A N/A 1528 C+G ...y\\ShellExperienceHost.exe N/A |\n",
|
||||
"| 0 N/A N/A 7296 C+G Insufficient Permissions N/A |\n",
|
||||
"| 0 N/A N/A 8280 C+G C:\\Windows\\explorer.exe N/A |\n",
|
||||
"| 0 N/A N/A 9532 C+G ...artMenuExperienceHost.exe N/A |\n",
|
||||
"| 0 N/A N/A 9896 C+G ...5n1h2txyewy\\SearchApp.exe N/A |\n",
|
||||
"| 0 N/A N/A 11040 C+G ...2txyewy\\TextInputHost.exe N/A |\n",
|
||||
"| 0 N/A N/A 11424 C+G Insufficient Permissions N/A |\n",
|
||||
"| 0 N/A N/A 13112 C+G ...icrosoft VS Code\\Code.exe N/A |\n",
|
||||
"| 0 N/A N/A 18720 C+G ...-2.9.15\\GitHubDesktop.exe N/A |\n",
|
||||
"| 0 N/A N/A 22996 C+G ...bbwe\\Microsoft.Photos.exe N/A |\n",
|
||||
"| 0 N/A N/A 23512 C+G ...me\\Application\\chrome.exe N/A |\n",
|
||||
"| 0 N/A N/A 25892 C+G Insufficient Permissions N/A |\n",
|
||||
"+-----------------------------------------------------------------------------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Z6BtQIgWuoqt"
|
||||
},
|
||||
"source": [
|
||||
"Installation\n",
|
||||
"All file changes made by this notebook are temporary. You can try to mount your own google drive to store files if you want."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "wdQJ9d8N8Tnf"
|
||||
},
|
||||
"source": [
|
||||
"#Get Scripts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9jZWwt97uvIe"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!git clone https://github.com/neuralchen/SimSwap\n",
|
||||
"!cd SimSwap && git pull"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "ATLrrbso8Y-Y"
|
||||
},
|
||||
"source": [
|
||||
"# Install Blocks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rwvbPhtOvZAL"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install googledrivedownloader\n",
|
||||
"!pip install timm\n",
|
||||
"!wget -P SimSwap/arcface_model https://github.com/neuralchen/SimSwap/releases/download/1.0/arcface_checkpoint.tar"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "hleVtHIJ_QUK"
|
||||
},
|
||||
"source": [
|
||||
"#Download the Training Dataset\n",
|
||||
"We employ the cropped VGGFace2-224 dataset for this toy training demo.\n",
|
||||
"\n",
|
||||
"You can download the dataset from our google driver https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing\n",
|
||||
"\n",
|
||||
"***Please check the dataset in dir /content/TrainingData***\n",
|
||||
"\n",
|
||||
"***If dataset already exists in /content/TrainingData, please do not run blow scripts!***\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "h2tyjBl0Llxp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc\" -O /content/TrainingData/vggface2_crop_arcfacealign_224.tar && rm -rf /tmp/cookies.txt\n",
|
||||
"%cd /content/\n",
|
||||
"!tar -xzvf /content/TrainingData/vggface2_crop_arcfacealign_224.tar\n",
|
||||
"!rm /content/TrainingData/vggface2_crop_arcfacealign_224.tar"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "o5SNDWzA8LjJ"
|
||||
},
|
||||
"source": [
|
||||
"#Trainig\n",
|
||||
"Batch size must larger than 1!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "XCxHa4oW507s"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%cd /content/SimSwap\n",
|
||||
"!ls\n",
|
||||
"!python train.py --name simswap224_test --gpu_ids 0 --dataset /content/TrainingData/vggface2_crop_arcfacealign_224 --Gdeep False"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "train.ipynb",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
290
train.py
Normal file
290
train.py
Normal file
@@ -0,0 +1,290 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: train.py
|
||||
# Created Date: Monday December 27th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 22nd April 2022 10:49:26 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.backends import cudnn
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
|
||||
from util import util
|
||||
from util.plot import plot_batch
|
||||
|
||||
from models.projected_model import fsModel
|
||||
from data.data_loader_Swapping import GetLoader
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ('true')
|
||||
|
||||
class TrainOptions:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self):
|
||||
self.parser.add_argument('--name', type=str, default='simswap', help='name of the experiment. It decides where to store samples and models')
|
||||
self.parser.add_argument('--gpu_ids', default='0')
|
||||
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
||||
self.parser.add_argument('--isTrain', type=str2bool, default='True')
|
||||
|
||||
# input/output sizes
|
||||
self.parser.add_argument('--batchSize', type=int, default=4, help='input batch size')
|
||||
|
||||
# for displays
|
||||
self.parser.add_argument('--use_tensorboard', type=str2bool, default='False')
|
||||
|
||||
# for training
|
||||
self.parser.add_argument('--dataset', type=str, default="/path/to/VGGFace2", help='path to the face swapping dataset')
|
||||
self.parser.add_argument('--continue_train', type=str2bool, default='False', help='continue training: load the latest model')
|
||||
self.parser.add_argument('--load_pretrain', type=str, default='./checkpoints/simswap224_test', help='load the pretrained model from the specified location')
|
||||
self.parser.add_argument('--which_epoch', type=str, default='10000', help='which epoch to load? set to latest to use latest cached model')
|
||||
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
||||
self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate')
|
||||
self.parser.add_argument('--niter_decay', type=int, default=10000, help='# of iter to linearly decay learning rate to zero')
|
||||
self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
|
||||
self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam')
|
||||
self.parser.add_argument('--Gdeep', type=str2bool, default='False')
|
||||
|
||||
# for discriminators
|
||||
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
|
||||
self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss')
|
||||
self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss')
|
||||
|
||||
self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT")
|
||||
self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step')
|
||||
self.parser.add_argument("--log_frep", type=int, default=200, help='frequence for printing log information')
|
||||
self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling')
|
||||
self.parser.add_argument("--model_freq", type=int, default=10000, help='frequence for saving the model')
|
||||
|
||||
|
||||
|
||||
|
||||
self.isTrain = True
|
||||
|
||||
def parse(self, save=True):
|
||||
if not self.initialized:
|
||||
self.initialize()
|
||||
self.opt = self.parser.parse_args()
|
||||
self.opt.isTrain = self.isTrain # train or test
|
||||
|
||||
args = vars(self.opt)
|
||||
|
||||
print('------------ Options -------------')
|
||||
for k, v in sorted(args.items()):
|
||||
print('%s: %s' % (str(k), str(v)))
|
||||
print('-------------- End ----------------')
|
||||
|
||||
# save to the disk
|
||||
if self.opt.isTrain:
|
||||
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
|
||||
util.mkdirs(expr_dir)
|
||||
if save and not self.opt.continue_train:
|
||||
file_name = os.path.join(expr_dir, 'opt.txt')
|
||||
with open(file_name, 'wt') as opt_file:
|
||||
opt_file.write('------------ Options -------------\n')
|
||||
for k, v in sorted(args.items()):
|
||||
opt_file.write('%s: %s\n' % (str(k), str(v)))
|
||||
opt_file.write('-------------- End ----------------\n')
|
||||
return self.opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
opt = TrainOptions().parse()
|
||||
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
|
||||
|
||||
sample_path = os.path.join(opt.checkpoints_dir, opt.name, 'samples')
|
||||
|
||||
if not os.path.exists(sample_path):
|
||||
os.makedirs(sample_path)
|
||||
|
||||
log_path = os.path.join(opt.checkpoints_dir, opt.name, 'summary')
|
||||
|
||||
if not os.path.exists(log_path):
|
||||
os.makedirs(log_path)
|
||||
|
||||
if opt.continue_train:
|
||||
try:
|
||||
start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
|
||||
except:
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
|
||||
else:
|
||||
start_epoch, epoch_iter = 1, 0
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_ids)
|
||||
print("GPU used : ", str(opt.gpu_ids))
|
||||
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
|
||||
|
||||
model = fsModel()
|
||||
|
||||
model.initialize(opt)
|
||||
|
||||
#####################################################
|
||||
if opt.use_tensorboard:
|
||||
tensorboard_writer = tensorboard.SummaryWriter(log_path)
|
||||
logger = tensorboard_writer
|
||||
|
||||
log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
||||
|
||||
with open(log_name, "a") as log_file:
|
||||
now = time.strftime("%c")
|
||||
log_file.write('================ Training Loss (%s) ================\n' % now)
|
||||
|
||||
optimizer_G, optimizer_D = model.optimizer_G, model.optimizer_D
|
||||
|
||||
loss_avg = 0
|
||||
refresh_count = 0
|
||||
imagenet_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
imagenet_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
|
||||
train_loader = GetLoader(opt.dataset,opt.batchSize,8,1234)
|
||||
|
||||
randindex = [i for i in range(opt.batchSize)]
|
||||
random.shuffle(randindex)
|
||||
|
||||
if not opt.continue_train:
|
||||
start = 0
|
||||
else:
|
||||
start = int(opt.which_epoch)
|
||||
total_step = opt.total_step
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from util.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
model.netD.feature_network.requires_grad_(False)
|
||||
|
||||
# Training Cycle
|
||||
for step in range(start, total_step):
|
||||
model.netG.train()
|
||||
for interval in range(2):
|
||||
random.shuffle(randindex)
|
||||
src_image1, src_image2 = train_loader.next()
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
else:
|
||||
img_id = src_image2[randindex]
|
||||
|
||||
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
|
||||
latent_id = model.netArc(img_id_112)
|
||||
latent_id = F.normalize(latent_id, p=2, dim=1)
|
||||
if interval:
|
||||
|
||||
img_fake = model.netG(src_image1, latent_id)
|
||||
gen_logits,_ = model.netD(img_fake.detach(), None)
|
||||
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
|
||||
|
||||
real_logits,_ = model.netD(src_image2,None)
|
||||
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
|
||||
|
||||
loss_D = loss_Dgen + loss_Dreal
|
||||
optimizer_D.zero_grad()
|
||||
loss_D.backward()
|
||||
optimizer_D.step()
|
||||
else:
|
||||
|
||||
# model.netD.requires_grad_(True)
|
||||
img_fake = model.netG(src_image1, latent_id)
|
||||
# G loss
|
||||
gen_logits,feat = model.netD(img_fake, None)
|
||||
|
||||
loss_Gmain = (-gen_logits).mean()
|
||||
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
|
||||
latent_fake = model.netArc(img_fake_down)
|
||||
latent_fake = F.normalize(latent_fake, p=2, dim=1)
|
||||
loss_G_ID = (1 - model.cosin_metric(latent_fake, latent_id)).mean()
|
||||
real_feat = model.netD.get_feature(src_image1)
|
||||
feat_match_loss = model.criterionFeat(feat["3"],real_feat["3"])
|
||||
loss_G = loss_Gmain + loss_G_ID * opt.lambda_id + feat_match_loss * opt.lambda_feat
|
||||
|
||||
|
||||
if step%2 == 0:
|
||||
#G_Rec
|
||||
loss_G_Rec = model.criterionRec(img_fake, src_image1) * opt.lambda_rec
|
||||
loss_G += loss_G_Rec
|
||||
|
||||
optimizer_G.zero_grad()
|
||||
loss_G.backward()
|
||||
optimizer_G.step()
|
||||
|
||||
|
||||
############## Display results and errors ##########
|
||||
### print out errors
|
||||
# Print out log info
|
||||
if (step + 1) % opt.log_frep == 0:
|
||||
# errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
|
||||
errors = {
|
||||
"G_Loss":loss_Gmain.item(),
|
||||
"G_ID":loss_G_ID.item(),
|
||||
"G_Rec":loss_G_Rec.item(),
|
||||
"G_feat_match":feat_match_loss.item(),
|
||||
"D_fake":loss_Dgen.item(),
|
||||
"D_real":loss_Dreal.item(),
|
||||
"D_loss":loss_D.item()
|
||||
}
|
||||
if opt.use_tensorboard:
|
||||
for tag, value in errors.items():
|
||||
logger.add_scalar(tag, value, step)
|
||||
message = '( step: %d, ) ' % (step)
|
||||
for k, v in errors.items():
|
||||
message += '%s: %.3f ' % (k, v)
|
||||
|
||||
print(message)
|
||||
with open(log_name, "a") as log_file:
|
||||
log_file.write('%s\n' % message)
|
||||
|
||||
### display output images
|
||||
if (step + 1) % opt.sample_freq == 0:
|
||||
model.netG.eval()
|
||||
with torch.no_grad():
|
||||
imgs = list()
|
||||
zero_img = (torch.zeros_like(src_image1[0,...]))
|
||||
imgs.append(zero_img.cpu().numpy())
|
||||
save_img = ((src_image1.cpu())* imagenet_std + imagenet_mean).numpy()
|
||||
for r in range(opt.batchSize):
|
||||
imgs.append(save_img[r,...])
|
||||
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
|
||||
id_vector_src1 = model.netArc(arcface_112)
|
||||
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
|
||||
|
||||
for i in range(opt.batchSize):
|
||||
|
||||
imgs.append(save_img[i,...])
|
||||
image_infer = src_image1[i, ...].repeat(opt.batchSize, 1, 1, 1)
|
||||
img_fake = model.netG(image_infer, id_vector_src1).cpu()
|
||||
|
||||
img_fake = img_fake * imagenet_std
|
||||
img_fake = img_fake + imagenet_mean
|
||||
img_fake = img_fake.numpy()
|
||||
for j in range(opt.batchSize):
|
||||
imgs.append(img_fake[j,...])
|
||||
print("Save test data")
|
||||
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
|
||||
plot_batch(imgs, os.path.join(sample_path, 'step_'+str(step+1)+'.jpg'))
|
||||
|
||||
### save latest model
|
||||
if (step+1) % opt.model_freq==0:
|
||||
print('saving the latest model (steps %d)' % (step+1))
|
||||
model.save(step+1)
|
||||
np.savetxt(iter_path, (step+1, total_step), delimiter=',', fmt='%d')
|
||||
wandb.finish()
|
||||
15
util/json_config.py
Normal file
15
util/json_config.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import json
|
||||
|
||||
|
||||
def readConfig(path):
|
||||
with open(path,'r') as cf:
|
||||
nodelocaltionstr = cf.read()
|
||||
nodelocaltioninf = json.loads(nodelocaltionstr)
|
||||
if isinstance(nodelocaltioninf,str):
|
||||
nodelocaltioninf = json.loads(nodelocaltioninf)
|
||||
return nodelocaltioninf
|
||||
|
||||
def writeConfig(path, info):
|
||||
with open(path, 'w') as cf:
|
||||
configjson = json.dumps(info, indent=4)
|
||||
cf.writelines(configjson)
|
||||
44
util/logo_class.py
Normal file
44
util/logo_class.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: logo_class.py
|
||||
# Created Date: Tuesday June 29th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 11th October 2021 12:39:55 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
class logo_class:
|
||||
|
||||
@staticmethod
|
||||
def print_group_logo():
|
||||
logo_str = """
|
||||
|
||||
███╗ ██╗██████╗ ███████╗██╗ ██████╗ ███████╗ ██╗████████╗██╗ ██╗
|
||||
████╗ ██║██╔══██╗██╔════╝██║██╔════╝ ██╔════╝ ██║╚══██╔══╝██║ ██║
|
||||
██╔██╗ ██║██████╔╝███████╗██║██║ ███╗ ███████╗ ██║ ██║ ██║ ██║
|
||||
██║╚██╗██║██╔══██╗╚════██║██║██║ ██║ ╚════██║██ ██║ ██║ ██║ ██║
|
||||
██║ ╚████║██║ ██║███████║██║╚██████╔╝ ███████║╚█████╔╝ ██║ ╚██████╔╝
|
||||
╚═╝ ╚═══╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═════╝ ╚══════╝ ╚════╝ ╚═╝ ╚═════╝
|
||||
Neural Rendering Special Interesting Group of SJTU
|
||||
|
||||
"""
|
||||
print(logo_str)
|
||||
|
||||
@staticmethod
|
||||
def print_start_training():
|
||||
logo_str = """
|
||||
_____ __ __ ______ _ _
|
||||
/ ___/ / /_ ____ _ _____ / /_ /_ __/_____ ____ _ (_)____ (_)____ ____ _
|
||||
\__ \ / __// __ `// ___// __/ / / / ___// __ `// // __ \ / // __ \ / __ `/
|
||||
___/ // /_ / /_/ // / / /_ / / / / / /_/ // // / / // // / / // /_/ /
|
||||
/____/ \__/ \__,_//_/ \__/ /_/ /_/ \__,_//_//_/ /_//_//_/ /_/ \__, /
|
||||
/____/
|
||||
"""
|
||||
print(logo_str)
|
||||
|
||||
if __name__=="__main__":
|
||||
# logo_class.print_group_logo()
|
||||
logo_class.print_start_training()
|
||||
37
util/plot.py
Normal file
37
util/plot.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import PIL
|
||||
|
||||
def postprocess(x):
|
||||
"""[0,1] to uint8."""
|
||||
|
||||
x = np.clip(255 * x, 0, 255)
|
||||
x = np.cast[np.uint8](x)
|
||||
return x
|
||||
|
||||
def tile(X, rows, cols):
|
||||
"""Tile images for display."""
|
||||
tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype)
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
idx = i * cols + j
|
||||
if idx < X.shape[0]:
|
||||
img = X[idx,...]
|
||||
tiling[
|
||||
i*X.shape[1]:(i+1)*X.shape[1],
|
||||
j*X.shape[2]:(j+1)*X.shape[2],
|
||||
:] = img
|
||||
return tiling
|
||||
|
||||
|
||||
def plot_batch(X, out_path):
|
||||
"""Save batch of images tiled."""
|
||||
n_channels = X.shape[3]
|
||||
if n_channels > 3:
|
||||
X = X[:,:,:,np.random.choice(n_channels, size = 3)]
|
||||
X = postprocess(X)
|
||||
rc = math.sqrt(X.shape[0])
|
||||
rows = cols = math.ceil(rc)
|
||||
canvas = tile(X, rows, cols)
|
||||
canvas = np.squeeze(canvas)
|
||||
PIL.Image.fromarray(canvas).save(out_path)
|
||||
@@ -110,7 +110,7 @@ def reverse2wholeimage(b_align_crop_tenor_list,swaped_imgs, mats, crop_size, ori
|
||||
tgt_mask = encode_segmentation_rgb(vis_parsing_anno)
|
||||
if tgt_mask.sum() >= 5000:
|
||||
# face_mask_tensor = tgt_mask[...,0] + tgt_mask[...,1]
|
||||
target_mask = cv2.resize(tgt_mask, (224, 224))
|
||||
target_mask = cv2.resize(tgt_mask, (crop_size, crop_size))
|
||||
# print(source_img)
|
||||
target_image_parsing = postprocess(swaped_img, source_img[0].cpu().detach().numpy().transpose((1, 2, 0)), target_mask,smooth_mask)
|
||||
|
||||
|
||||
57
util/save_heatmap.py
Normal file
57
util/save_heatmap.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: save_heatmap.py
|
||||
# Created Date: Friday January 15th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 19th January 2022 1:22:47 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def SaveHeatmap(heatmaps, path, row=-1, dpi=72):
|
||||
"""
|
||||
The input tensor must be B X 1 X H X W
|
||||
"""
|
||||
batch_size = heatmaps.shape[0]
|
||||
temp_path = ".temp/"
|
||||
if not os.path.exists(temp_path):
|
||||
os.makedirs(temp_path)
|
||||
final_img = None
|
||||
if row < 1:
|
||||
col = batch_size
|
||||
row = 1
|
||||
else:
|
||||
col = batch_size // row
|
||||
if row * col <batch_size:
|
||||
col +=1
|
||||
|
||||
row_i = 0
|
||||
col_i = 0
|
||||
|
||||
for i in range(batch_size):
|
||||
img_path = os.path.join(temp_path,'temp_batch_{}.png'.format(i))
|
||||
sns.heatmap(heatmaps[i,0,:,:],vmin=0,vmax=heatmaps[i,0,:,:].max(),cbar=False)
|
||||
plt.savefig(img_path, dpi=dpi, bbox_inches = 'tight', pad_inches = 0)
|
||||
img = cv2.imread(img_path)
|
||||
if i == 0:
|
||||
H,W,C = img.shape
|
||||
final_img = np.zeros((H*row,W*col,C))
|
||||
final_img[H*row_i:H*(row_i+1),W*col_i:W*(col_i+1),:] = img
|
||||
col_i += 1
|
||||
if col_i >= col:
|
||||
col_i = 0
|
||||
row_i += 1
|
||||
cv2.imwrite(path,final_img)
|
||||
|
||||
if __name__ == "__main__":
|
||||
random_map = np.random.randn(16,1,10,10)
|
||||
SaveHeatmap(random_map,"./wocao.png",1)
|
||||
@@ -1,3 +1,11 @@
|
||||
'''
|
||||
Author: Naiyuan liu
|
||||
Github: https://github.com/NNNNAI
|
||||
Date: 2021-11-23 17:03:58
|
||||
LastEditors: Naiyuan liu
|
||||
LastEditTime: 2021-11-24 19:19:52
|
||||
Description:
|
||||
'''
|
||||
import os
|
||||
import cv2
|
||||
import glob
|
||||
@@ -19,7 +27,7 @@ def _totensor(array):
|
||||
img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
return img.float().div(255)
|
||||
|
||||
def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False):
|
||||
def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False, use_mask = False, skip_existing_frames = False):
|
||||
video_forcheck = VideoFileClip(video_path)
|
||||
if video_forcheck.audio is None:
|
||||
no_audio = True
|
||||
@@ -43,8 +51,8 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r
|
||||
# video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
fps = video.get(cv2.CAP_PROP_FPS)
|
||||
if os.path.exists(temp_results_dir):
|
||||
shutil.rmtree(temp_results_dir)
|
||||
if not skip_existing_frames and os.path.exists(temp_results_dir):
|
||||
shutil.rmtree(temp_results_dir)
|
||||
|
||||
spNorm =SpecificNorm()
|
||||
if use_mask:
|
||||
@@ -56,17 +64,22 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r
|
||||
net.eval()
|
||||
else:
|
||||
net =None
|
||||
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
|
||||
# while ret:
|
||||
for frame_index in tqdm(range(frame_count)):
|
||||
ret, frame = video.read()
|
||||
|
||||
if skip_existing_frames and os.path.exists(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index))):
|
||||
continue
|
||||
|
||||
if ret:
|
||||
detect_results = detect_model.get(frame,crop_size)
|
||||
|
||||
if detect_results is not None:
|
||||
# print(frame_index)
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame_align_crop_list = detect_results[0]
|
||||
frame_mat_list = detect_results[1]
|
||||
swap_result_list = []
|
||||
@@ -79,6 +92,7 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r
|
||||
frame_align_crop_tenor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
|
||||
|
||||
swap_result = swap_model(None, frame_align_crop_tenor, id_vetor, None, True)[0]
|
||||
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
|
||||
swap_result_list.append(swap_result)
|
||||
frame_align_crop_tenor_list.append(frame_align_crop_tenor)
|
||||
|
||||
@@ -88,8 +102,6 @@ def video_swap(video_path, id_vetor, swap_model, detect_model, save_path, temp_r
|
||||
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask=use_mask, norm = spNorm)
|
||||
|
||||
else:
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame = frame.astype(np.uint8)
|
||||
if not no_simswaplogo:
|
||||
frame = logoclass.apply_frames(frame)
|
||||
|
||||
@@ -20,7 +20,7 @@ def _totensor(array):
|
||||
img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
return img.float().div(255)
|
||||
|
||||
def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False):
|
||||
def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False, use_mask =False, skip_existing_frames = False):
|
||||
video_forcheck = VideoFileClip(video_path)
|
||||
if video_forcheck.audio is None:
|
||||
no_audio = True
|
||||
@@ -44,8 +44,8 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id
|
||||
# video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
fps = video.get(cv2.CAP_PROP_FPS)
|
||||
if os.path.exists(temp_results_dir):
|
||||
shutil.rmtree(temp_results_dir)
|
||||
if not skip_existing_frames and os.path.exists(temp_results_dir):
|
||||
shutil.rmtree(temp_results_dir)
|
||||
|
||||
spNorm =SpecificNorm()
|
||||
mse = torch.nn.MSELoss().cuda()
|
||||
@@ -60,16 +60,21 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id
|
||||
else:
|
||||
net =None
|
||||
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
|
||||
# while ret:
|
||||
for frame_index in tqdm(range(frame_count)):
|
||||
ret, frame = video.read()
|
||||
|
||||
if skip_existing_frames and os.path.exists(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index))):
|
||||
continue
|
||||
|
||||
if ret:
|
||||
detect_results = detect_model.get(frame,crop_size)
|
||||
|
||||
if detect_results is not None:
|
||||
# print(frame_index)
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame_align_crop_list = detect_results[0]
|
||||
frame_mat_list = detect_results[1]
|
||||
|
||||
@@ -83,7 +88,7 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id
|
||||
frame_align_crop_tenor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
|
||||
|
||||
frame_align_crop_tenor_arcnorm = spNorm(frame_align_crop_tenor)
|
||||
frame_align_crop_tenor_arcnorm_downsample = F.interpolate(frame_align_crop_tenor_arcnorm, scale_factor=0.5)
|
||||
frame_align_crop_tenor_arcnorm_downsample = F.interpolate(frame_align_crop_tenor_arcnorm, size=(112,112))
|
||||
frame_align_crop_crop_id_nonorm = swap_model.netArc(frame_align_crop_tenor_arcnorm_downsample)
|
||||
id_compare_values.append([])
|
||||
for source_specific_id_nonorm_tmp in source_specific_id_nonorm_list:
|
||||
@@ -113,16 +118,12 @@ def video_swap(video_path, target_id_norm_list,source_specific_id_nonorm_list,id
|
||||
reverse2wholeimage(swap_result_ori_pic_list,swap_result_list, swap_result_matrix_list, crop_size, frame, logoclass,\
|
||||
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask=use_mask, norm = spNorm)
|
||||
else:
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame = frame.astype(np.uint8)
|
||||
if not no_simswaplogo:
|
||||
frame = logoclass.apply_frames(frame)
|
||||
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
|
||||
|
||||
else:
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame = frame.astype(np.uint8)
|
||||
if not no_simswaplogo:
|
||||
frame = logoclass.apply_frames(frame)
|
||||
|
||||
@@ -20,7 +20,7 @@ def _totensor(array):
|
||||
img = tensor.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
return img.float().div(255)
|
||||
|
||||
def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False,use_mask =False):
|
||||
def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_model, detect_model, save_path, temp_results_dir='./temp_results', crop_size=224, no_simswaplogo = False, use_mask =False, skip_existing_frames = False):
|
||||
video_forcheck = VideoFileClip(video_path)
|
||||
if video_forcheck.audio is None:
|
||||
no_audio = True
|
||||
@@ -44,8 +44,8 @@ def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_mod
|
||||
# video_HEIGHT = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
fps = video.get(cv2.CAP_PROP_FPS)
|
||||
if os.path.exists(temp_results_dir):
|
||||
shutil.rmtree(temp_results_dir)
|
||||
if not skip_existing_frames and os.path.exists(temp_results_dir):
|
||||
shutil.rmtree(temp_results_dir)
|
||||
|
||||
spNorm =SpecificNorm()
|
||||
mse = torch.nn.MSELoss().cuda()
|
||||
@@ -60,16 +60,21 @@ def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_mod
|
||||
else:
|
||||
net =None
|
||||
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
|
||||
# while ret:
|
||||
for frame_index in tqdm(range(frame_count)):
|
||||
ret, frame = video.read()
|
||||
|
||||
if skip_existing_frames and os.path.exists(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index))):
|
||||
continue
|
||||
|
||||
if ret:
|
||||
detect_results = detect_model.get(frame,crop_size)
|
||||
|
||||
if detect_results is not None:
|
||||
# print(frame_index)
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame_align_crop_list = detect_results[0]
|
||||
frame_mat_list = detect_results[1]
|
||||
|
||||
@@ -83,7 +88,7 @@ def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_mod
|
||||
frame_align_crop_tenor = _totensor(cv2.cvtColor(frame_align_crop,cv2.COLOR_BGR2RGB))[None,...].cuda()
|
||||
|
||||
frame_align_crop_tenor_arcnorm = spNorm(frame_align_crop_tenor)
|
||||
frame_align_crop_tenor_arcnorm_downsample = F.interpolate(frame_align_crop_tenor_arcnorm, scale_factor=0.5)
|
||||
frame_align_crop_tenor_arcnorm_downsample = F.interpolate(frame_align_crop_tenor_arcnorm, size=(112,112))
|
||||
frame_align_crop_crop_id_nonorm = swap_model.netArc(frame_align_crop_tenor_arcnorm_downsample)
|
||||
|
||||
id_compare_values.append(mse(frame_align_crop_crop_id_nonorm,specific_person_id_nonorm).detach().cpu().numpy())
|
||||
@@ -97,16 +102,12 @@ def video_swap(video_path, id_vetor,specific_person_id_nonorm,id_thres, swap_mod
|
||||
reverse2wholeimage([frame_align_crop_tenor_list[min_index]], [swap_result], [frame_mat_list[min_index]], crop_size, frame, logoclass,\
|
||||
os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)),no_simswaplogo,pasring_model =net,use_mask= use_mask, norm = spNorm)
|
||||
else:
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame = frame.astype(np.uint8)
|
||||
if not no_simswaplogo:
|
||||
frame = logoclass.apply_frames(frame)
|
||||
cv2.imwrite(os.path.join(temp_results_dir, 'frame_{:0>7d}.jpg'.format(frame_index)), frame)
|
||||
|
||||
else:
|
||||
if not os.path.exists(temp_results_dir):
|
||||
os.mkdir(temp_results_dir)
|
||||
frame = frame.astype(np.uint8)
|
||||
if not no_simswaplogo:
|
||||
frame = logoclass.apply_frames(frame)
|
||||
|
||||
Reference in New Issue
Block a user