From 3bd2c7438805848555e3439bf53d83ef498792a3 Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Fri, 22 Apr 2022 01:53:01 +0800 Subject: [PATCH] update --- README.md | 7 +- train.ipynb | 203 ++++++---------------------------------------------- 2 files changed, 25 insertions(+), 185 deletions(-) diff --git a/README.md b/README.md index 5decf87..7e7e44d 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,8 @@ *Our method can realize **arbitrary face swapping** on images and videos with **one single trained model**.* -Training and test code are now available! [Colab demo](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/train.ipynb) +Training and test code are now available! +[ google colab logo](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/train.ipynb) We are working with our incoming paper SimSwap++, keeping expecting! @@ -26,6 +27,8 @@ If you find this project useful, please star it. It is the greatest appreciation ## Top News +**`2022-04-21`**: For resource limited users, we provide the cropped VGGFace2-224 dataset [VGGFace2-224 (10.8G)](https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing). + **`2022-04-20`**: Training scripts are now available. We highly recommend that you guys train the simswap model with our released high quality dataset [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ). **`2021-11-24`**: We have trained a beta version of ***SimSwap-HQ*** on [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ) and open sourced the checkpoint of this model (if you think the Simswap 512 is cool, please star our [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ) repo). Please don’t forget to go to [Preparation](./docs/guidance/preparation.md) and [Inference for image or video face swapping](./docs/guidance/usage.md) to check the latest set up. @@ -65,7 +68,7 @@ Download the dataset from [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ). The training script is slightly different from the original version, e.g., we replace the patch discriminator with the projected discriminator, which saves a lot of hardware overhead and achieves slightly better results. In order to ensure normal training, the batch size must be greater than 1. -- Train 224 models with VGGFace2 224*224 [VGGFace2-224](https://github.com/NNNNAI/VGGFace2-HQ) +- Train 224 models with VGGFace2 224*224 [VGGFace2-224 (10.8G)](https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing) ``` python train.py --name simswap224_test --batchSize 4 --gpu_ids 0 --dataset /path/to/VGGFace2HQ --Gdeep False ``` diff --git a/train.ipynb b/train.ipynb index 97526b7..7cd8905 100644 --- a/train.ipynb +++ b/train.ipynb @@ -37,41 +37,10 @@ "!nvidia-smi" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "J8WrNaQHuUGC", - "outputId": "afffa0be-92b5-4133-b6d9-6c3e08c6de64" + "id": "J8WrNaQHuUGC" }, "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Thu Apr 21 16:07:35 2022 \n", - "+-----------------------------------------------------------------------------+\n", - "| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", - "|-------------------------------+----------------------+----------------------+\n", - "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", - "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", - "| | | MIG M. |\n", - "|===============================+======================+======================|\n", - "| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n", - "| N/A 67C P8 32W / 149W | 0MiB / 11441MiB | 0% Default |\n", - "| | | N/A |\n", - "+-------------------------------+----------------------+----------------------+\n", - " \n", - "+-----------------------------------------------------------------------------+\n", - "| Processes: |\n", - "| GPU GI CI PID Type Process name GPU Memory |\n", - "| ID ID Usage |\n", - "|=============================================================================|\n", - "| No running processes found |\n", - "+-----------------------------------------------------------------------------+\n" - ] - } - ] + "outputs": [] }, { "cell_type": "markdown", @@ -99,29 +68,10 @@ "!cd SimSwap && git pull" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9jZWwt97uvIe", - "outputId": "42a1bda8-3ca3-46af-fc82-d1af99ce15e1" + "id": "9jZWwt97uvIe" }, "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Cloning into 'SimSwap'...\n", - "remote: Enumerating objects: 1017, done.\u001b[K\n", - "remote: Counting objects: 100% (16/16), done.\u001b[K\n", - "remote: Compressing objects: 100% (13/13), done.\u001b[K\n", - "remote: Total 1017 (delta 5), reused 10 (delta 3), pack-reused 1001\u001b[K\n", - "Receiving objects: 100% (1017/1017), 210.79 MiB | 14.80 MiB/s, done.\n", - "Resolving deltas: 100% (510/510), done.\n", - "Already up to date.\n" - ] - } - ] + "outputs": [] }, { "cell_type": "markdown", @@ -140,32 +90,22 @@ "!wget -P SimSwap/arcface_model https://github.com/neuralchen/SimSwap/releases/download/1.0/arcface_checkpoint.tar" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "rwvbPhtOvZAL", - "outputId": "ffa12208-d388-412d-e83b-c54864c4526e" + "id": "rwvbPhtOvZAL" }, "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Requirement already satisfied: googledrivedownloader in /usr/local/lib/python3.7/dist-packages (0.4)\n", - "Requirement already satisfied: imageio==2.4.1 in /usr/local/lib/python3.7/dist-packages (2.4.1)\n", - "Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from imageio==2.4.1) (7.1.2)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from imageio==2.4.1) (1.21.6)\n" - ] - } - ] + "outputs": [] }, { "cell_type": "markdown", "source": [ "#Download the Training Dataset\n", "We employ the cropped VGGFace2-224 dataset for this toy training demo.\n", - "You can download the dataset from our google driver " + "\n", + "You can download the dataset from our google driver https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing\n", + "\n", + "***Please check the dataset in dir /content/TrainingData***\n", + "\n", + "***If dataset already exists in /content/TrainingData, please do not run blow scripts!***\n" ], "metadata": { "id": "hleVtHIJ_QUK" @@ -174,28 +114,16 @@ { "cell_type": "code", "source": [ - "from google_drive_downloader import GoogleDriveDownloader as gdd\n", - "gdd.download_file_from_google_drive(file_id='1iytA1n2z4go3uVCwE__vIKouTKyIDjEq',dest_path='/content/TrainingData/vggface2_crop_arcfacealign_224.tar',showsize=True)\n", - "!tar -xzvf /content/TrainingData/vggface2_crop_arcfacealign_224.tar" + "!wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc\" -O /content/TrainingData/vggface2_crop_arcfacealign_224.tar && rm -rf /tmp/cookies.txt\n", + "%%cd /content/\n", + "!tar -xzvf /content/TrainingData/vggface2_crop_arcfacealign_224.tar\n", + "!rm /content/TrainingData/vggface2_crop_arcfacealign_224.tar" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "gMVKEej59LX9", - "outputId": "2e508c44-d006-4183-81d9-f9753d08dea7" + "id": "h2tyjBl0Llxp" }, "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading 1iytA1n2z4go3uVCwE__vIKouTKyIDjEq into /content/TrainingData/mnist.zip... \n", - "0.0 B Done.\n" - ] - } - ] + "outputs": [] }, { "cell_type": "markdown", @@ -215,101 +143,10 @@ "!python train.py --name simswap224_test --gpu_ids 0 --dataset /content/TrainingData/vggface2_crop_arcfacealign_224 --Gdeep False" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "XCxHa4oW507s", - "outputId": "c84c52d9-0b36-4932-925d-1ae38a3f7bb0" + "id": "XCxHa4oW507s" }, "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "/content/SimSwap\n", - " arcface_model\t predict.py\n", - " cog.yaml\t README.md\n", - " crop_224\t 'SimSwap colab.ipynb'\n", - " data\t\t simswaplogo\n", - " demo_file\t test_one_image.py\n", - " docs\t\t test_video_swapmulti.py\n", - " download-weights.sh test_video_swap_multispecific.py\n", - " insightface_func test_video_swapsingle.py\n", - " LICENSE\t test_video_swapspecific.py\n", - " models\t\t test_wholeimage_swapmulti.py\n", - " MultiSpecific.ipynb test_wholeimage_swap_multispecific.py\n", - " options\t test_wholeimage_swapsingle.py\n", - " output\t\t test_wholeimage_swapspecific.py\n", - " parsing_model\t train.py\n", - " pg_modules\t util\n", - "------------ Options -------------\n", - "Arc_path: arcface_model/arcface_checkpoint.tar\n", - "Gdeep: False\n", - "batchSize: 2\n", - "beta1: 0.0\n", - "checkpoints_dir: ./checkpoints\n", - "continue_train: False\n", - "dataset: /path/to/VGGFace2\n", - "gpu_ids: 0\n", - "isTrain: True\n", - "lambda_feat: 10.0\n", - "lambda_id: 30.0\n", - "lambda_rec: 10.0\n", - "load_pretrain: checkpoints\n", - "log_frep: 200\n", - "lr: 0.0004\n", - "model_freq: 10000\n", - "name: simswap\n", - "niter: 10000\n", - "niter_decay: 10000\n", - "phase: train\n", - "sample_freq: 1000\n", - "tag: simswap\n", - "total_step: 1000000\n", - "train_simswap: True\n", - "use_tensorboard: False\n", - "which_epoch: 800000\n", - "-------------- End ----------------\n", - "GPU used : 0\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.PReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.MaxPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.AdaptiveAvgPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.Sigmoid' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm1d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth\" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_lite0-0aa007d2.pth\n", - "processing Swapping dataset images...\n", - "Finished preprocessing the Swapping dataset, total dirs number: 0...\n", - "Traceback (most recent call last):\n", - " File \"train.py\", line 163, in \n", - " train_loader = GetLoader(opt.dataset,opt.batchSize,8,1234)\n", - " File \"/content/SimSwap/data/data_loader_Swapping.py\", line 119, in GetLoader\n", - " drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)\n", - " File \"/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py\", line 268, in __init__\n", - " sampler = RandomSampler(dataset, generator=generator)\n", - " File \"/usr/local/lib/python3.7/dist-packages/torch/utils/data/sampler.py\", line 103, in __init__\n", - " \"value, but got num_samples={}\".format(self.num_samples))\n", - "ValueError: num_samples should be a positive integer value, but got num_samples=0\n" - ] - } - ] + "outputs": [] } ] } \ No newline at end of file