From b4ba93300b88691bec82af05cebfde2c0132afeb Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Fri, 22 Apr 2022 00:48:21 +0800 Subject: [PATCH] update --- README.md | 10 +- train.ipynb | 315 ++++++++++++++++++++++++++++++++++++++++++++++++++++ train.py | 10 +- 3 files changed, 322 insertions(+), 13 deletions(-) create mode 100644 train.ipynb diff --git a/README.md b/README.md index 9ed8c16..3ac227a 100644 --- a/README.md +++ b/README.md @@ -65,14 +65,14 @@ Download the dataset from [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ). The training script is slightly different from the original version, e.g., we replace the patch discriminator with the projected discriminator, which saves a lot of hardware overhead and achieves slightly better results. In order to ensure normal training, the batch size must be greater than 1. -- Train 256 models +- Train 224 models with VGGFace2 224*224 [VGGFace2-224](https://github.com/NNNNAI/VGGFace2-HQ) ``` -python train.py --name simswap256_test --gpu_ids 0 --dataset /path/to/VGGFace2HQ --train_simswap True --Gdeep False +python train.py --name simswap224_test --batchSize 4 --gpu_ids 0 --dataset /path/to/VGGFace2HQ --Gdeep False ``` -- Train 512 models +- Train 512 models with VGGFace2-HQ 512*512 [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ). ``` -python train.py --name simswap512_test --gpu_ids 0 --dataset /path/to/VGGFace2HQ --train_simswap False --Gdeep True +python train.py --name simswap512_test --gpu_ids 0 --dataset /path/to/VGGFace2HQ --Gdeep True ``` @@ -86,7 +86,7 @@ python train.py --name simswap512_test --gpu_ids 0 --dataset /path/to/VGGFace2H
Stronger feature
-[Colab fo switching specific faces in multi-face videos](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/MultiSpecific.ipynb) +[Colab for switching specific faces in multi-face videos](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/MultiSpecific.ipynb) [Image face swapping demo & Docker image on Replicate](https://replicate.ai/neuralchen/simswap-image) diff --git a/train.ipynb b/train.ipynb new file mode 100644 index 0000000..97526b7 --- /dev/null +++ b/train.ipynb @@ -0,0 +1,315 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "train.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "#Training Demo\n", + "This is a simple example for training the SimSwap 224*224 with VGGFace2-224.\n", + "\n", + "Code path: https://github.com/neuralchen/SimSwap\n", + "If you like the SimSwap project, please star it!\n", + "Paper path: https://arxiv.org/pdf/2106.06340v1.pdf or https://dl.acm.org/doi/10.1145/3394171.3413630" + ], + "metadata": { + "id": "fC7QoKePuJWu" + } + }, + { + "cell_type": "code", + "source": [ + "!nvidia-smi" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "J8WrNaQHuUGC", + "outputId": "afffa0be-92b5-4133-b6d9-6c3e08c6de64" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Thu Apr 21 16:07:35 2022 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n", + "| N/A 67C P8 32W / 149W | 0MiB / 11441MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Installation\n", + "All file changes made by this notebook are temporary. You can try to mount your own google drive to store files if you want." + ], + "metadata": { + "id": "Z6BtQIgWuoqt" + } + }, + { + "cell_type": "markdown", + "source": [ + "#Get Scripts" + ], + "metadata": { + "id": "wdQJ9d8N8Tnf" + } + }, + { + "cell_type": "code", + "source": [ + "!git clone https://github.com/neuralchen/SimSwap\n", + "!cd SimSwap && git pull" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9jZWwt97uvIe", + "outputId": "42a1bda8-3ca3-46af-fc82-d1af99ce15e1" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'SimSwap'...\n", + "remote: Enumerating objects: 1017, done.\u001b[K\n", + "remote: Counting objects: 100% (16/16), done.\u001b[K\n", + "remote: Compressing objects: 100% (13/13), done.\u001b[K\n", + "remote: Total 1017 (delta 5), reused 10 (delta 3), pack-reused 1001\u001b[K\n", + "Receiving objects: 100% (1017/1017), 210.79 MiB | 14.80 MiB/s, done.\n", + "Resolving deltas: 100% (510/510), done.\n", + "Already up to date.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Install Blocks" + ], + "metadata": { + "id": "ATLrrbso8Y-Y" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install googledrivedownloader\n", + "!pip install timm\n", + "!wget -P SimSwap/arcface_model https://github.com/neuralchen/SimSwap/releases/download/1.0/arcface_checkpoint.tar" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rwvbPhtOvZAL", + "outputId": "ffa12208-d388-412d-e83b-c54864c4526e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: googledrivedownloader in /usr/local/lib/python3.7/dist-packages (0.4)\n", + "Requirement already satisfied: imageio==2.4.1 in /usr/local/lib/python3.7/dist-packages (2.4.1)\n", + "Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from imageio==2.4.1) (7.1.2)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from imageio==2.4.1) (1.21.6)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "#Download the Training Dataset\n", + "We employ the cropped VGGFace2-224 dataset for this toy training demo.\n", + "You can download the dataset from our google driver " + ], + "metadata": { + "id": "hleVtHIJ_QUK" + } + }, + { + "cell_type": "code", + "source": [ + "from google_drive_downloader import GoogleDriveDownloader as gdd\n", + "gdd.download_file_from_google_drive(file_id='1iytA1n2z4go3uVCwE__vIKouTKyIDjEq',dest_path='/content/TrainingData/vggface2_crop_arcfacealign_224.tar',showsize=True)\n", + "!tar -xzvf /content/TrainingData/vggface2_crop_arcfacealign_224.tar" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gMVKEej59LX9", + "outputId": "2e508c44-d006-4183-81d9-f9753d08dea7" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading 1iytA1n2z4go3uVCwE__vIKouTKyIDjEq into /content/TrainingData/mnist.zip... \n", + "0.0 B Done.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "#Trainig\n", + "Batch size must larger than 1!" + ], + "metadata": { + "id": "o5SNDWzA8LjJ" + } + }, + { + "cell_type": "code", + "source": [ + "%cd /content/SimSwap\n", + "!ls\n", + "!python train.py --name simswap224_test --gpu_ids 0 --dataset /content/TrainingData/vggface2_crop_arcfacealign_224 --Gdeep False" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XCxHa4oW507s", + "outputId": "c84c52d9-0b36-4932-925d-1ae38a3f7bb0" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/SimSwap\n", + " arcface_model\t predict.py\n", + " cog.yaml\t README.md\n", + " crop_224\t 'SimSwap colab.ipynb'\n", + " data\t\t simswaplogo\n", + " demo_file\t test_one_image.py\n", + " docs\t\t test_video_swapmulti.py\n", + " download-weights.sh test_video_swap_multispecific.py\n", + " insightface_func test_video_swapsingle.py\n", + " LICENSE\t test_video_swapspecific.py\n", + " models\t\t test_wholeimage_swapmulti.py\n", + " MultiSpecific.ipynb test_wholeimage_swap_multispecific.py\n", + " options\t test_wholeimage_swapsingle.py\n", + " output\t\t test_wholeimage_swapspecific.py\n", + " parsing_model\t train.py\n", + " pg_modules\t util\n", + "------------ Options -------------\n", + "Arc_path: arcface_model/arcface_checkpoint.tar\n", + "Gdeep: False\n", + "batchSize: 2\n", + "beta1: 0.0\n", + "checkpoints_dir: ./checkpoints\n", + "continue_train: False\n", + "dataset: /path/to/VGGFace2\n", + "gpu_ids: 0\n", + "isTrain: True\n", + "lambda_feat: 10.0\n", + "lambda_id: 30.0\n", + "lambda_rec: 10.0\n", + "load_pretrain: checkpoints\n", + "log_frep: 200\n", + "lr: 0.0004\n", + "model_freq: 10000\n", + "name: simswap\n", + "niter: 10000\n", + "niter_decay: 10000\n", + "phase: train\n", + "sample_freq: 1000\n", + "tag: simswap\n", + "total_step: 1000000\n", + "train_simswap: True\n", + "use_tensorboard: False\n", + "which_epoch: 800000\n", + "-------------- End ----------------\n", + "GPU used : 0\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.PReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.MaxPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.AdaptiveAvgPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.Sigmoid' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm1d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", + " warnings.warn(msg, SourceChangeWarning)\n", + "Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth\" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_lite0-0aa007d2.pth\n", + "processing Swapping dataset images...\n", + "Finished preprocessing the Swapping dataset, total dirs number: 0...\n", + "Traceback (most recent call last):\n", + " File \"train.py\", line 163, in \n", + " train_loader = GetLoader(opt.dataset,opt.batchSize,8,1234)\n", + " File \"/content/SimSwap/data/data_loader_Swapping.py\", line 119, in GetLoader\n", + " drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)\n", + " File \"/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py\", line 268, in __init__\n", + " sampler = RandomSampler(dataset, generator=generator)\n", + " File \"/usr/local/lib/python3.7/dist-packages/torch/utils/data/sampler.py\", line 103, in __init__\n", + " \"value, but got num_samples={}\".format(self.num_samples))\n", + "ValueError: num_samples should be a positive integer value, but got num_samples=0\n" + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/train.py b/train.py index 03dfb58..7059d48 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ # Created Date: Monday December 27th 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Friday, 22nd April 2022 12:15:47 am +# Last Modified: Friday, 22nd April 2022 12:34:40 am # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -42,16 +42,14 @@ class TrainOptions: self.parser.add_argument('--isTrain', type=str2bool, default='True') # input/output sizes - self.parser.add_argument('--batchSize', type=int, default=2, help='input batch size') + self.parser.add_argument('--batchSize', type=int, default=4, help='input batch size') # for displays - self.parser.add_argument('--tag', type=str, default='simswap') self.parser.add_argument('--use_tensorboard', type=str2bool, default='False') # for training self.parser.add_argument('--dataset', type=str, default="/path/to/VGGFace2", help='path to the face swapping dataset') self.parser.add_argument('--continue_train', type=str2bool, default='True', help='continue training: load the latest model') - # self.parser.add_argument('--Gdeep', type=str2bool, default='False') self.parser.add_argument('--load_pretrain', type=str, default='checkpoints', help='load the pretrained model from the specified location') self.parser.add_argument('--which_epoch', type=str, default='320', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') @@ -60,7 +58,6 @@ class TrainOptions: self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam') self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam') self.parser.add_argument('--Gdeep', type=str2bool, default='False') - self.parser.add_argument('--train_simswap', type=str2bool, default='True') # for discriminators self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') @@ -182,9 +179,6 @@ if __name__ == '__main__': for interval in range(2): random.shuffle(randindex) src_image1, src_image2 = train_loader.next() - if opt.train_simswap: - src_image1 = F.interpolate(src_image1,size=(256,256), mode='bicubic') - src_image2 = F.interpolate(src_image2,size=(256,256), mode='bicubic') if step%2 == 0: img_id = src_image2