mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Compare commits
456 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6547fcfe7b | |||
| 2e6394565a | |||
| 9f9f9dbad7 | |||
| dbe79aa3b9 | |||
| 2809a59704 | |||
| a86497177d | |||
| 338f49c3dc | |||
| 56b71048e3 | |||
| 7490ead302 | |||
| ecc37873bf | |||
| bd762c4c38 | |||
| f5c49a02cb | |||
| f4d4914f5c | |||
| 2f28fb664b | |||
| 35c250b0c9 | |||
| 580a179f44 | |||
| e846d88145 | |||
| e894e4172a | |||
| fc766b8327 | |||
| a06f5fd9e8 | |||
| ce7aaa57dc | |||
| fce54eb7db | |||
| 3e9c8a37e7 | |||
| 143b594ee6 | |||
| 47bebb02d7 | |||
| 94cbcb68f0 | |||
| a602bbd474 | |||
| 24f45877f5 | |||
| 0722db91f1 | |||
| 475b8b1538 | |||
| d68b77bd4d | |||
| 8806accbb3 | |||
| d4a8719870 | |||
| 1a41a941e2 | |||
| 5c855aae4e | |||
| 810df0f540 | |||
| 03011200e4 | |||
| 837ee1e18c | |||
| ef62a2ee9e | |||
| af455f5236 | |||
| 0b7db0cc27 | |||
| 140cad492a | |||
| d44ac98e38 | |||
| d990ce4575 | |||
| 982a94b535 | |||
| 5b41d8e91f | |||
| bcf5b4e5a8 | |||
| 128726701b | |||
| 4a319ec9bd | |||
| 39ce14b590 | |||
| 1477850a23 | |||
| f4c4066e8c | |||
| d9fe667ced | |||
| b7a6f00e8b | |||
| b215db68c3 | |||
| dc2b2dc982 | |||
| 76fe5c351c | |||
| 056bacb7de | |||
| dafada11bc | |||
| 4f5ac00a7b | |||
| 2e3c3517cb | |||
| 7845dd8522 | |||
| 4b851a173d | |||
| f99c73495c | |||
| cc6a99f305 | |||
| 9df29f8a22 | |||
| 80e600cbb5 | |||
| 117a9d0fc9 | |||
| d2be8a386a | |||
| 0743b99347 | |||
| 4f4057fc54 | |||
| 4ebdeee634 | |||
| 00d5c1f200 | |||
| 99a8527e24 | |||
| 602e890af2 | |||
| 9ede8a2a7d | |||
| c85c755e00 | |||
| c1f39a73dd | |||
| 583d09e666 | |||
| 9153b4ce8f | |||
| 35afb426b7 | |||
| b6a2734622 | |||
| 10b6f801d1 | |||
| 798ff87736 | |||
| df3e22fdf9 | |||
| 24f2e14a95 | |||
| c45fcbba84 | |||
| ad675ae633 | |||
| 803902c8bb | |||
| bc1b04a107 | |||
| 345a225c94 | |||
| eefc69a820 | |||
| 94571c5676 | |||
| 904a447e06 | |||
| 33d00ac941 | |||
| 5234874bc7 | |||
| b5efcbe44a | |||
| bdd7fd0d86 | |||
| 7c75b0d898 | |||
| 0d73bcf918 | |||
| 8f0ee4935b | |||
| 431df7cff8 | |||
| c0aaae9358 | |||
| e4d2d244a0 | |||
| 72591fbed1 | |||
| cf0bd93814 | |||
| 0732924f1e | |||
| bf7bbc2550 | |||
| f989df39e9 | |||
| ecc5b1a1d5 | |||
| 5550c89f43 | |||
| 113f9fa6e5 | |||
| 20aa5114e1 | |||
| 3c0a65c7a0 | |||
| 2af4f2c8ab | |||
| 8b465fce03 | |||
| d212e2fe12 | |||
| 569df9e96d | |||
| 738e00d59e | |||
| 564cc7b127 | |||
| 944096befc | |||
| 0991745753 | |||
| 70ac772a34 | |||
| 31303c1c6c | |||
| e85aa20602 | |||
| 5567b49a6d | |||
| 99931def84 | |||
| 4f67e045a0 | |||
| af09ee7ff3 | |||
| 09432d9214 | |||
| afab997ffc | |||
| e758eb3e19 | |||
| f90fd73b54 | |||
| 8f76d96bb2 | |||
| 66d1573f4b | |||
| a7a21cd684 | |||
| 52b98b5be5 | |||
| 0304b5dd91 | |||
| 2322b6539f | |||
| b6e131e4c1 | |||
| d809f66216 | |||
| 90cb6afe10 | |||
| 7f16d0a10e | |||
| 54923abf7f | |||
| e267f9ffd5 | |||
| 27440ef023 | |||
| c2a639229f | |||
| f9d105ea2b | |||
| d2efb2fd08 | |||
| 64464b6f1c | |||
| 606bf42089 | |||
| a3ac4d5ddd | |||
| 1659805b08 | |||
| 8f1f002c64 | |||
| 4af22832db | |||
| 9ff30a0268 | |||
| dbe931e950 | |||
| 862dce7bc6 | |||
| 8101b15e1c | |||
| 1dfd230fc5 | |||
| e5f983b2bf | |||
| 7e938c2ec9 | |||
| c7d55d0d17 | |||
| 6bc44ad3d8 | |||
| b829d5e42c | |||
| ab3b699124 | |||
| a5d99c139e | |||
| bfa9924b40 | |||
| a0c42bedbe | |||
| d215b6f98b | |||
| 01278d679f | |||
| a2a9b78dac | |||
| cf26f66e36 | |||
| f4a1e18ca9 | |||
| 7ab7efbbf4 | |||
| 6a5f81e5fe | |||
| 6a11603e7e | |||
| ff9b777b28 | |||
| 5bacb048dd | |||
| 847579f925 | |||
| 57aad5204e | |||
| d3b0051912 | |||
| d944d95bfd | |||
| 8f1f63f2ef | |||
| b59e172fa3 | |||
| 368da824aa | |||
| e61e470432 | |||
| c8953ce8a1 | |||
| e428ae04e3 | |||
| a89e51c2f8 | |||
| 61f48d9246 | |||
| dfd018a897 | |||
| b63562abad | |||
| e9ea9cd9e5 | |||
| abdc770892 | |||
| 7cf5609c1f | |||
| 6388727262 | |||
| aedaa20d78 | |||
| f9de4ce78a | |||
| 6d805438ad | |||
| 7cc893c32e | |||
| 3e69d5a9a9 | |||
| f791178ded | |||
| 866019d44f | |||
| 94ad33cb1e | |||
| de72e50233 | |||
| 5f053f9f69 | |||
| f678aa8f7e | |||
| 0e8207ccc8 | |||
| 4a2559c866 | |||
| 3c5554c1c5 | |||
| dbf5687bcd | |||
| 2148e9b701 | |||
| 64ebfa7b84 | |||
| daeec46e36 | |||
| 13d15029b7 | |||
| d5c51a90e8 | |||
| 6fa8d6b6eb | |||
| 72371b9f11 | |||
| 786adf73a2 | |||
| dcc5ccccd7 | |||
| 5056b8df75 | |||
| 430c71d031 | |||
| 176dced1f6 | |||
| 18a605e1a3 | |||
| fea75ff949 | |||
| df5895e266 | |||
| f2d3f8a19f | |||
| ceb3c0cfdf | |||
| dd0a2fe649 | |||
| e5a4a54e61 | |||
| 9dc1031fa5 | |||
| d93daf9e5f | |||
| 3a61da8bab | |||
| 83c20f8331 | |||
| dfd9e99aed | |||
| 2fb0b4289d | |||
| e6ea454360 | |||
| 80d1694e23 | |||
| 5e68de9170 | |||
| 6ca68f1408 | |||
| 43f99db5d7 | |||
| c9e70ebc18 | |||
| 1ab57b1d3f | |||
| cc9a0ba83e | |||
| 589568bfb5 | |||
| 34a7f3ef55 | |||
| c35b0a6f4c | |||
| 23ab7dc89d | |||
| 8443f3512a | |||
| d3a2035d7a | |||
| 16c8b32269 | |||
| 65ab796835 | |||
| 56be3f0b9b | |||
| a22adaf51f | |||
| 0055c0c97f | |||
| 35b779b1ed | |||
| ea1b0205f0 | |||
| a5eb7d6aa1 | |||
| b27b8663e5 | |||
| d87f6c0b15 | |||
| 5d1b90ff19 | |||
| 2ddcf52b66 | |||
| c8801ececd | |||
| 84b4451366 | |||
| cadbe9cf76 | |||
| 58a85a80bb | |||
| ab0a59fb74 | |||
| 578b07a7f4 | |||
| 7ce9d27097 | |||
| 8c24c9ec27 | |||
| 0ad2556c4c | |||
| 484a49c27d | |||
| e8cc2bfff1 | |||
| a951d700fc | |||
| 6eff69a41a | |||
| 5b3b2abdd7 | |||
| bfcbd6bf95 | |||
| 607c55ff1f | |||
| 93cbbf52d0 | |||
| f5cd6b6336 | |||
| 257e5e56a4 | |||
| f19908ccd6 | |||
| 18a2531b54 | |||
| 84be7d1ffb | |||
| 8b53c76a0a | |||
| 7d8cb146a4 | |||
| 303cbfa024 | |||
| bc174186eb | |||
| d7158749c2 | |||
| 8b2b6892aa | |||
| 5bba2a1c69 | |||
| 94480e16eb | |||
| bbcb1c35f0 | |||
| ee3fc40e83 | |||
| 335d597e53 | |||
| 14bbece850 | |||
| fad38da864 | |||
| e75a3c58f9 | |||
| 6fed877d33 | |||
| de6cfbc35b | |||
| ed0f6ae897 | |||
| 63e4bea3cd | |||
| 14b9bccafe | |||
| 579d3ef51c | |||
| a797548329 | |||
| 6eabcad1d0 | |||
| 7848d28b02 | |||
| 30e787129a | |||
| 38211f0340 | |||
| f2833a32c3 | |||
| 3b7d3b6688 | |||
| 086d9eed87 | |||
| 085c493e18 | |||
| 2220f5ef08 | |||
| f482d46798 | |||
| a6e1405c70 | |||
| ac41bab3a2 | |||
| 206a1411d1 | |||
| dc0abff0ce | |||
| 10ce04ed58 | |||
| 094d5cea9e | |||
| f6c59257d9 | |||
| 83ef075b1d | |||
| 575f215408 | |||
| d153c68813 | |||
| 00dccf07b9 | |||
| 5a6e3393e2 | |||
| c17378f3c7 | |||
| 84503761b9 | |||
| 723e9fde78 | |||
| 09e913233b | |||
| b4bbd862e2 | |||
| 4d8433f54a | |||
| db44c91dd8 | |||
| b33281425a | |||
| 3d9ff4add0 | |||
| 5934b47961 | |||
| 04eaa831ea | |||
| c1bed34c27 | |||
| 4078681031 | |||
| 15ee6fa763 | |||
| 251e610f0e | |||
| da51c5336d | |||
| 9bd68c3d14 | |||
| 0a50e2d706 | |||
| 40dcef7fc7 | |||
| c041073953 | |||
| 354315502b | |||
| 254bc17c98 | |||
| 39c0313202 | |||
| 3cf9711df0 | |||
| d1bf54276d | |||
| b47c6b72ee | |||
| dcf19634d1 | |||
| 857365770f | |||
| d25f2865a9 | |||
| 0d45568bd1 | |||
| ccf6fa7f43 | |||
| 28977d37d6 | |||
| c161da2f25 | |||
| 39818a16df | |||
| e1ba81f220 | |||
| bf696be097 | |||
| f63bc788ac | |||
| 11bb9065ba | |||
| 7b2b8f0f85 | |||
| 999f2c9cbe | |||
| 254f3efe68 | |||
| 5bb41ecbb2 | |||
| 777a8384c2 | |||
| 0949c1358b | |||
| d9e10a9f7c | |||
| cd4b10c832 | |||
| 0b7d25a36e | |||
| 026bcf0c97 | |||
| aa5094e576 | |||
| 8f8dfecdbc | |||
| 030d912c1b | |||
| 0e148845af | |||
| e1e0c11bb5 | |||
| 88c4e53192 | |||
| 650551c19b | |||
| a971506271 | |||
| b69f69d015 | |||
| 953525e6b0 | |||
| 58ad6af619 | |||
| 7264884ff9 | |||
| 1872f99584 | |||
| 29e82f909a | |||
| 257ab668ee | |||
| 62d897f9d8 | |||
| f05ff6cdb1 | |||
| 3fe32d7832 | |||
| 3be8368eaa | |||
| b6b4f9f65b | |||
| e33bc0d52a | |||
| f3409c5ade | |||
| c6d16c0cf6 | |||
| dfe7ab3b6f | |||
| 3c6dfa4efe | |||
| 34d0bc10ed | |||
| b785525f3b | |||
| dd320ea5be | |||
| 71e0ae34c0 | |||
| 32dfdcf1b3 | |||
| 494b84aecb | |||
| 860771e482 | |||
| b42d2b06e7 | |||
| 11c038cb81 | |||
| 746cc86d52 | |||
| 58c81cd646 | |||
| 4d2038d4ce | |||
| 67ad9badac | |||
| 1f4405be44 | |||
| 1b6e7a6ca5 | |||
| 62a69cddd2 | |||
| 611618e413 | |||
| 6381e755d7 | |||
| 2ed558a873 | |||
| b7e2d3ccd7 | |||
| 66af7d7957 | |||
| dc0ef53668 | |||
| a17d050648 | |||
| 989a81c751 | |||
| a1cd025f81 | |||
| 9c15f584aa | |||
| 5892460c3d | |||
| bb0e3b4a8a | |||
| fcb3390796 | |||
| 3b8b6442fc | |||
| cfcf0ee2bd | |||
| 4260bd6c28 | |||
| d018dd0633 | |||
| 798ba48a52 | |||
| e45f46d355 | |||
| 008a221f55 | |||
| 10826558b4 | |||
| 23ac63d55b | |||
| 2bbba3563b | |||
| e3273221c9 | |||
| 9e1c71b498 | |||
| ef313042c6 | |||
| a1fd382659 | |||
| 650268c06b | |||
| fe28b6fffe | |||
| e6c2a64256 | |||
| 7bef17b551 | |||
| 8e53c6bc9f | |||
| a461d9c389 | |||
| f3e1c3feaa | |||
| abce63007d | |||
| 5fa53dabf2 | |||
| bd4725b52f | |||
| ef38a0e23f | |||
| e1348dd596 |
@@ -1,7 +1,5 @@
|
||||
[flake8]
|
||||
select = E3, E4, F, I1, I2
|
||||
select = E22, E23, E24, E27, E3, E4, E7, F, I1, I2
|
||||
plugins = flake8-import-order
|
||||
application_import_names = arcface_converter
|
||||
application_import_names = crossface, hyperswap
|
||||
import-order-style = pycharm
|
||||
per-file-ignores = preparing.py:E402
|
||||
|
||||
|
||||
+1
-2
@@ -1,2 +1 @@
|
||||
github: henryruhs
|
||||
custom: [ buymeacoffee.com/henryruhs, paypal.me/henryruhs ]
|
||||
custom: [ buymeacoffee.com/facefusion, ko-fi.com/facefusion ]
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.3 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 1.0 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 1.3 MiB |
@@ -8,12 +8,24 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Python 3.10
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
- run: pip install flake8
|
||||
- run: pip install flake8-import-order
|
||||
- run: pip install mypy
|
||||
- run: flake8 arcface_converter
|
||||
- run: mypy arcface_converter
|
||||
- run: flake8 crossface hyperswap
|
||||
- run: mypy crossface hyperswap
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- run: pip install torch torchvision
|
||||
- run: pip install pytest
|
||||
- run: PYTHONPATH=/home/runner/work/facefusion-labs/facefusion-labs pytest
|
||||
|
||||
@@ -1,2 +1,11 @@
|
||||
__pycache__
|
||||
.assets
|
||||
.claude
|
||||
.datasets
|
||||
.idea
|
||||
.inputs
|
||||
.exports
|
||||
.logs
|
||||
.models
|
||||
.outputs
|
||||
.vscode
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
MIT license
|
||||
|
||||
Copyright (c) 2024 Henry Ruhs
|
||||
@@ -4,4 +4,3 @@ FaceFusion Labs
|
||||
> Industry leading face manipulation platform.
|
||||
|
||||
[](https://github.com/facefusion/facefusion-labs/actions?query=workflow:ci)
|
||||

|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
ArcFace Converter
|
||||
=================
|
||||
|
||||
> Convert face embeddings between various ArcFace models.
|
||||
|
||||
|
||||
Preview
|
||||
-------
|
||||
|
||||

|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
This example utilizes the MegaFace dataset to train an ArcFace Converter for SimSwap.
|
||||
|
||||
```
|
||||
[preparing.dataset]
|
||||
dataset_path = datasets/megaface/train.rec
|
||||
crop_size = 112
|
||||
process_limit = 650000
|
||||
|
||||
[preparing.model]
|
||||
source_path = models/arcface_w600k_r50.onnx
|
||||
target_path = models/arcface_simswap.onnx
|
||||
|
||||
[preparing.input]
|
||||
directory_path = inputs
|
||||
source_path = inputs/arcface_w600k_r50.npy
|
||||
target_path = inputs/arcface_simswap.npy
|
||||
|
||||
[training.loader]
|
||||
split_ratio = 0.8
|
||||
batch_size = 51200
|
||||
num_workers = 8
|
||||
|
||||
[training.trainer]
|
||||
max_epochs = 4096
|
||||
|
||||
[training.output]
|
||||
directory_path = outputs
|
||||
file_pattern = arcface_converter_simswap_{epoch:02d}_{val_loss:.4f}
|
||||
|
||||
[exporting]
|
||||
directory_path = exports
|
||||
source_path = outputs/last.ckpt
|
||||
target_path = exports/arcface_converter_simswap.onnx
|
||||
opset_version = 15
|
||||
|
||||
[execution]
|
||||
providers = CUDAExecutionProvider
|
||||
```
|
||||
|
||||
|
||||
Preparing
|
||||
---------
|
||||
|
||||
Prepare the face embedding pairs.
|
||||
|
||||
```
|
||||
python prepare.py
|
||||
```
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
|
||||
Train the ArcFace converter model.
|
||||
|
||||
```
|
||||
python train.py
|
||||
```
|
||||
|
||||
|
||||
Exporting
|
||||
---------
|
||||
|
||||
Export the model to ONNX.
|
||||
|
||||
```
|
||||
python export.py
|
||||
```
|
||||
@@ -1,22 +0,0 @@
|
||||
import configparser
|
||||
from os import makedirs
|
||||
|
||||
import torch
|
||||
|
||||
from .training import ArcFaceConverterTrainer
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def export() -> None:
|
||||
directory_path = CONFIG.get('exporting', 'directory_path')
|
||||
source_path = CONFIG.get('exporting', 'source_path')
|
||||
target_path = CONFIG.get('exporting', 'target_path')
|
||||
opset_version = CONFIG.getint('exporting', 'opset_version')
|
||||
|
||||
makedirs(directory_path, exist_ok = True)
|
||||
model = ArcFaceConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
|
||||
model.eval()
|
||||
input_tensor = torch.randn(1, 512)
|
||||
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
|
||||
@@ -1,21 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ArcFaceConverter(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(ArcFaceConverter, self).__init__()
|
||||
self.fc1 = nn.Linear(512, 1024)
|
||||
self.fc2 = nn.Linear(1024, 2048)
|
||||
self.fc3 = nn.Linear(2048, 1024)
|
||||
self.fc4 = nn.Linear(1024, 512)
|
||||
self.activation = nn.LeakyReLU()
|
||||
|
||||
def forward(self, inputs : Tensor) -> Tensor:
|
||||
norm_inputs = inputs / torch.norm(inputs)
|
||||
outputs = self.activation(self.fc1(norm_inputs))
|
||||
outputs = self.activation(self.fc2(outputs))
|
||||
outputs = self.activation(self.fc3(outputs))
|
||||
outputs = self.fc4(outputs)
|
||||
return outputs
|
||||
@@ -1,81 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import configparser
|
||||
from os import makedirs
|
||||
from os.path import isfile
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
numpy.bool = numpy.bool_
|
||||
from mxnet.io import ImageRecordIter
|
||||
from onnxruntime import InferenceSession
|
||||
from tqdm import tqdm
|
||||
|
||||
from .typing import Embedding, EmbeddingPairs, VisionFrame
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def prepare_crop_vision_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
|
||||
crop_vision_frame = crop_vision_frame.astype(numpy.float32) / 255
|
||||
crop_vision_frame = (crop_vision_frame - 0.5) * 2
|
||||
return crop_vision_frame
|
||||
|
||||
|
||||
def create_inference_session(model_path : str, execution_providers : List[str]) -> InferenceSession:
|
||||
inference_session = InferenceSession(model_path, providers = execution_providers)
|
||||
return inference_session
|
||||
|
||||
|
||||
def forward(inference_session : InferenceSession, crop_vision_frame : VisionFrame) -> Embedding:
|
||||
embedding = inference_session.run(None,
|
||||
{
|
||||
'input': crop_vision_frame
|
||||
})[0]
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
def process_embeddings(dataset_reader : ImageRecordIter, source_inference_session : InferenceSession, target_inference_session : InferenceSession) -> EmbeddingPairs:
|
||||
dataset_process_limit = CONFIG.getint('preparing.dataset', 'process_limit')
|
||||
embedding_pairs = []
|
||||
|
||||
with tqdm(total = dataset_process_limit) as progress:
|
||||
for batch in dataset_reader:
|
||||
crop_vision_frame = batch.data[0].asnumpy()
|
||||
crop_vision_frame = prepare_crop_vision_frame(crop_vision_frame)
|
||||
source_embedding = forward(source_inference_session, crop_vision_frame)
|
||||
target_embedding = forward(target_inference_session, crop_vision_frame)
|
||||
embedding_pairs.append([ source_embedding, target_embedding ])
|
||||
progress.update()
|
||||
|
||||
if progress.n == dataset_process_limit:
|
||||
return numpy.concatenate(embedding_pairs, axis = 1).T
|
||||
|
||||
return numpy.concatenate(embedding_pairs, axis = 1).T
|
||||
|
||||
|
||||
def prepare() -> None:
|
||||
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
|
||||
dataset_crop_size = CONFIG.getint('preparing.dataset', 'crop_size')
|
||||
model_source_path = CONFIG.get('preparing.model', 'source_path')
|
||||
model_target_path = CONFIG.get('preparing.model', 'target_path')
|
||||
input_directory_path = CONFIG.get('preparing.input', 'directory_path')
|
||||
input_source_path = CONFIG.get('preparing.input', 'source_path')
|
||||
input_target_path = CONFIG.get('preparing.input', 'target_path')
|
||||
execution_providers = CONFIG.get('execution', 'providers').split(' ')
|
||||
|
||||
makedirs(input_directory_path, exist_ok = True)
|
||||
if isfile(dataset_path) and isfile(model_source_path) and isfile(model_target_path):
|
||||
dataset_reader = ImageRecordIter(
|
||||
path_imgrec = dataset_path,
|
||||
data_shape = (3, dataset_crop_size, dataset_crop_size),
|
||||
batch_size = 1,
|
||||
shuffle = False
|
||||
)
|
||||
source_inference_session = create_inference_session(model_source_path, execution_providers)
|
||||
target_inference_session = create_inference_session(model_target_path, execution_providers)
|
||||
embedding_pairs = process_embeddings(dataset_reader, source_inference_session, target_inference_session)
|
||||
numpy.save(input_source_path, embedding_pairs[..., 0].T)
|
||||
numpy.save(input_target_path, embedding_pairs[..., 1].T)
|
||||
@@ -1,118 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import configparser
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy
|
||||
import pytorch_lightning
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
|
||||
|
||||
from .model import ArcFaceConverter
|
||||
from .typing import Batch, Loader
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class ArcFaceConverterTrainer(pytorch_lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super(ArcFaceConverterTrainer, self).__init__()
|
||||
self.model = ArcFaceConverter()
|
||||
self.loss_fn = torch.nn.MSELoss()
|
||||
self.lr = 0.001
|
||||
|
||||
def forward(self, source_embedding : Tensor) -> Tensor:
|
||||
return self.model(source_embedding)
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source, target = batch
|
||||
output = self(source)
|
||||
loss = self.loss_fn(output, target)
|
||||
self.log('train_loss', loss, prog_bar = True, logger = True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source, target = batch
|
||||
output = self(source)
|
||||
loss = self.loss_fn(output, target)
|
||||
self.log('val_loss', loss, prog_bar = True, logger = True)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr = self.lr)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
|
||||
return\
|
||||
{
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': scheduler,
|
||||
'monitor': 'train_loss',
|
||||
'interval': 'epoch',
|
||||
'frequency': 1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_loaders() -> Tuple[Loader, Loader]:
|
||||
loader_batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
loader_num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset()
|
||||
training_loader = DataLoader(training_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = True, pin_memory = True)
|
||||
validation_loader = DataLoader(validate_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = False, pin_memory = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset() -> Tuple[Dataset[Any], Dataset[Any]]:
|
||||
input_source_path = CONFIG.get('preparing.input', 'source_path')
|
||||
input_target_path = CONFIG.get('preparing.input', 'target_path')
|
||||
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
|
||||
|
||||
source_input = torch.from_numpy(numpy.load(input_source_path)).float()
|
||||
target_input = torch.from_numpy(numpy.load(input_target_path)).float()
|
||||
dataset = TensorDataset(source_input, target_input)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
training_size = int(loader_split_ratio * len(dataset))
|
||||
validation_size = int(dataset_size - training_size)
|
||||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
|
||||
output_directory_path = CONFIG.get('training.output', 'directory_path')
|
||||
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
|
||||
|
||||
return Trainer(
|
||||
max_epochs = trainer_max_epochs,
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
monitor = 'train_loss',
|
||||
dirpath = output_directory_path,
|
||||
filename = output_file_pattern,
|
||||
every_n_epochs = 10,
|
||||
save_top_k = 3,
|
||||
save_last = True
|
||||
)
|
||||
],
|
||||
enable_progress_bar = True,
|
||||
log_every_n_steps = 2
|
||||
)
|
||||
|
||||
|
||||
def train() -> None:
|
||||
trainer = create_trainer()
|
||||
training_loader, validation_loader = create_loaders()
|
||||
model = ArcFaceConverterTrainer()
|
||||
tuner = Tuner(trainer)
|
||||
tuner.lr_find(model, training_loader, validation_loader)
|
||||
trainer.fit(model, training_loader, validation_loader)
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Any, Tuple
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
Batch = Tuple[Tensor, Tensor]
|
||||
Loader = DataLoader[Tuple[Tensor, ...]]
|
||||
|
||||
Embedding = NDArray[Any]
|
||||
EmbeddingPairs = NDArray[Any]
|
||||
FaceLandmark5 = NDArray[Any]
|
||||
VisionFrame = NDArray[Any]
|
||||
@@ -0,0 +1,3 @@
|
||||
OpenRAIL-MS license
|
||||
|
||||
Copyright (c) 2025 Henry Ruhs
|
||||
@@ -0,0 +1,104 @@
|
||||
CrossFace
|
||||
=========
|
||||
|
||||
> Seamless face embedding across various models.
|
||||
|
||||

|
||||
|
||||
|
||||
Preview
|
||||
-------
|
||||
|
||||

|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
Setup
|
||||
-----
|
||||
|
||||
This `config.ini` utilizes the MegaFace dataset to train the CrossFace model for SimSwap.
|
||||
|
||||
```
|
||||
[training.dataset]
|
||||
file_pattern = .datasets/megaface/**/*.jpg
|
||||
```
|
||||
|
||||
```
|
||||
[training.loader]
|
||||
batch_size = 128
|
||||
num_workers = 8
|
||||
split_ratio = 0.95
|
||||
```
|
||||
|
||||
```
|
||||
[training.model]
|
||||
source_path = .models/arcface_w600k_r50.pt
|
||||
target_path = .models/arcface_simswap.pt
|
||||
```
|
||||
|
||||
```
|
||||
[training.trainer]
|
||||
max_epochs = 4096
|
||||
strategy = auto
|
||||
precision = 16-mixed
|
||||
```
|
||||
|
||||
```
|
||||
[training.optimizer]
|
||||
learning_rate = 0.001
|
||||
```
|
||||
|
||||
```
|
||||
[training.logger]
|
||||
logger_path = .logs
|
||||
logger_name = crossface_simswap
|
||||
```
|
||||
|
||||
```
|
||||
[training.output]
|
||||
directory_path = .outputs
|
||||
file_pattern = crossface_simswap_{epoch}_{step}
|
||||
resume_path = .outputs/last.ckpt
|
||||
```
|
||||
|
||||
```
|
||||
[exporting]
|
||||
directory_path = .exports
|
||||
source_path = .outputs/last.ckpt
|
||||
target_path = .exports/crossface_simswap.onnx
|
||||
ir_version = 10
|
||||
opset_version = 15
|
||||
```
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
|
||||
Train the model.
|
||||
|
||||
```
|
||||
python train.py
|
||||
```
|
||||
|
||||
Launch the TensorBoard to monitor the training.
|
||||
|
||||
```
|
||||
tensorboard --logdir .logs
|
||||
```
|
||||
|
||||
|
||||
Exporting
|
||||
---------
|
||||
|
||||
Export the model to ONNX.
|
||||
|
||||
```
|
||||
python export.py
|
||||
```
|
||||
@@ -1,34 +1,35 @@
|
||||
[preparing.dataset]
|
||||
dataset_path =
|
||||
crop_size =
|
||||
process_limit =
|
||||
|
||||
[preparing.model]
|
||||
source_path =
|
||||
target_path =
|
||||
|
||||
[preparing.input]
|
||||
directory_path =
|
||||
source_path =
|
||||
target_path =
|
||||
[training.dataset]
|
||||
file_pattern =
|
||||
|
||||
[training.loader]
|
||||
split_ratio =
|
||||
batch_size =
|
||||
num_workers =
|
||||
split_ratio =
|
||||
|
||||
[training.model]
|
||||
source_path =
|
||||
target_path =
|
||||
|
||||
[training.trainer]
|
||||
max_epochs =
|
||||
strategy =
|
||||
precision =
|
||||
|
||||
[training.optimizer]
|
||||
learning_rate =
|
||||
|
||||
[training.logger]
|
||||
logger_path =
|
||||
logger_name =
|
||||
|
||||
[training.output]
|
||||
directory_path =
|
||||
file_pattern =
|
||||
resume_path =
|
||||
|
||||
[exporting]
|
||||
directory_path =
|
||||
source_path =
|
||||
target_path =
|
||||
ir_version =
|
||||
opset_version =
|
||||
|
||||
[execution]
|
||||
providers =
|
||||
@@ -0,0 +1,34 @@
|
||||
import glob
|
||||
from configparser import ConfigParser
|
||||
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .types import Batch
|
||||
|
||||
|
||||
class StaticDataset(Dataset[Tensor]):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
|
||||
self.file_paths = glob.glob(self.config_file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = self.file_paths[index]
|
||||
temp_tensor = io.read_image(file_path)
|
||||
return self.transforms(temp_tensor)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.file_paths)
|
||||
|
||||
@staticmethod
|
||||
def compose_transforms() -> transforms:
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
|
||||
import torch
|
||||
|
||||
from .training import CrossFaceTrainer
|
||||
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
def export() -> None:
|
||||
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
|
||||
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
|
||||
config_target_path = CONFIG_PARSER.get('exporting', 'target_path')
|
||||
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
|
||||
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
|
||||
os.makedirs(config_directory_path, exist_ok = True)
|
||||
model = CrossFaceTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
|
||||
model.ir_version = torch.tensor(config_ir_version)
|
||||
input_tensor = torch.randn(1, 512)
|
||||
torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version)
|
||||
@@ -0,0 +1,37 @@
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class CrossFace(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sequence = self.create_sequence()
|
||||
self.linear = nn.Linear(512, 512)
|
||||
self.apply(init_weight)
|
||||
|
||||
@staticmethod
|
||||
def create_sequence() -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Linear(512, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 2048),
|
||||
nn.LayerNorm(2048),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = nn.functional.normalize(input_tensor, p = 2, dim = -1)
|
||||
return self.sequence(temp_tensor) + 0.2 * self.linear(temp_tensor)
|
||||
|
||||
|
||||
def init_weight(module : nn.Module) -> None:
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_normal_(module.weight)
|
||||
nn.init.constant_(module.bias, 0.01)
|
||||
@@ -0,0 +1,145 @@
|
||||
import os
|
||||
import shutil
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
from typing import Tuple, cast
|
||||
|
||||
import torch
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import Dataset, random_split
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import StaticDataset
|
||||
from .models.crossface import CrossFace
|
||||
from .types import Batch, Embedding, OptimizerSet, TrainerPrecision, TrainerStrategy
|
||||
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
class CrossFaceTrainer(LightningModule):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_source_path = config_parser.get('training.model', 'source_path')
|
||||
self.config_target_path = config_parser.get('training.model', 'target_path')
|
||||
self.config_learning_rate = config_parser.getfloat('training.optimizer', 'learning_rate')
|
||||
self.crossface = CrossFace()
|
||||
self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval()
|
||||
self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_embedding : Embedding) -> Embedding:
|
||||
return self.crossface(source_embedding)
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
with torch.no_grad():
|
||||
source_embedding = self.source_embedder(batch)
|
||||
target_embedding = self.target_embedder(batch)
|
||||
output_embedding = self(source_embedding)
|
||||
training_loss = self.mse_loss(output_embedding, target_embedding)
|
||||
self.log('training_loss', training_loss, prog_bar = True)
|
||||
return training_loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
with torch.no_grad():
|
||||
source_embedding = self.source_embedder(batch)
|
||||
output_embedding = self(source_embedding)
|
||||
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
|
||||
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
|
||||
return validation_score
|
||||
|
||||
def configure_optimizers(self) -> OptimizerSet:
|
||||
optimizer = torch.optim.AdamW(self.parameters(), lr = self.config_learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
optimizer_set =\
|
||||
{
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': scheduler,
|
||||
'monitor': 'training_loss',
|
||||
'interval': 'epoch',
|
||||
'frequency': 1
|
||||
}
|
||||
}
|
||||
|
||||
return optimizer_set
|
||||
|
||||
|
||||
class ModelWithConfigCheckpoint(ModelCheckpoint):
|
||||
def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None:
|
||||
super()._save_checkpoint(trainer, checkpoint_path)
|
||||
config_path = Path(checkpoint_path).with_suffix('.ini')
|
||||
shutil.copy('config.ini', config_path)
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
|
||||
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
|
||||
config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
|
||||
|
||||
dataset_size = len(dataset) # type:ignore[arg-type]
|
||||
training_size = int(dataset_size * config_split_ratio)
|
||||
validation_size = int(dataset_size - training_size)
|
||||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
|
||||
config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy'))
|
||||
config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision'))
|
||||
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
|
||||
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
|
||||
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
logger = TensorBoardLogger(config_logger_path, config_logger_name)
|
||||
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = config_max_epochs,
|
||||
strategy = config_strategy,
|
||||
precision = config_precision,
|
||||
callbacks =
|
||||
[
|
||||
ModelWithConfigCheckpoint(
|
||||
monitor = 'training_loss',
|
||||
dirpath = config_directory_path,
|
||||
filename = config_file_pattern,
|
||||
every_n_epochs = 1000,
|
||||
save_top_k = 5,
|
||||
save_last = True
|
||||
),
|
||||
StochasticWeightAveraging(swa_lrs = 1e-2)
|
||||
],
|
||||
val_check_interval = 1000
|
||||
)
|
||||
|
||||
|
||||
def train() -> None:
|
||||
config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = StaticDataset(CONFIG_PARSER)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
crossface_trainer = CrossFaceTrainer(CONFIG_PARSER)
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.exists(config_resume_path):
|
||||
trainer.fit(crossface_trainer, training_loader, validation_loader, ckpt_path = config_resume_path)
|
||||
else:
|
||||
trainer.fit(crossface_trainer, training_loader, validation_loader)
|
||||
@@ -0,0 +1,11 @@
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
Batch : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true']
|
||||
TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16']
|
||||
@@ -0,0 +1,3 @@
|
||||
ResearchRAIL-MS license
|
||||
|
||||
Copyright (c) 2025 Henry Ruhs
|
||||
@@ -0,0 +1,189 @@
|
||||
HyperSwap
|
||||
=========
|
||||
|
||||
> Hyper accurate face swapping for everyone.
|
||||
|
||||

|
||||
|
||||
|
||||
Preview
|
||||
-------
|
||||
|
||||

|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
Setup
|
||||
-----
|
||||
|
||||
This `config.ini` utilizes the VGGFace2 dataset to train the HyperSwap model.
|
||||
|
||||
```
|
||||
[training.dataset]
|
||||
file_pattern = .datasets/vggface2/**/*.jpg
|
||||
convert_template = vggfacehq_512_to_arcface_128
|
||||
multiplier = 1
|
||||
transform_size = 256
|
||||
usage_mode = both
|
||||
batch_mode = same
|
||||
batch_ratio = 0.2
|
||||
```
|
||||
|
||||
```
|
||||
[training.loader]
|
||||
batch_size = 8
|
||||
num_workers = 8
|
||||
split_ratio = 0.9995
|
||||
```
|
||||
|
||||
```
|
||||
[training.model]
|
||||
generator_embedder_path = .models/blendface.pt
|
||||
loss_embedder_path = .models/arcface.pt
|
||||
gazer_path = .models/gazer.pt
|
||||
face_masker_path = .models/face_masker.pt
|
||||
```
|
||||
|
||||
```
|
||||
[training.model.generator]
|
||||
source_channels = 512
|
||||
output_size = 256
|
||||
num_blocks = 2
|
||||
```
|
||||
|
||||
```
|
||||
[training.model.discriminator]
|
||||
input_channels = 3
|
||||
num_filters = 64
|
||||
num_layers = 5
|
||||
num_discriminators = 3
|
||||
kernel_size = 4
|
||||
```
|
||||
|
||||
```
|
||||
[training.model.masker]
|
||||
input_channels = 67
|
||||
output_channels = 1
|
||||
num_filters = 16
|
||||
```
|
||||
|
||||
```
|
||||
[training.losses]
|
||||
adversarial_weight = 1.0
|
||||
cycle_weight = 1.0
|
||||
feature_weight = 10.0
|
||||
reconstruction_weight = 10.0
|
||||
identity_weight = 20.0
|
||||
gaze_weight = 0.05
|
||||
mask_weight = 5.0
|
||||
```
|
||||
|
||||
```
|
||||
[training.trainer]
|
||||
accumulate_size = 4
|
||||
discriminator_ratio = 0.4
|
||||
gradient_clip = 20.0
|
||||
max_epochs = 50
|
||||
strategy = auto
|
||||
precision = 16-mixed
|
||||
sync_batchnorm = false
|
||||
preview_frequency = 100
|
||||
```
|
||||
|
||||
```
|
||||
[training.modifier]
|
||||
mask_factor = 0.01
|
||||
noise_factor = 0.05
|
||||
```
|
||||
|
||||
```
|
||||
[training.optimizer.generator]
|
||||
learning_rate = 0.0004
|
||||
momentum = 0.5
|
||||
scheduler_factor = 0.7
|
||||
scheduler_patience = 2000
|
||||
```
|
||||
|
||||
```
|
||||
[training.optimizer.discriminator]
|
||||
learning_rate = 0.0002
|
||||
momentum = 0.5
|
||||
scheduler_factor = 0.7
|
||||
scheduler_patience = 2000
|
||||
```
|
||||
|
||||
```
|
||||
[training.logger]
|
||||
logger_path = .logs
|
||||
logger_name = hyperswap
|
||||
```
|
||||
|
||||
```
|
||||
[training.output]
|
||||
directory_path = .outputs
|
||||
file_pattern = hyperswap_{epoch}_{step}
|
||||
resume_path = .outputs/last.ckpt
|
||||
```
|
||||
|
||||
```
|
||||
[exporting]
|
||||
directory_path = .exports
|
||||
source_path = .outputs/last.ckpt
|
||||
target_path = .exports/hyperswap_256.onnx
|
||||
target_size = 256
|
||||
ir_version = 10
|
||||
opset_version = 15
|
||||
precision = full
|
||||
```
|
||||
|
||||
```
|
||||
[inferencing]
|
||||
generator_path = .outputs/last.ckpt
|
||||
embedder_path = .models/arcface.pt
|
||||
source_path = .assets/source.jpg
|
||||
target_path = .assets/target.jpg
|
||||
output_path = .outputs/output.jpg
|
||||
```
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
|
||||
Train the model.
|
||||
|
||||
```
|
||||
python train.py
|
||||
```
|
||||
|
||||
Launch the TensorBoard to monitor the training.
|
||||
|
||||
```
|
||||
tensorboard --logdir .logs
|
||||
```
|
||||
|
||||
|
||||
Exporting
|
||||
---------
|
||||
|
||||
Export the model to ONNX.
|
||||
|
||||
```
|
||||
python export.py
|
||||
```
|
||||
|
||||
|
||||
Inferencing
|
||||
-----------
|
||||
|
||||
Inference the model.
|
||||
|
||||
```
|
||||
python infer.py
|
||||
```
|
||||
@@ -0,0 +1,96 @@
|
||||
[training.dataset]
|
||||
file_pattern =
|
||||
convert_template =
|
||||
multiplier =
|
||||
transform_size =
|
||||
usage_mode =
|
||||
batch_mode =
|
||||
batch_ratio =
|
||||
|
||||
[training.loader]
|
||||
batch_size =
|
||||
num_workers =
|
||||
split_ratio =
|
||||
|
||||
[training.model]
|
||||
generator_embedder_path =
|
||||
loss_embedder_path =
|
||||
gazer_path =
|
||||
face_masker_path =
|
||||
|
||||
[training.model.generator]
|
||||
source_channels =
|
||||
output_size =
|
||||
num_blocks =
|
||||
|
||||
[training.model.discriminator]
|
||||
input_channels =
|
||||
num_filters =
|
||||
num_layers =
|
||||
num_discriminators =
|
||||
kernel_size =
|
||||
|
||||
[training.model.masker]
|
||||
input_channels =
|
||||
output_channels =
|
||||
num_filters =
|
||||
|
||||
[training.losses]
|
||||
adversarial_weight =
|
||||
cycle_weight =
|
||||
feature_weight =
|
||||
reconstruction_weight =
|
||||
identity_weight =
|
||||
gaze_weight =
|
||||
mask_weight =
|
||||
|
||||
[training.trainer]
|
||||
accumulate_size =
|
||||
discriminator_ratio =
|
||||
gradient_clip =
|
||||
max_epochs =
|
||||
strategy =
|
||||
precision =
|
||||
sync_batchnorm =
|
||||
preview_frequency =
|
||||
|
||||
[training.modifier]
|
||||
mask_factor =
|
||||
noise_factor =
|
||||
|
||||
[training.optimizer.generator]
|
||||
learning_rate =
|
||||
momentum =
|
||||
scheduler_factor =
|
||||
scheduler_patience =
|
||||
|
||||
[training.optimizer.discriminator]
|
||||
learning_rate =
|
||||
momentum =
|
||||
scheduler_factor =
|
||||
scheduler_patience =
|
||||
|
||||
[training.logger]
|
||||
logger_path =
|
||||
logger_name =
|
||||
|
||||
[training.output]
|
||||
directory_path =
|
||||
file_pattern =
|
||||
resume_path =
|
||||
|
||||
[exporting]
|
||||
directory_path =
|
||||
source_path =
|
||||
target_path =
|
||||
target_size =
|
||||
ir_version =
|
||||
opset_version =
|
||||
precision =
|
||||
|
||||
[inferencing]
|
||||
generator_path =
|
||||
embedder_path =
|
||||
source_path =
|
||||
target_path =
|
||||
output_path =
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from src.preparing import prepare
|
||||
from src.exporting import export
|
||||
|
||||
if __name__ == '__main__':
|
||||
prepare()
|
||||
export()
|
||||
@@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from src.inferencing import infer
|
||||
|
||||
if __name__ == '__main__':
|
||||
infer()
|
||||
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
import random
|
||||
from configparser import ConfigParser
|
||||
from typing import cast
|
||||
|
||||
import albumentations
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .helper import convert_tensor, resolve_static_file_pattern
|
||||
from .types import Batch, BatchMode, ConvertTemplate, UsageMode
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config_file_pattern = config_parser.get('training.dataset.current', 'file_pattern')
|
||||
self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset.current', 'convert_template'))
|
||||
self.config_transform_size = config_parser.getint('training.dataset.current', 'transform_size')
|
||||
self.config_usage_mode = cast(UsageMode, config_parser.get('training.dataset.current', 'usage_mode'))
|
||||
self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset.current', 'batch_mode'))
|
||||
self.config_batch_ratio = config_parser.getfloat('training.dataset.current', 'batch_ratio')
|
||||
self.config_parser = config_parser
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = resolve_static_file_pattern(self.config_file_pattern)[index]
|
||||
|
||||
if random.random() < self.config_batch_ratio:
|
||||
if self.config_batch_mode == 'equal':
|
||||
return self.prepare_equal_batch(file_path)
|
||||
if self.config_batch_mode == 'same':
|
||||
return self.prepare_same_batch(file_path)
|
||||
|
||||
if self.config_usage_mode == 'source':
|
||||
return self.prepare_source_batch(file_path)
|
||||
|
||||
if self.config_usage_mode == 'target':
|
||||
return self.prepare_target_batch(file_path)
|
||||
|
||||
return self.prepare_different_batch(file_path)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(resolve_static_file_pattern(self.config_file_pattern))
|
||||
|
||||
def prepare_equal_batch(self, source_path : str) -> Batch:
|
||||
return self.create_batch(source_path, source_path, self.config_convert_template, self.config_convert_template)
|
||||
|
||||
def prepare_same_batch(self, source_path : str) -> Batch:
|
||||
target_directory_path = os.path.dirname(source_path)
|
||||
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
|
||||
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
|
||||
return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template)
|
||||
|
||||
def prepare_source_batch(self, source_path : str) -> Batch:
|
||||
config_parser = self.filter_config_by_usage_mode('both')
|
||||
config_section = random.choice(config_parser.sections())
|
||||
config_file_pattern = config_parser.get(config_section, 'file_pattern')
|
||||
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
|
||||
target_path = random.choice(resolve_static_file_pattern(config_file_pattern))
|
||||
return self.create_batch(source_path, target_path, self.config_convert_template, config_convert_template)
|
||||
|
||||
def prepare_target_batch(self, target_path : str) -> Batch:
|
||||
config_parser = self.filter_config_by_usage_mode('both')
|
||||
config_section = random.choice(config_parser.sections())
|
||||
config_file_pattern = config_parser.get(config_section, 'file_pattern')
|
||||
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
|
||||
source_path = random.choice(resolve_static_file_pattern(config_file_pattern))
|
||||
return self.create_batch(source_path, target_path, config_convert_template, self.config_convert_template)
|
||||
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern))
|
||||
return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template)
|
||||
|
||||
def compose_transforms(self) -> transforms:
|
||||
return transforms.Compose(
|
||||
[
|
||||
AugmentTransform(),
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
def filter_config_by_usage_mode(self, usage_mode : UsageMode) -> ConfigParser:
|
||||
config_parser = ConfigParser()
|
||||
|
||||
for config_section in self.config_parser.sections():
|
||||
|
||||
if config_section.startswith('training.dataset'):
|
||||
current_usage_mode = cast(UsageMode, self.config_parser.get(config_section, 'usage_mode'))
|
||||
if current_usage_mode == usage_mode:
|
||||
config_parser.add_section(config_section)
|
||||
|
||||
for key, value in self.config_parser.items(config_section):
|
||||
config_parser.set(config_section, key, value)
|
||||
|
||||
return config_parser
|
||||
|
||||
def create_batch(self, source_path : str, target_path : str, source_convert_template : ConvertTemplate, target_convert_template : ConvertTemplate) -> Batch:
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, source_convert_template)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
target_tensor = self.conditional_convert_tensor(target_tensor, target_convert_template)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
@staticmethod
|
||||
def conditional_convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
|
||||
if convert_template:
|
||||
temp_tensor = input_tensor.unsqueeze(0)
|
||||
return convert_tensor(temp_tensor, convert_template).squeeze(0)
|
||||
return input_tensor
|
||||
|
||||
|
||||
class AugmentTransform:
|
||||
def __init__(self) -> None:
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __call__(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.numpy().transpose(1, 2, 0)
|
||||
return self.transforms(image = temp_tensor).get('image')
|
||||
|
||||
@staticmethod
|
||||
def compose_transforms() -> albumentations.Compose:
|
||||
return albumentations.Compose(
|
||||
[
|
||||
albumentations.HorizontalFlip(p = 0.5),
|
||||
albumentations.OneOf(
|
||||
[
|
||||
albumentations.MotionBlur(),
|
||||
albumentations.ZoomBlur(max_factor = (1.0, 1.2))
|
||||
], p = 0.1),
|
||||
albumentations.OneOf(
|
||||
[
|
||||
albumentations.RandomGamma(),
|
||||
albumentations.RandomBrightnessContrast(),
|
||||
albumentations.Illumination()
|
||||
], p = 0.2),
|
||||
albumentations.OneOf(
|
||||
[
|
||||
albumentations.ColorJitter(),
|
||||
albumentations.RGBShift(),
|
||||
albumentations.HueSaturationValue()
|
||||
], p = 0.2),
|
||||
albumentations.Affine(
|
||||
translate_percent = (-0.05, 0.05),
|
||||
scale = (0.95, 1.05),
|
||||
rotate = (-2, 2),
|
||||
border_mode = 1,
|
||||
p = 0.2
|
||||
)
|
||||
])
|
||||
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .training import HyperSwapTrainer
|
||||
from .types import Embedding, Mask, Module
|
||||
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
class HalfPrecision(nn.Module):
|
||||
def __init__(self, model : Module) -> None:
|
||||
super().__init__()
|
||||
self.model = model.half()
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
|
||||
source_embedding = source_embedding.half()
|
||||
target_tensor = target_tensor.half()
|
||||
output_tensor, output_mask = self.model(source_embedding, target_tensor)
|
||||
output_tensor = output_tensor.float()
|
||||
output_mask = output_mask.float()
|
||||
return output_tensor, output_mask
|
||||
|
||||
|
||||
def export() -> None:
|
||||
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
|
||||
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
|
||||
config_target_path = CONFIG_PARSER.get('exporting', 'target_path')
|
||||
config_target_size = CONFIG_PARSER.getint('exporting', 'target_size')
|
||||
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
|
||||
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
config_precision = CONFIG_PARSER.get('exporting', 'precision')
|
||||
|
||||
os.makedirs(config_directory_path, exist_ok = True)
|
||||
model = HyperSwapTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
|
||||
|
||||
if config_precision == 'half':
|
||||
model = HalfPrecision(model).eval()
|
||||
|
||||
model.ir_version = torch.tensor(config_ir_version)
|
||||
source_tensor = torch.randn(1, 512)
|
||||
target_tensor = torch.randn(1, 3, config_target_size, config_target_size)
|
||||
torch.onnx.export(model, (source_tensor, target_tensor), config_target_path, input_names = [ 'source', 'target' ], output_names = [ 'output', 'mask' ], opset_version = config_opset_version)
|
||||
@@ -0,0 +1,82 @@
|
||||
import glob
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import ConvertTemplate, ConvertTemplateSet, EmbedderModule, Embedding, Mask, Padding
|
||||
|
||||
CONVERT_TEMPLATE_SET : ConvertTemplateSet =\
|
||||
{
|
||||
'arcface_128_to_arcface_112_v2': torch.tensor(
|
||||
[
|
||||
[ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ],
|
||||
[ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ]
|
||||
]),
|
||||
'ffhq_512_to_arcface_128': torch.tensor(
|
||||
[
|
||||
[ 8.50048894e-01, -1.29486822e-04, 1.90956388e-03 ],
|
||||
[ 1.29486822e-04, 8.50048894e-01, 9.56254653e-02 ]
|
||||
]),
|
||||
'vggfacehq_512_to_arcface_128': torch.tensor(
|
||||
[
|
||||
[ 1.01305414, -0.00140513, -0.00585911 ],
|
||||
[ 0.00140513, 1.01305414, 0.11169602 ]
|
||||
])
|
||||
}
|
||||
|
||||
|
||||
def convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
|
||||
convert_matrix = CONVERT_TEMPLATE_SET.get(convert_template).repeat(input_tensor.shape[0], 1, 1)
|
||||
affine_grid = nn.functional.affine_grid(convert_matrix.to(input_tensor.device), list(input_tensor.shape))
|
||||
output_tensor = nn.functional.grid_sample(input_tensor, affine_grid, padding_mode = 'reflection')
|
||||
return output_tensor
|
||||
|
||||
|
||||
def calculate_face_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
|
||||
crop_tensor = convert_tensor(input_tensor, 'arcface_128_to_arcface_112_v2')
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
|
||||
crop_tensor[:, :, :padding[0], :] = 0
|
||||
crop_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
|
||||
face_embedding = embedder(crop_tensor)
|
||||
face_embedding = nn.functional.normalize(face_embedding, p = 2)
|
||||
return face_embedding
|
||||
|
||||
|
||||
def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
|
||||
overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device)
|
||||
overlay_tensor[:, 2, :, :] = 1
|
||||
input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8)
|
||||
output_tensor = input_tensor * (1 - input_mask) + overlay_tensor * input_mask
|
||||
return output_tensor
|
||||
|
||||
|
||||
def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
padding = int(input_tensor.shape[2] * factor + 0.5)
|
||||
kernel_size = 1 + 2 * padding
|
||||
temp_tensor = nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
|
||||
output_tensor = nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def erode_mask(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
padding = int(input_tensor.shape[2] * factor + 0.5)
|
||||
kernel_size = 1 + 2 * padding
|
||||
temp_tensor = 1 - nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
|
||||
output_tensor = 1 - nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def apply_noise(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
noise_tensor = torch.randn_like(input_tensor) * factor
|
||||
output_tensor = input_tensor + noise_tensor
|
||||
return output_tensor
|
||||
|
||||
|
||||
@lru_cache(maxsize = None)
|
||||
def resolve_static_file_pattern(file_pattern : str) -> List[str]:
|
||||
return sorted(glob.glob(file_pattern))
|
||||
@@ -0,0 +1,27 @@
|
||||
import configparser
|
||||
|
||||
import torch
|
||||
from torchvision import io
|
||||
|
||||
from .helper import calculate_face_embedding
|
||||
from .training import HyperSwapTrainer
|
||||
|
||||
CONFIG_PARSER = configparser.ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
def infer() -> None:
|
||||
config_generator_path = CONFIG_PARSER.get('inferencing', 'generator_path')
|
||||
config_embedder_path = CONFIG_PARSER.get('inferencing', 'embedder_path')
|
||||
config_source_path = CONFIG_PARSER.get('inferencing', 'source_path')
|
||||
config_target_path = CONFIG_PARSER.get('inferencing', 'target_path')
|
||||
config_output_path = CONFIG_PARSER.get('inferencing', 'output_path')
|
||||
|
||||
generator = HyperSwapTrainer.load_from_checkpoint(config_generator_path, config_parser = CONFIG_PARSER, map_location ='cpu').eval()
|
||||
embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval()
|
||||
|
||||
source_tensor = io.read_image(config_source_path)
|
||||
target_tensor = io.read_image(config_target_path)
|
||||
source_embedding = calculate_face_embedding(embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor, _ = generator(source_embedding, target_tensor)
|
||||
io.write_jpeg(output_tensor, config_output_path)
|
||||
@@ -0,0 +1,35 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..networks.nld import NLD
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators')
|
||||
self.config_parser = config_parser
|
||||
self.discriminators = self.create_discriminators()
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
|
||||
def create_discriminators(self) -> nn.ModuleList:
|
||||
discriminators = nn.ModuleList()
|
||||
|
||||
for _ in range(self.config_num_discriminators):
|
||||
discriminator = NLD(self.config_parser).sequences
|
||||
discriminators.append(discriminator)
|
||||
|
||||
return discriminators
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> List[Tensor]:
|
||||
temp_tensor = input_tensor
|
||||
output_tensors = []
|
||||
|
||||
for discriminator in self.discriminators:
|
||||
output_tensor = discriminator(temp_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
temp_tensor = self.avg_pool(temp_tensor)
|
||||
|
||||
return output_tensors
|
||||
@@ -0,0 +1,42 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..networks.aad import AAD
|
||||
from ..networks.masknet import MaskNet
|
||||
from ..networks.unet import UNet
|
||||
from ..types import Embedding, Feature, Mask
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.encoder = UNet(config_parser)
|
||||
self.generator = AAD(config_parser)
|
||||
self.masker = MaskNet(config_parser)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
self.masker.apply(init_weight)
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask]:
|
||||
output_tensor = self.generator(source_embedding, target_features)
|
||||
target_feature = target_features[-1]
|
||||
output_mask = self.masker(target_tensor, target_feature)
|
||||
output_tensor = output_tensor * output_mask + target_tensor * (1 - output_mask)
|
||||
return output_tensor, output_mask
|
||||
|
||||
def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]:
|
||||
return self.encoder(input_tensor)
|
||||
|
||||
|
||||
def init_weight(module : nn.Module) -> None:
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(std = 0.001)
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.xavier_normal_(module.weight.data)
|
||||
|
||||
if isinstance(module, nn.ConvTranspose2d):
|
||||
nn.init.xavier_normal_(module.weight.data)
|
||||
@@ -0,0 +1,192 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calculate_face_embedding, dilate_mask
|
||||
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, discriminator_real_tensors : List[Tensor], discriminator_fake_tensors : List[Tensor]) -> Loss:
|
||||
positive_tensors = []
|
||||
negative_tensors = []
|
||||
|
||||
for discriminator_real_tensor in discriminator_real_tensors:
|
||||
positive_tensor = torch.relu(1 - discriminator_real_tensor).mean(dim = [ 1, 2, 3 ])
|
||||
positive_tensors.append(positive_tensor)
|
||||
|
||||
for discriminator_fake_tensor in discriminator_fake_tensors:
|
||||
negative_tensor = torch.relu(discriminator_fake_tensor + 1).mean(dim = [ 1, 2, 3 ])
|
||||
negative_tensors.append(negative_tensor)
|
||||
|
||||
positive_loss = torch.stack(positive_tensors).mean()
|
||||
negative_loss = torch.stack(negative_tensors).mean()
|
||||
discriminator_loss = (positive_loss + negative_loss) * 0.5
|
||||
return discriminator_loss
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_adversarial_weight = config_parser.getfloat('training.losses', 'adversarial_weight')
|
||||
|
||||
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for discriminator_output_tensor in discriminator_output_tensors:
|
||||
temp_tensor = torch.relu(1 - discriminator_output_tensor).mean(dim = [ 1, 2, 3 ]).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
adversarial_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_adversarial_loss = adversarial_loss * self.config_adversarial_weight
|
||||
return adversarial_loss, weighted_adversarial_loss
|
||||
|
||||
|
||||
class CycleLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight')
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_feature, output_feature in zip(target_features, cycle_features):
|
||||
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
feature_loss = torch.stack(temp_tensors).mean()
|
||||
reconstruction_loss = self.l1_loss(target_tensor, cycle_tensor)
|
||||
cycle_loss = (feature_loss + reconstruction_loss) * 0.5
|
||||
weighted_feature_loss = cycle_loss * self.config_cycle_weight
|
||||
return cycle_loss, weighted_feature_loss
|
||||
|
||||
|
||||
class FeatureLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
self.config_feature_weight = config_parser.getfloat('training.losses', 'feature_weight')
|
||||
|
||||
def forward(self, target_features : Tuple[Feature, ...], output_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_feature, output_feature in zip(target_features, output_features):
|
||||
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
feature_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
weighted_feature_loss = feature_loss * self.config_feature_weight
|
||||
return feature_loss, weighted_feature_loss
|
||||
|
||||
|
||||
class ReconstructionLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.config_reconstruction_weight = config_parser.getfloat('training.losses', 'reconstruction_weight')
|
||||
self.embedder = embedder
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
with torch.no_grad():
|
||||
source_embedding = calculate_face_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calculate_face_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
|
||||
|
||||
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
|
||||
|
||||
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
|
||||
reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5
|
||||
|
||||
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = 2.0)
|
||||
visual_loss = (visual_loss * has_similar_identity).mean()
|
||||
|
||||
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
|
||||
weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight
|
||||
return reconstruction_loss, weighted_reconstruction_loss
|
||||
|
||||
|
||||
class IdentityLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight')
|
||||
self.embedder = embedder
|
||||
|
||||
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
output_embedding = calculate_face_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calculate_face_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
weighted_identity_loss = identity_loss * self.config_identity_weight
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None:
|
||||
super().__init__()
|
||||
self.config_gaze_weight = config_parser.getfloat('training.losses', 'gaze_weight')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.gazer = gazer
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
output_pitch, output_yaw = self.detect_gaze(output_tensor)
|
||||
target_pitch, target_yaw = self.detect_gaze(target_tensor)
|
||||
|
||||
pitch_loss = self.l1_loss(output_pitch, target_pitch)
|
||||
yaw_loss = self.l1_loss(output_yaw, target_yaw)
|
||||
|
||||
gaze_loss = (pitch_loss + yaw_loss) * 0.5
|
||||
weighted_gaze_loss = gaze_loss * self.config_gaze_weight
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
def detect_gaze(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config_output_size).int()
|
||||
crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]]
|
||||
crop_tensor = (crop_tensor + 1) * 0.5
|
||||
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic')
|
||||
|
||||
with torch.no_grad():
|
||||
pitch, yaw = self.gazer(crop_tensor)
|
||||
|
||||
return pitch, yaw
|
||||
|
||||
|
||||
class MaskLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
|
||||
super().__init__()
|
||||
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
|
||||
self.config_mask_factor = config_parser.getfloat('training.modifier', 'mask_factor')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.face_masker = face_masker
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
|
||||
target_mask = self.calculate_mask(target_tensor)
|
||||
|
||||
if self.config_mask_factor > 0:
|
||||
target_mask = dilate_mask(target_mask, self.config_mask_factor)
|
||||
|
||||
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
mask_loss = self.mse_loss(target_mask, output_mask)
|
||||
weighted_mask_loss = mask_loss * self.config_mask_weight
|
||||
return mask_loss, weighted_mask_loss
|
||||
|
||||
def calculate_mask(self, target_tensor : Tensor) -> Tensor:
|
||||
target_tensor = torch.nn.functional.interpolate(target_tensor, (256, 256), mode = 'bilinear')
|
||||
target_tensor = (target_tensor.clip(-1, 1) + 1) * 0.5
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensor = self.face_masker(target_tensor)
|
||||
output_tensor = output_tensor.clamp(0, 1)
|
||||
output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear')
|
||||
|
||||
return output_tensor
|
||||
@@ -0,0 +1,190 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Embedding, Feature
|
||||
|
||||
|
||||
class AAD(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks')
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, 4096)
|
||||
self.layers = self.create_layers()
|
||||
|
||||
def create_layers(self) -> nn.ModuleList:
|
||||
layers = nn.ModuleList()
|
||||
|
||||
if self.config_output_size == 128:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(1024, 1024, 512, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 256:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 512:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1536, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 768, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 3072, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1536, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 768, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(256, 128, 128, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(128, 64, 64, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(64, 3, 64, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
return layers
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_features : Tuple[Feature, ...]) -> Tensor:
|
||||
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
|
||||
|
||||
for index, layer in enumerate(self.layers[:-1]):
|
||||
target_feature = target_features[index]
|
||||
temp_tensor = layer(temp_tensors, source_embedding, target_feature)
|
||||
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
|
||||
target_feature = target_features[-1]
|
||||
temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_feature)
|
||||
output_tensor = torch.tanh(temp_tensors)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class AdaptiveFeatureModulation(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int, target_channels : int, source_channels : int, num_blocks : int) -> None:
|
||||
super().__init__()
|
||||
self.context_input_channels = input_channels
|
||||
self.context_output_channels = output_channels
|
||||
self.context_target_channels = target_channels
|
||||
self.context_source_channels = source_channels
|
||||
self.context_num_blocks = num_blocks
|
||||
self.primary_layers = self.create_primary_layers()
|
||||
self.shortcut_layers = self.create_shortcut_layers()
|
||||
|
||||
def create_primary_layers(self) -> nn.ModuleList:
|
||||
primary_layers = nn.ModuleList()
|
||||
|
||||
for index in range(self.context_num_blocks):
|
||||
primary_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
|
||||
nn.ReLU()
|
||||
])
|
||||
|
||||
if index < self.context_num_blocks - 1:
|
||||
primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_input_channels, kernel_size = 3, padding = 1, bias = False))
|
||||
else:
|
||||
primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False))
|
||||
|
||||
return primary_layers
|
||||
|
||||
def create_shortcut_layers(self) -> nn.ModuleList:
|
||||
shortcut_layers = nn.ModuleList()
|
||||
|
||||
if self.context_input_channels > self.context_output_channels:
|
||||
shortcut_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False)
|
||||
])
|
||||
|
||||
return shortcut_layers
|
||||
|
||||
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
|
||||
primary_tensor = input_tensor
|
||||
|
||||
for primary_layer in self.primary_layers:
|
||||
if isinstance(primary_layer, FeatureModulation):
|
||||
primary_tensor = primary_layer(primary_tensor, source_embedding, target_feature)
|
||||
else:
|
||||
primary_tensor = primary_layer(primary_tensor)
|
||||
|
||||
if self.context_input_channels > self.context_output_channels:
|
||||
shortcut_tensor = input_tensor
|
||||
|
||||
for shortcut_layer in self.shortcut_layers:
|
||||
if isinstance(shortcut_layer, FeatureModulation):
|
||||
shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_feature)
|
||||
else:
|
||||
shortcut_tensor = shortcut_layer(shortcut_tensor)
|
||||
|
||||
input_tensor = shortcut_tensor
|
||||
|
||||
return primary_tensor + input_tensor
|
||||
|
||||
|
||||
class FeatureModulation(nn.Module):
|
||||
def __init__(self, input_channels : int, target_channels : int, source_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.context_input_channels = input_channels
|
||||
self.conv1 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
|
||||
self.conv2 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
|
||||
self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
self.linear1 = nn.Linear(source_channels, input_channels)
|
||||
self.linear2 = nn.Linear(source_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels)
|
||||
|
||||
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
|
||||
temp_tensor = self.instance_norm(input_tensor)
|
||||
|
||||
source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
source_modulation = source_scale * temp_tensor + source_shift
|
||||
|
||||
target_scale = self.conv1(target_feature)
|
||||
target_shift = self.conv2(target_feature)
|
||||
target_modulation = target_scale * temp_tensor + target_shift
|
||||
|
||||
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
|
||||
output_tensor = (1 - temp_mask) * target_modulation + temp_mask * source_modulation
|
||||
return output_tensor
|
||||
|
||||
|
||||
class PixelShuffleUpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(input_channels, output_channels)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1),
|
||||
nn.PixelShuffle(upscale_factor = 2)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.view(input_tensor.shape[0], -1, 1, 1)
|
||||
output_tensor = self.sequences(temp_tensor)
|
||||
return output_tensor
|
||||
@@ -0,0 +1,111 @@
|
||||
from configparser import ConfigParser
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Feature, Mask
|
||||
|
||||
|
||||
class MaskNet(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_input_channels = config_parser.getint('training.model.masker', 'input_channels')
|
||||
self.config_output_channels = config_parser.getint('training.model.masker', 'output_channels')
|
||||
self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters')
|
||||
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters)
|
||||
self.up_samples = self.create_up_samples(self.config_num_filters)
|
||||
self.bottleneck = BottleNeck(self.config_num_filters * 4)
|
||||
self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
@staticmethod
|
||||
def create_down_samples(input_channels : int, num_filters : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
DownSample(input_channels, num_filters),
|
||||
DownSample(num_filters, num_filters * 2),
|
||||
DownSample(num_filters * 2, num_filters * 4)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def create_up_samples(num_filters : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
UpSample(num_filters * 4, num_filters * 2),
|
||||
UpSample(num_filters * 2, num_filters),
|
||||
UpSample(num_filters, num_filters)
|
||||
])
|
||||
|
||||
def forward(self, input_tensor : Tensor, input_feature : Feature) -> Mask:
|
||||
output_mask = torch.cat([ input_tensor, input_feature ], dim = 1)
|
||||
|
||||
for down_sample in self.down_samples:
|
||||
output_mask = down_sample(output_mask)
|
||||
|
||||
output_mask = self.bottleneck(output_mask)
|
||||
|
||||
for up_sample in self.up_samples:
|
||||
output_mask = up_sample(output_mask)
|
||||
|
||||
output_mask = self.conv(output_mask)
|
||||
output_mask = self.sigmoid(output_mask)
|
||||
return output_mask
|
||||
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
def __init__(self, num_filters : int):
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(num_filters)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(num_filters : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor) + input_tensor
|
||||
output_tensor = self.relu(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(input_channels, output_channels)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(input_channels, output_channels)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor)
|
||||
return output_tensor
|
||||
@@ -0,0 +1,48 @@
|
||||
import math
|
||||
from configparser import ConfigParser
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class NLD(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_input_channels = config_parser.getint('training.model.discriminator', 'input_channels')
|
||||
self.config_num_filters = config_parser.getint('training.model.discriminator', 'num_filters')
|
||||
self.config_kernel_size = config_parser.getint('training.model.discriminator', 'kernel_size')
|
||||
self.config_num_layers = config_parser.getint('training.model.discriminator', 'num_layers')
|
||||
self.layers = self.create_layers()
|
||||
self.sequences = nn.Sequential(*self.layers)
|
||||
|
||||
def create_layers(self) -> nn.ModuleList:
|
||||
padding = math.ceil((self.config_kernel_size - 1) / 2)
|
||||
current_filters = self.config_num_filters
|
||||
layers = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(self.config_input_channels, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
|
||||
nn.LeakyReLU(0.2)
|
||||
])
|
||||
|
||||
for _ in range(1, self.config_num_layers):
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2)
|
||||
]
|
||||
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Conv2d(current_filters, 1, kernel_size = self.config_kernel_size, padding = padding)
|
||||
]
|
||||
return layers
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
return self.sequences(input_tensor)
|
||||
@@ -0,0 +1,160 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Feature
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.down_samples = self.create_down_samples()
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
def create_down_samples(self) -> nn.ModuleList:
|
||||
down_samples = nn.ModuleList(
|
||||
[
|
||||
DownSample(3, 32),
|
||||
DownSample(32, 64),
|
||||
DownSample(64, 128),
|
||||
DownSample(128, 256),
|
||||
DownSample(256, 512)
|
||||
])
|
||||
|
||||
if self.config_output_size == 128:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 512)
|
||||
])
|
||||
|
||||
if self.config_output_size == 256:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
if self.config_output_size == 512:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 1024),
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 2048),
|
||||
DownSample(2048, 2048),
|
||||
DownSample(2048, 2048)
|
||||
])
|
||||
|
||||
return down_samples
|
||||
|
||||
def create_up_samples(self) -> nn.ModuleList:
|
||||
up_samples = nn.ModuleList()
|
||||
|
||||
if self.config_output_size == 128:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(512, 512),
|
||||
UpSample(1024, 256)
|
||||
])
|
||||
|
||||
if self.config_output_size == 256:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(1024, 1024),
|
||||
UpSample(2048, 512),
|
||||
UpSample(1024, 256)
|
||||
])
|
||||
|
||||
if self.config_output_size == 512:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(1024, 1024),
|
||||
UpSample(2048, 512),
|
||||
UpSample(1536, 256),
|
||||
UpSample(768, 256)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(2048, 2048),
|
||||
UpSample(4096, 1024),
|
||||
UpSample(3072, 512),
|
||||
UpSample(1536, 256),
|
||||
UpSample(768, 256)
|
||||
])
|
||||
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(512, 128),
|
||||
UpSample(256, 64),
|
||||
UpSample(128, 32)
|
||||
])
|
||||
|
||||
return up_samples
|
||||
|
||||
def forward(self, target_tensor : Tensor) -> Tuple[Feature, ...]:
|
||||
down_features = []
|
||||
up_features = []
|
||||
temp_feature = target_tensor
|
||||
|
||||
for down_sample in self.down_samples:
|
||||
temp_feature = down_sample(temp_feature)
|
||||
down_features.append(temp_feature)
|
||||
|
||||
bottleneck_feature = down_features[-1]
|
||||
temp_feature = bottleneck_feature
|
||||
|
||||
for index, up_sample in enumerate(self.up_samples):
|
||||
skip_tensor = down_features[-(index + 2)]
|
||||
temp_feature = up_sample(temp_feature, skip_tensor)
|
||||
up_features.append(temp_feature)
|
||||
|
||||
final_feature = nn.functional.interpolate(temp_feature, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_feature, *up_features, final_feature
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(input_channels, output_channels)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.LeakyReLU(0.1)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor)
|
||||
output_tensor = torch.cat((output_tensor, skip_tensor), dim = 1)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.sequences = self.create_sequences(input_channels, output_channels)
|
||||
|
||||
@staticmethod
|
||||
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
nn.LeakyReLU(0.1)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.sequences(input_tensor)
|
||||
return output_tensor
|
||||
@@ -0,0 +1,299 @@
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from configparser import ConfigParser
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import ConcatDataset, Dataset, random_split
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import DynamicDataset
|
||||
from .helper import apply_noise, calculate_face_embedding, erode_mask, overlay_mask
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, Mask, OptimizerSet, TrainerPrecision, TrainerStrategy
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
class HyperSwapTrainer(LightningModule):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_generator_embedder_path = config_parser.get('training.model', 'generator_embedder_path')
|
||||
self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path')
|
||||
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
|
||||
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
|
||||
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
|
||||
self.config_discriminator_ratio = config_parser.getfloat('training.trainer', 'discriminator_ratio')
|
||||
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
|
||||
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
|
||||
self.config_mask_factor = config_parser.getfloat('training.modifier', 'mask_factor')
|
||||
self.config_noise_factor = config_parser.getfloat('training.modifier', 'noise_factor')
|
||||
self.config_generator_learning_rate = config_parser.getfloat('training.optimizer.generator', 'learning_rate')
|
||||
self.config_generator_momentum = config_parser.getfloat('training.optimizer.generator', 'momentum')
|
||||
self.config_generator_scheduler_factor = config_parser.getfloat('training.optimizer.generator', 'scheduler_factor')
|
||||
self.config_generator_scheduler_patience = config_parser.getint('training.optimizer.generator', 'scheduler_patience')
|
||||
self.config_discriminator_learning_rate = config_parser.getfloat('training.optimizer.discriminator', 'learning_rate')
|
||||
self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum')
|
||||
self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor')
|
||||
self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience')
|
||||
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
|
||||
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
|
||||
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
|
||||
self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval()
|
||||
self.generator = Generator(config_parser)
|
||||
self.discriminator = Discriminator(config_parser)
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
self.adversarial_loss = AdversarialLoss(config_parser)
|
||||
self.cycle_loss = CycleLoss(config_parser)
|
||||
self.feature_loss = FeatureLoss(config_parser)
|
||||
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
|
||||
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
|
||||
self.gaze_loss = GazeLoss(config_parser, self.gazer)
|
||||
self.mask_loss = MaskLoss(config_parser, self.face_masker)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
|
||||
with torch.no_grad():
|
||||
generator_target_features = self.generator.encode_features(target_tensor)
|
||||
output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
|
||||
|
||||
if self.config_mask_factor > 0:
|
||||
output_mask = erode_mask(output_mask, self.config_mask_factor)
|
||||
|
||||
return output_tensor, output_mask
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_generator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_discriminator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
|
||||
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_generator_scheduler_factor, patience = self.config_generator_scheduler_patience, min_lr = 1e-8)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_discriminator_scheduler_factor, patience = self.config_discriminator_scheduler_patience, min_lr = 1e-8)
|
||||
|
||||
generator_config =\
|
||||
{
|
||||
'optimizer': generator_optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': generator_scheduler
|
||||
}
|
||||
}
|
||||
discriminator_config =\
|
||||
{
|
||||
'optimizer': discriminator_optimizer,
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': discriminator_scheduler
|
||||
}
|
||||
}
|
||||
return generator_config, discriminator_config
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
do_update = (batch_index + 1) % self.config_accumulate_size == 0
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined]
|
||||
source_embedding = calculate_face_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calculate_face_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
|
||||
|
||||
if self.config_noise_factor > 0:
|
||||
source_embedding = apply_noise(source_embedding, self.config_noise_factor)
|
||||
source_embedding = nn.functional.normalize(source_embedding, p = 2)
|
||||
|
||||
generator_target_features = self.generator.encode_features(target_tensor)
|
||||
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
|
||||
generator_output_features = self.generator.encode_features(generator_output_tensor)
|
||||
cycle_output_tensor, cycle_output_mask = self.generator(target_embedding, generator_output_tensor, generator_output_features)
|
||||
cycle_output_features = self.generator.encode_features(cycle_output_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
|
||||
cycle_loss, weighted_cycle_loss = self.cycle_loss(target_tensor, cycle_output_tensor, generator_target_features, cycle_output_features)
|
||||
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
|
||||
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask)
|
||||
generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss
|
||||
|
||||
if torch.randn(1).item() < self.config_discriminator_ratio:
|
||||
discriminator_real_tensors = self.discriminator(source_tensor)
|
||||
else:
|
||||
discriminator_real_tensors = self.discriminator(target_tensor)
|
||||
discriminator_fake_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss(discriminator_real_tensors, discriminator_fake_tensors)
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
self.manual_backward(generator_loss)
|
||||
|
||||
if do_update:
|
||||
if self.config_gradient_clip:
|
||||
self.clip_gradients(
|
||||
generator_optimizer,
|
||||
gradient_clip_val = self.config_gradient_clip,
|
||||
gradient_clip_algorithm = 'norm'
|
||||
)
|
||||
generator_optimizer.step()
|
||||
generator_optimizer.zero_grad()
|
||||
self.untoggle_optimizer(generator_optimizer)
|
||||
|
||||
self.toggle_optimizer(discriminator_optimizer)
|
||||
self.manual_backward(discriminator_loss)
|
||||
|
||||
if do_update:
|
||||
if self.config_gradient_clip:
|
||||
self.clip_gradients(
|
||||
discriminator_optimizer,
|
||||
gradient_clip_val = self.config_gradient_clip,
|
||||
gradient_clip_algorithm = 'norm'
|
||||
)
|
||||
discriminator_optimizer.step()
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.untoggle_optimizer(discriminator_optimizer)
|
||||
|
||||
if self.global_step % self.config_preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)
|
||||
|
||||
self.log('generator_loss', generator_loss, prog_bar = True)
|
||||
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
|
||||
self.log('adversarial_loss', adversarial_loss)
|
||||
self.log('cycle_loss', cycle_loss)
|
||||
self.log('feature_loss', feature_loss)
|
||||
self.log('reconstruction_loss', reconstruction_loss)
|
||||
self.log('identity_loss', identity_loss)
|
||||
self.log('gaze_loss', gaze_loss)
|
||||
self.log('mask_loss', mask_loss)
|
||||
|
||||
if do_update:
|
||||
generator_scheduler.step(generator_loss)
|
||||
discriminator_scheduler.step(discriminator_loss)
|
||||
|
||||
return generator_loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
source_embedding = calculate_face_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor, _ = self.forward(source_embedding, target_tensor)
|
||||
output_embedding = calculate_face_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0))
|
||||
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
|
||||
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
|
||||
return validation_score
|
||||
|
||||
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, output_mask : Mask) -> None:
|
||||
preview_limit = 8
|
||||
preview_cells = []
|
||||
overlay_tensor = overlay_mask(output_tensor, output_mask)
|
||||
|
||||
for source_tensor, target_tensor, output_tensor, overlay_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], overlay_tensor[:preview_limit]):
|
||||
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2)
|
||||
preview_cells.append(preview_cell)
|
||||
|
||||
preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)
|
||||
preview_grid = torchvision.utils.make_grid(preview_cells, normalize = True, scale_each = True)
|
||||
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
|
||||
|
||||
|
||||
class ModelWithConfigCheckpoint(ModelCheckpoint):
|
||||
def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None:
|
||||
super()._save_checkpoint(trainer, checkpoint_path)
|
||||
config_path = Path(checkpoint_path).with_suffix('.ini')
|
||||
shutil.copy('config.ini', config_path)
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
|
||||
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
|
||||
config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
|
||||
|
||||
dataset_size = len(dataset) # type:ignore[arg-type]
|
||||
training_size = int(dataset_size * config_split_ratio)
|
||||
validation_size = int(dataset_size - training_size)
|
||||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]:
|
||||
datasets = []
|
||||
|
||||
for config_section in config_parser.sections():
|
||||
|
||||
if config_section.startswith('training.dataset'):
|
||||
config_multiplier = config_parser.getint(config_section, 'multiplier')
|
||||
__config_parser__ = deepcopy(config_parser)
|
||||
__config_parser__.remove_section(config_section)
|
||||
__config_parser__.add_section('training.dataset.current')
|
||||
|
||||
for key, value in config_parser.items(config_section):
|
||||
__config_parser__.set('training.dataset.current', key, value)
|
||||
|
||||
dynamic_dataset = DynamicDataset(__config_parser__)
|
||||
datasets.extend([ dynamic_dataset ] * config_multiplier)
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
|
||||
config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy'))
|
||||
config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision'))
|
||||
config_sync_batchnorm = CONFIG_PARSER.getboolean('training.trainer', 'sync_batchnorm')
|
||||
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
|
||||
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
|
||||
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
logger = TensorBoardLogger(config_logger_path, config_logger_name)
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = config_max_epochs,
|
||||
strategy = config_strategy,
|
||||
precision = config_precision,
|
||||
sync_batchnorm = config_sync_batchnorm,
|
||||
callbacks =
|
||||
[
|
||||
ModelWithConfigCheckpoint(
|
||||
monitor = 'generator_loss',
|
||||
dirpath = config_directory_path,
|
||||
filename = config_file_pattern,
|
||||
every_n_train_steps = 1000,
|
||||
save_top_k = 5,
|
||||
save_last = True
|
||||
)
|
||||
],
|
||||
val_check_interval = 1000
|
||||
)
|
||||
|
||||
|
||||
def train() -> None:
|
||||
config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = ConcatDataset(prepare_datasets(CONFIG_PARSER))
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
hyperswap_trainer = HyperSwapTrainer(CONFIG_PARSER)
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.isfile(config_resume_path):
|
||||
trainer.fit(hyperswap_trainer, training_loader, validation_loader, ckpt_path = config_resume_path)
|
||||
else:
|
||||
trainer.fit(hyperswap_trainer, training_loader, validation_loader)
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import Any, Dict, Literal, Tuple, TypeAlias
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
Batch : TypeAlias = Tuple[Tensor, Tensor]
|
||||
BatchMode = Literal['equal', 'same', 'different']
|
||||
UsageMode = Literal['source', 'target', 'both']
|
||||
|
||||
ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
|
||||
ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor]
|
||||
|
||||
Feature : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
Mask : TypeAlias = Tensor
|
||||
Loss : TypeAlias = Tensor
|
||||
|
||||
Padding : TypeAlias = Tuple[int, int, int, int]
|
||||
|
||||
GeneratorModule : TypeAlias = Module
|
||||
EmbedderModule : TypeAlias = Module
|
||||
GazerModule : TypeAlias = Module
|
||||
FaceMaskerModule : TypeAlias = Module
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true']
|
||||
TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16']
|
||||
@@ -0,0 +1,56 @@
|
||||
from configparser import ConfigParser
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from hyperswap.src.networks.aad import AAD
|
||||
from hyperswap.src.networks.masknet import MaskNet
|
||||
from hyperswap.src.networks.unet import UNet
|
||||
|
||||
|
||||
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
|
||||
def test_aad_with_unet(output_size : int) -> None:
|
||||
config_parser = ConfigParser()
|
||||
config_parser.read_dict(
|
||||
{
|
||||
'training.model.generator':
|
||||
{
|
||||
'source_channels': '512',
|
||||
'output_size': str(output_size),
|
||||
'num_blocks': '2'
|
||||
}
|
||||
})
|
||||
|
||||
encoder = UNet(config_parser).eval()
|
||||
generator = AAD(config_parser).eval()
|
||||
|
||||
source_tensor = torch.randn(1, 512)
|
||||
target_tensor = torch.randn(1, 3, output_size, output_size)
|
||||
|
||||
target_features = encoder(target_tensor)
|
||||
output_tensor = generator(source_tensor, target_features)
|
||||
|
||||
assert output_tensor.shape == (1, 3, output_size, output_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
|
||||
def test_mask_net(output_size : int) -> None:
|
||||
config_parser = ConfigParser()
|
||||
config_parser.read_dict(
|
||||
{
|
||||
'training.model.masker':
|
||||
{
|
||||
'input_channels': '67',
|
||||
'output_channels': '1',
|
||||
'num_filters': '16'
|
||||
}
|
||||
})
|
||||
|
||||
masker = MaskNet(config_parser).eval()
|
||||
|
||||
target_tensor = torch.randn(1, 3, output_size, output_size)
|
||||
target_feature = torch.randn(1, 64, output_size, output_size)
|
||||
|
||||
output_mask = masker(target_tensor, target_feature)
|
||||
|
||||
assert output_mask.shape == (1, 1, output_size, output_size)
|
||||
@@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from src.training import train
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
@@ -5,3 +5,4 @@ disallow_untyped_calls = True
|
||||
disallow_untyped_defs = True
|
||||
ignore_missing_imports = True
|
||||
strict_optional = False
|
||||
explicit_package_bases = True
|
||||
|
||||
+10
-6
@@ -1,6 +1,10 @@
|
||||
lightning==2.4.0
|
||||
numpy==1.26.4
|
||||
onnx==1.16.2
|
||||
onnxruntime==1.19.0
|
||||
opencv-python==4.10.0.84
|
||||
mxnet==1.9.1
|
||||
--extra-index-url https://download.pytorch.org/whl/cu128
|
||||
albumentations==2.0.8
|
||||
lightning==2.5.5
|
||||
onnx==1.18.0
|
||||
onnxruntime==1.22.0
|
||||
pytorch-msssim==1.0.0
|
||||
torch==2.8.0
|
||||
torchdata==0.11.0
|
||||
torchvision==0.23.0
|
||||
tensorboard==2.20.0
|
||||
|
||||
Reference in New Issue
Block a user