update
This commit is contained in:
10
README.md
10
README.md
@@ -65,14 +65,14 @@ Download the dataset from [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ).
|
||||
The training script is slightly different from the original version, e.g., we replace the patch discriminator with the projected discriminator, which saves a lot of hardware overhead and achieves slightly better results.
|
||||
In order to ensure normal training, the batch size must be greater than 1.
|
||||
|
||||
- Train 256 models
|
||||
- Train 224 models with VGGFace2 224*224 [VGGFace2-224](https://github.com/NNNNAI/VGGFace2-HQ)
|
||||
```
|
||||
python train.py --name simswap256_test --gpu_ids 0 --dataset /path/to/VGGFace2HQ --train_simswap True --Gdeep False
|
||||
python train.py --name simswap224_test --batchSize 4 --gpu_ids 0 --dataset /path/to/VGGFace2HQ --Gdeep False
|
||||
```
|
||||
|
||||
- Train 512 models
|
||||
- Train 512 models with VGGFace2-HQ 512*512 [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ).
|
||||
```
|
||||
python train.py --name simswap512_test --gpu_ids 0 --dataset /path/to/VGGFace2HQ --train_simswap False --Gdeep True
|
||||
python train.py --name simswap512_test --gpu_ids 0 --dataset /path/to/VGGFace2HQ --Gdeep True
|
||||
```
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ python train.py --name simswap512_test --gpu_ids 0 --dataset /path/to/VGGFace2H
|
||||
|
||||
<div style="background: yellow; width:140px; font-weight:bold;font-family: sans-serif;">Stronger feature</div>
|
||||
|
||||
[Colab fo switching specific faces in multi-face videos](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/MultiSpecific.ipynb)
|
||||
[Colab for switching specific faces in multi-face videos](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/MultiSpecific.ipynb)
|
||||
|
||||
[Image face swapping demo & Docker image on Replicate](https://replicate.ai/neuralchen/simswap-image)
|
||||
|
||||
|
||||
315
train.ipynb
Normal file
315
train.ipynb
Normal file
@@ -0,0 +1,315 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "train.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"#Training Demo\n",
|
||||
"This is a simple example for training the SimSwap 224*224 with VGGFace2-224.\n",
|
||||
"\n",
|
||||
"Code path: https://github.com/neuralchen/SimSwap\n",
|
||||
"If you like the SimSwap project, please star it!\n",
|
||||
"Paper path: https://arxiv.org/pdf/2106.06340v1.pdf or https://dl.acm.org/doi/10.1145/3394171.3413630"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "fC7QoKePuJWu"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!nvidia-smi"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "J8WrNaQHuUGC",
|
||||
"outputId": "afffa0be-92b5-4133-b6d9-6c3e08c6de64"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Thu Apr 21 16:07:35 2022 \n",
|
||||
"+-----------------------------------------------------------------------------+\n",
|
||||
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
|
||||
"|-------------------------------+----------------------+----------------------+\n",
|
||||
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
||||
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
|
||||
"| | | MIG M. |\n",
|
||||
"|===============================+======================+======================|\n",
|
||||
"| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n",
|
||||
"| N/A 67C P8 32W / 149W | 0MiB / 11441MiB | 0% Default |\n",
|
||||
"| | | N/A |\n",
|
||||
"+-------------------------------+----------------------+----------------------+\n",
|
||||
" \n",
|
||||
"+-----------------------------------------------------------------------------+\n",
|
||||
"| Processes: |\n",
|
||||
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
||||
"| ID ID Usage |\n",
|
||||
"|=============================================================================|\n",
|
||||
"| No running processes found |\n",
|
||||
"+-----------------------------------------------------------------------------+\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Installation\n",
|
||||
"All file changes made by this notebook are temporary. You can try to mount your own google drive to store files if you want."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Z6BtQIgWuoqt"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"#Get Scripts"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "wdQJ9d8N8Tnf"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!git clone https://github.com/neuralchen/SimSwap\n",
|
||||
"!cd SimSwap && git pull"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "9jZWwt97uvIe",
|
||||
"outputId": "42a1bda8-3ca3-46af-fc82-d1af99ce15e1"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Cloning into 'SimSwap'...\n",
|
||||
"remote: Enumerating objects: 1017, done.\u001b[K\n",
|
||||
"remote: Counting objects: 100% (16/16), done.\u001b[K\n",
|
||||
"remote: Compressing objects: 100% (13/13), done.\u001b[K\n",
|
||||
"remote: Total 1017 (delta 5), reused 10 (delta 3), pack-reused 1001\u001b[K\n",
|
||||
"Receiving objects: 100% (1017/1017), 210.79 MiB | 14.80 MiB/s, done.\n",
|
||||
"Resolving deltas: 100% (510/510), done.\n",
|
||||
"Already up to date.\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Install Blocks"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ATLrrbso8Y-Y"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!pip install googledrivedownloader\n",
|
||||
"!pip install timm\n",
|
||||
"!wget -P SimSwap/arcface_model https://github.com/neuralchen/SimSwap/releases/download/1.0/arcface_checkpoint.tar"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "rwvbPhtOvZAL",
|
||||
"outputId": "ffa12208-d388-412d-e83b-c54864c4526e"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Requirement already satisfied: googledrivedownloader in /usr/local/lib/python3.7/dist-packages (0.4)\n",
|
||||
"Requirement already satisfied: imageio==2.4.1 in /usr/local/lib/python3.7/dist-packages (2.4.1)\n",
|
||||
"Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from imageio==2.4.1) (7.1.2)\n",
|
||||
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from imageio==2.4.1) (1.21.6)\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"#Download the Training Dataset\n",
|
||||
"We employ the cropped VGGFace2-224 dataset for this toy training demo.\n",
|
||||
"You can download the dataset from our google driver "
|
||||
],
|
||||
"metadata": {
|
||||
"id": "hleVtHIJ_QUK"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from google_drive_downloader import GoogleDriveDownloader as gdd\n",
|
||||
"gdd.download_file_from_google_drive(file_id='1iytA1n2z4go3uVCwE__vIKouTKyIDjEq',dest_path='/content/TrainingData/vggface2_crop_arcfacealign_224.tar',showsize=True)\n",
|
||||
"!tar -xzvf /content/TrainingData/vggface2_crop_arcfacealign_224.tar"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "gMVKEej59LX9",
|
||||
"outputId": "2e508c44-d006-4183-81d9-f9753d08dea7"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Downloading 1iytA1n2z4go3uVCwE__vIKouTKyIDjEq into /content/TrainingData/mnist.zip... \n",
|
||||
"0.0 B Done.\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"#Trainig\n",
|
||||
"Batch size must larger than 1!"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "o5SNDWzA8LjJ"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"%cd /content/SimSwap\n",
|
||||
"!ls\n",
|
||||
"!python train.py --name simswap224_test --gpu_ids 0 --dataset /content/TrainingData/vggface2_crop_arcfacealign_224 --Gdeep False"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "XCxHa4oW507s",
|
||||
"outputId": "c84c52d9-0b36-4932-925d-1ae38a3f7bb0"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"/content/SimSwap\n",
|
||||
" arcface_model\t predict.py\n",
|
||||
" cog.yaml\t README.md\n",
|
||||
" crop_224\t 'SimSwap colab.ipynb'\n",
|
||||
" data\t\t simswaplogo\n",
|
||||
" demo_file\t test_one_image.py\n",
|
||||
" docs\t\t test_video_swapmulti.py\n",
|
||||
" download-weights.sh test_video_swap_multispecific.py\n",
|
||||
" insightface_func test_video_swapsingle.py\n",
|
||||
" LICENSE\t test_video_swapspecific.py\n",
|
||||
" models\t\t test_wholeimage_swapmulti.py\n",
|
||||
" MultiSpecific.ipynb test_wholeimage_swap_multispecific.py\n",
|
||||
" options\t test_wholeimage_swapsingle.py\n",
|
||||
" output\t\t test_wholeimage_swapspecific.py\n",
|
||||
" parsing_model\t train.py\n",
|
||||
" pg_modules\t util\n",
|
||||
"------------ Options -------------\n",
|
||||
"Arc_path: arcface_model/arcface_checkpoint.tar\n",
|
||||
"Gdeep: False\n",
|
||||
"batchSize: 2\n",
|
||||
"beta1: 0.0\n",
|
||||
"checkpoints_dir: ./checkpoints\n",
|
||||
"continue_train: False\n",
|
||||
"dataset: /path/to/VGGFace2\n",
|
||||
"gpu_ids: 0\n",
|
||||
"isTrain: True\n",
|
||||
"lambda_feat: 10.0\n",
|
||||
"lambda_id: 30.0\n",
|
||||
"lambda_rec: 10.0\n",
|
||||
"load_pretrain: checkpoints\n",
|
||||
"log_frep: 200\n",
|
||||
"lr: 0.0004\n",
|
||||
"model_freq: 10000\n",
|
||||
"name: simswap\n",
|
||||
"niter: 10000\n",
|
||||
"niter_decay: 10000\n",
|
||||
"phase: train\n",
|
||||
"sample_freq: 1000\n",
|
||||
"tag: simswap\n",
|
||||
"total_step: 1000000\n",
|
||||
"train_simswap: True\n",
|
||||
"use_tensorboard: False\n",
|
||||
"which_epoch: 800000\n",
|
||||
"-------------- End ----------------\n",
|
||||
"GPU used : 0\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.PReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.MaxPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.AdaptiveAvgPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.Sigmoid' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm1d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
|
||||
" warnings.warn(msg, SourceChangeWarning)\n",
|
||||
"Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth\" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_lite0-0aa007d2.pth\n",
|
||||
"processing Swapping dataset images...\n",
|
||||
"Finished preprocessing the Swapping dataset, total dirs number: 0...\n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"train.py\", line 163, in <module>\n",
|
||||
" train_loader = GetLoader(opt.dataset,opt.batchSize,8,1234)\n",
|
||||
" File \"/content/SimSwap/data/data_loader_Swapping.py\", line 119, in GetLoader\n",
|
||||
" drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)\n",
|
||||
" File \"/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py\", line 268, in __init__\n",
|
||||
" sampler = RandomSampler(dataset, generator=generator)\n",
|
||||
" File \"/usr/local/lib/python3.7/dist-packages/torch/utils/data/sampler.py\", line 103, in __init__\n",
|
||||
" \"value, but got num_samples={}\".format(self.num_samples))\n",
|
||||
"ValueError: num_samples should be a positive integer value, but got num_samples=0\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
10
train.py
10
train.py
@@ -5,7 +5,7 @@
|
||||
# Created Date: Monday December 27th 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 22nd April 2022 12:15:47 am
|
||||
# Last Modified: Friday, 22nd April 2022 12:34:40 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -42,16 +42,14 @@ class TrainOptions:
|
||||
self.parser.add_argument('--isTrain', type=str2bool, default='True')
|
||||
|
||||
# input/output sizes
|
||||
self.parser.add_argument('--batchSize', type=int, default=2, help='input batch size')
|
||||
self.parser.add_argument('--batchSize', type=int, default=4, help='input batch size')
|
||||
|
||||
# for displays
|
||||
self.parser.add_argument('--tag', type=str, default='simswap')
|
||||
self.parser.add_argument('--use_tensorboard', type=str2bool, default='False')
|
||||
|
||||
# for training
|
||||
self.parser.add_argument('--dataset', type=str, default="/path/to/VGGFace2", help='path to the face swapping dataset')
|
||||
self.parser.add_argument('--continue_train', type=str2bool, default='True', help='continue training: load the latest model')
|
||||
# self.parser.add_argument('--Gdeep', type=str2bool, default='False')
|
||||
self.parser.add_argument('--load_pretrain', type=str, default='checkpoints', help='load the pretrained model from the specified location')
|
||||
self.parser.add_argument('--which_epoch', type=str, default='320', help='which epoch to load? set to latest to use latest cached model')
|
||||
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
||||
@@ -60,7 +58,6 @@ class TrainOptions:
|
||||
self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
|
||||
self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam')
|
||||
self.parser.add_argument('--Gdeep', type=str2bool, default='False')
|
||||
self.parser.add_argument('--train_simswap', type=str2bool, default='True')
|
||||
|
||||
# for discriminators
|
||||
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
|
||||
@@ -182,9 +179,6 @@ if __name__ == '__main__':
|
||||
for interval in range(2):
|
||||
random.shuffle(randindex)
|
||||
src_image1, src_image2 = train_loader.next()
|
||||
if opt.train_simswap:
|
||||
src_image1 = F.interpolate(src_image1,size=(256,256), mode='bicubic')
|
||||
src_image2 = F.interpolate(src_image2,size=(256,256), mode='bicubic')
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
|
||||
Reference in New Issue
Block a user