456 Commits

Author SHA1 Message Date
henryruhs 6547fcfe7b Add .claude to gitignore 2025-10-10 18:32:38 +02:00
Henry Ruhs 2e6394565a Next (#93)
* add sync batchnorm

* replace random.choice with hash

* fifty percent reduction

* fix discriminator input

* restore dataset.py

* Remove duplicates

* add discriminator_ratio to config

* fix onnx export bug: replace round() with int()

* Fix embedding naming

* Introduce ModelWithConfigCheckpoint callback (#86)

* Fix dist ini

* Style: Refactor typing and improve code clarity in training.py (#88)

* Add type casting for trainer params

* Add type casting for trainer params

* Add type casting for trainer params

* Remove inplace activations for torch.compile compatibility (#89)

* Fix README

* improvise with norm layers & weighted average

* add skip layer

* use gelu instead of leaky_relu

* cleanup

* cleanup

* Update dependencies

* Different defaults and enable validation

* Different defaults and enable validation

* Revert to higher batch size

* Just use copy over copy2

---------

Co-authored-by: harisreedhar <h4harisreedhar.s.s@gmail.com>
Co-authored-by: NeuroDonu <112660822+NeuroDonu@users.noreply.github.com>
Co-authored-by: Harisreedhar <46858047+harisreedhar@users.noreply.github.com>
2025-09-06 19:12:29 +02:00
henryruhs 9f9f9dbad7 Rename calcXXX to calculateXXX 2025-06-25 11:36:31 +02:00
henryruhs dbe79aa3b9 Fix config for mask factor 2025-06-24 21:47:53 +02:00
henryruhs 2809a59704 Move mask_factor and noise_factor to modifier block 2025-06-20 15:29:37 +02:00
henryruhs a86497177d Move mask factor to trainer, Refactor helper 2025-06-20 15:15:02 +02:00
Harisreedhar 338f49c3dc Merge pull request #85 from facefusion/feat/dilate-mask
Feat/dilate mask
2025-06-20 18:28:08 +05:30
harisreedhar 56b71048e3 add erode for export and make it conditional 2025-06-19 16:29:30 +05:30
henryruhs 7490ead302 Update dependencies 2025-06-17 22:03:22 +02:00
henryruhs ecc37873bf Update README 2025-06-17 21:49:45 +02:00
harisreedhar bd762c4c38 dilate 2025-06-17 15:15:50 +05:30
Harisreedhar f5c49a02cb Merge pull request #84 from facefusion/stabilize-finetuning
Change optimizer and expose more parameters to config
2025-06-17 14:49:34 +05:30
harisreedhar f4d4914f5c rearrange 2025-06-17 14:46:56 +05:30
harisreedhar 2f28fb664b apply crossface 2025-06-17 14:43:21 +05:30
harisreedhar 35c250b0c9 rearrange logger 2025-06-17 13:44:37 +05:30
harisreedhar 580a179f44 rename 2025-06-17 13:39:24 +05:30
harisreedhar e846d88145 split config section 2025-06-17 13:26:06 +05:30
harisreedhar e894e4172a stabilize training 2025-06-17 13:06:47 +05:30
harisreedhar fc766b8327 fix lr scheduler 2025-06-11 15:38:28 +05:30
harisreedhar a06f5fd9e8 stabilize finetune 2025-06-11 14:15:11 +05:30
Harisreedhar ce7aaa57dc Merge pull request #83 from facefusion/ssim-fix
fix ssim
2025-06-09 19:07:30 +05:30
harisreedhar fce54eb7db fix ssim 2025-06-09 19:00:56 +05:30
Henry Ruhs 3e9c8a37e7 Some polishing on augmentation (#82) 2025-06-09 07:49:00 +02:00
henryruhs 143b594ee6 Rename to dataset multiplier 2025-06-05 16:34:03 +02:00
Henry Ruhs 47bebb02d7 Introduce usage ration to boost datasets (#81)
* Introduce usage rate to boost datasets

* Rename to usage ratio
2025-06-05 15:22:25 +02:00
Henry Ruhs 94cbcb68f0 Cache the usage of glob.glob (#80) 2025-06-05 09:30:04 +02:00
Henry Ruhs a602bbd474 Dataset Usage Mode (#79)
* Introduce FilePool to support usage modes

* Fix lint

* Add to README and config

* Enforce equal and same swaps

* Different approach to forward convert tempalte

* Changes

* Changes

* Changes

* Changes

* Introduce V3 of the usage mode feature

* Fix lint

* Proper use of config parser

* fix filter to filter config
2025-06-04 09:25:31 +02:00
Henry Ruhs 24f45877f5 feat/noise injection (#78)
* changes

* add to config.ini

* changes

* changes

---------

Co-authored-by: harisreedhar <h4harisreedhar.s.s@gmail.com>
2025-06-02 11:04:53 +02:00
Henry Ruhs 0722db91f1 refactor/convert tensor (#76)
* change to convert template

* changes

* changes

* Conditional convert input tensor

* Conditional convert input tensor

---------

Co-authored-by: harisreedhar <h4harisreedhar.s.s@gmail.com>
2025-05-26 09:44:09 +02:00
Henry Ruhs 475b8b1538 Next (#75)
* Add gradient value clip

* Add gradient clip to config

* Fix HifiFace in preview

* Fix HifiFace in preview

* Fix HifiFace in preview

* Adjust save top
2025-05-05 10:03:34 +02:00
henryruhs d68b77bd4d Final touches on the README 2025-04-26 12:20:09 +02:00
henryruhs 8806accbb3 Update previews 2025-04-26 12:09:35 +02:00
henryruhs d4a8719870 Fix preview 2025-04-24 15:51:28 +02:00
Henry Ruhs 1a41a941e2 Merge pull request #74 from facefusion/rename-everything
Rename everything
2025-04-24 15:22:47 +02:00
henryruhs 5c855aae4e Sort out the warp template naming 2025-04-24 13:07:35 +02:00
henryruhs 810df0f540 Final rename for everything 2025-04-24 12:42:53 +02:00
henryruhs 03011200e4 Update dependencies 2025-04-23 21:35:45 +02:00
henryruhs 837ee1e18c Cosmetic changes 2025-04-23 21:11:51 +02:00
Henry Ruhs ef62a2ee9e Fix preview 2025-04-23 20:52:04 +02:00
Henry Ruhs af455f5236 Fix preview 2025-04-23 20:51:39 +02:00
Henry Ruhs 0b7db0cc27 Merge pull request #72 from facefusion/next
Next
2025-04-23 20:50:08 +02:00
Harisreedhar 140cad492a Merge pull request #71 from facefusion/remove-pose-expression-loss
Remove pose and expression loss
2025-04-23 22:13:27 +05:30
harisreedhar d44ac98e38 changes 2025-04-23 22:00:42 +05:30
harisreedhar d990ce4575 remove 1024 from test 2025-04-21 15:41:22 +05:30
harisreedhar 982a94b535 add 1024 2025-04-21 15:37:13 +05:30
Harisreedhar 5b41d8e91f Merge pull request #70 from facefusion/cycle-loss
Cycle loss
2025-04-15 14:17:28 +05:30
harisreedhar bcf5b4e5a8 changes 2025-04-15 14:14:09 +05:30
harisreedhar 128726701b changes 2025-04-15 14:10:58 +05:30
harisreedhar 4a319ec9bd changes 2025-04-15 14:05:27 +05:30
harisreedhar 39ce14b590 remove discriminator frequency 2025-04-15 13:54:17 +05:30
harisreedhar 1477850a23 limit discriminator training every 10 steps 2025-04-15 13:54:17 +05:30
harisreedhar f4c4066e8c changes 2025-04-15 13:54:17 +05:30
henryruhs d9fe667ced Fix typo 2025-04-15 13:54:17 +05:30
harisreedhar b7a6f00e8b remove discriminator frequency 2025-04-15 13:54:17 +05:30
harisreedhar b215db68c3 limit discriminator training every 10 steps 2025-04-15 13:54:17 +05:30
harisreedhar dc2b2dc982 changes 2025-04-15 13:54:15 +05:30
henryruhs 76fe5c351c Fix typo 2025-04-13 10:17:52 +02:00
harisreedhar 056bacb7de more augmentation 2025-04-10 11:33:21 +05:30
Henry Ruhs dafada11bc Merge pull request #69 from facefusion/individual-embedder
Individual embedder for generator and loss
2025-04-08 22:01:07 +02:00
harisreedhar 4f5ac00a7b changes 2025-04-08 12:29:07 +05:30
Henry Ruhs 2e3c3517cb Merge pull request #68 from facefusion/mixed-precision-export
Half precision export
2025-03-29 09:57:47 +01:00
henryruhs 7845dd8522 Adjust like reviewed 2025-03-29 09:56:23 +01:00
harisreedhar 4b851a173d changes 2025-03-28 16:32:44 +05:30
Harisreedhar f99c73495c Merge pull request #67 from facefusion/multi-dataset
Multi dataset
2025-03-26 17:43:35 +05:30
harisreedhar cc6a99f305 changes 2025-03-26 17:13:44 +05:30
harisreedhar 9df29f8a22 changes 2025-03-26 16:54:05 +05:30
harisreedhar 80e600cbb5 changes 2025-03-26 16:51:38 +05:30
henryruhs 117a9d0fc9 Remove outdated file 2025-03-24 15:09:02 +01:00
Henry Ruhs d2be8a386a Update dependencies 2025-03-24 09:09:29 +01:00
henryruhs 0743b99347 Add HorizontalFlip to AugmentTransform 2025-03-23 21:36:03 +01:00
henryruhs 4f4057fc54 Adjust suggested defaults 2025-03-23 16:33:25 +01:00
henryruhs 4ebdeee634 Alphabetical order for warp templates 2025-03-23 16:21:45 +01:00
Harisreedhar 00d5c1f200 Merge pull request #66 from facefusion/remove-redundant-encoder-calc
Remove redundant encoder calculation
2025-03-23 18:53:29 +05:30
harisreedhar 99a8527e24 changes 2025-03-23 18:48:45 +05:30
harisreedhar 602e890af2 changes 2025-03-23 18:45:35 +05:30
harisreedhar 9ede8a2a7d changes 2025-03-23 18:38:48 +05:30
harisreedhar c85c755e00 changes 2025-03-23 18:34:23 +05:30
harisreedhar c1f39a73dd change face-parser to face-masker 2025-03-23 18:18:10 +05:30
harisreedhar 583d09e666 change face-parser 2025-03-23 18:01:38 +05:30
harisreedhar 9153b4ce8f add 'ffhq_to_arcface_128_v2' template 2025-03-22 14:14:37 +05:30
harisreedhar 35afb426b7 add expression loss to log 2025-03-20 11:01:54 +05:30
harisreedhar b6a2734622 add masknet layer 2025-03-19 19:27:09 +05:30
harisreedhar 10b6f801d1 add mask blend 2025-03-18 21:43:49 +05:30
henryruhs 798ff87736 Adjust naming 2025-03-16 15:18:28 +01:00
Henry Ruhs df3e22fdf9 Merge pull request #63 from facefusion/mask-guided-generator
Mask guided generator
2025-03-16 14:25:39 +01:00
henryruhs 24f2e14a95 Join MaskNet to guide generator 2025-03-16 12:29:17 +01:00
henryruhs c45fcbba84 Join MaskNet to guide generator 2025-03-16 12:24:08 +01:00
henryruhs ad675ae633 Join MaskNet to guide generator 2025-03-16 12:23:08 +01:00
henryruhs 803902c8bb Add sync_dist to validation score logging 2025-03-16 09:01:40 +01:00
henryruhs bc1b04a107 Hotfix imprt 2025-03-16 08:59:08 +01:00
Henry Ruhs 345a225c94 Merge pull request #62 from facefusion/rename-attribute-to-feature
Rename attribute to feature
2025-03-16 08:44:00 +01:00
henryruhs eefc69a820 Fix CI 2025-03-16 08:41:31 +01:00
henryruhs 94571c5676 Rename attribute to feature 2025-03-16 08:39:35 +01:00
henryruhs 904a447e06 Mask typing and naming related updates 2025-03-14 08:30:55 +01:00
henryruhs 33d00ac941 Mask typing and naming related updates 2025-03-14 08:13:21 +01:00
henryruhs 5234874bc7 Mask typing and naming related updates 2025-03-14 08:09:47 +01:00
henryruhs b5efcbe44a Adjust the weights in README 2025-03-12 23:02:26 +01:00
henryruhs bdd7fd0d86 Fix attribute loss 2025-03-12 22:54:20 +01:00
henryruhs 7c75b0d898 Fix attribute loss 2025-03-12 22:41:28 +01:00
henryruhs 0d73bcf918 Fix attribute loss 2025-03-12 22:37:06 +01:00
henryruhs 8f0ee4935b Remove flag 2025-03-12 22:07:39 +01:00
henryruhs 431df7cff8 Hotfix 2025-03-12 21:58:29 +01:00
henryruhs c0aaae9358 Hotfix 2025-03-12 21:55:26 +01:00
Henry Ruhs e4d2d244a0 Merge pull request #60 from facefusion/output-based-masknet
Train MaskNet based on the output
2025-03-12 21:42:51 +01:00
henryruhs 72591fbed1 Train MaskNet based on the output 2025-03-12 21:42:08 +01:00
henryruhs cf0bd93814 Train MaskNet based on the output 2025-03-12 21:35:07 +01:00
henryruhs 0732924f1e Minor change 2025-03-12 20:18:33 +01:00
Henry Ruhs bf7bbc2550 Merge pull request #57 from facefusion/fix-aad-naming
Fix AAD naming, Attribute vs. Embedding
2025-03-12 18:07:46 +01:00
henryruhs f989df39e9 Fix test 2025-03-12 16:30:59 +01:00
henryruhs ecc5b1a1d5 Fix test 2025-03-12 16:30:59 +01:00
henryruhs 5550c89f43 Fix AAD naming, Attribute and Embedding 2025-03-12 16:30:59 +01:00
Henry Ruhs 113f9fa6e5 Merge pull request #59 from facefusion/extend-config-with-logger-and-strategy
Extend config
2025-03-12 16:29:36 +01:00
henryruhs 20aa5114e1 Extend config 2025-03-12 16:24:03 +01:00
Henry Ruhs 3c0a65c7a0 Merge pull request #58 from facefusion/refactor/generator-return-attributes
Let the generator return target attributes
2025-03-12 16:22:24 +01:00
henryruhs 2af4f2c8ab Extend config 2025-03-12 16:21:47 +01:00
henryruhs 8b465fce03 Fix generator call 2025-03-12 15:52:56 +01:00
henryruhs d212e2fe12 Fix import 2025-03-12 15:48:53 +01:00
henryruhs 569df9e96d Let the generator return target attributes 2025-03-12 15:33:31 +01:00
henryruhs 738e00d59e Add more Loss types 2025-03-12 08:42:38 +01:00
henryruhs 564cc7b127 Introduce Loss type, Remove Gaze type 2025-03-12 08:38:49 +01:00
Henry Ruhs 944096befc Merge pull request #56 from facefusion/name-polishing
Add Attributes, polish names
2025-03-11 19:27:49 +01:00
henryruhs 0991745753 Remove Attributes 2025-03-11 19:17:26 +01:00
henryruhs 70ac772a34 Remove Attributes 2025-03-11 19:16:17 +01:00
henryruhs 31303c1c6c Add Attributes, polish names 2025-03-11 19:12:26 +01:00
Henry Ruhs e85aa20602 Merge pull request #55 from facefusion/fix-export
Fix export
2025-03-11 18:33:16 +01:00
harisreedhar 5567b49a6d fix 2025-03-11 23:02:46 +05:30
harisreedhar 99931def84 fix 2025-03-11 22:57:08 +05:30
harisreedhar 4f67e045a0 move generator optimizer toggle 2025-03-11 19:37:05 +05:30
henryruhs af09ee7ff3 Fix config 2025-03-11 14:43:10 +01:00
henryruhs 09432d9214 Rename parser to face parser 2025-03-11 14:43:10 +01:00
henryruhs afab997ffc Rename parser to face parser 2025-03-11 14:43:10 +01:00
henryruhs e758eb3e19 Use sequential whenever possible 2025-03-11 14:43:10 +01:00
harisreedhar f90fd73b54 add strategy to config 2025-03-11 14:43:10 +01:00
henryruhs 8f76d96bb2 Name color tensor -> overlay tensor 2025-03-11 14:43:10 +01:00
harisreedhar 66d1573f4b changes 2025-03-11 14:43:10 +01:00
harisreedhar a7a21cd684 changes 2025-03-11 14:43:10 +01:00
harisreedhar 52b98b5be5 change mask preview 2025-03-11 14:43:10 +01:00
henryruhs 0304b5dd91 Fix spacing 2025-03-11 14:43:10 +01:00
harisreedhar 2322b6539f changes 2025-03-11 14:43:10 +01:00
harisreedhar b6e131e4c1 changes 2025-03-11 14:43:10 +01:00
harisreedhar d809f66216 changes 2025-03-11 14:43:10 +01:00
henryruhs 90cb6afe10 Minor adjustment 2025-03-11 14:43:10 +01:00
henryruhs 7f16d0a10e Follow sequences pattern 2025-03-11 14:43:10 +01:00
henryruhs 54923abf7f Follow layers and sequences pattern 2025-03-11 14:43:10 +01:00
harisreedhar e267f9ffd5 reduce layer 2025-03-11 14:43:10 +01:00
harisreedhar 27440ef023 changes 2025-03-11 14:43:10 +01:00
harisreedhar c2a639229f changes 2025-03-11 14:43:10 +01:00
henryruhs f9d105ea2b Some code review 2025-03-11 14:43:10 +01:00
henryruhs d2efb2fd08 Fix README.md 2025-03-11 14:43:10 +01:00
harisreedhar 64464b6f1c fix 2025-03-11 14:43:10 +01:00
harisreedhar 606bf42089 fix 2025-03-11 14:43:10 +01:00
harisreedhar a3ac4d5ddd changes 2025-03-11 14:43:10 +01:00
harisreedhar 1659805b08 changes 2025-03-11 14:43:10 +01:00
harisreedhar 8f1f002c64 add masknet 2025-03-11 14:43:10 +01:00
harisreedhar 4af22832db fix discriminator & restore grad for ID embedder 2025-03-11 14:43:10 +01:00
henryruhs 9ff30a0268 Use getint 2025-03-11 14:43:10 +01:00
henryruhs dbe931e950 Fix naming 2025-03-11 14:43:10 +01:00
henryruhs 862dce7bc6 Revert the config dicts 2025-03-11 14:43:10 +01:00
henryruhs 8101b15e1c Revert the config dicts 2025-03-11 14:43:10 +01:00
henryruhs 1dfd230fc5 Revert the config dicts 2025-03-11 14:43:10 +01:00
henryruhs e5f983b2bf Minor adjustment for test 2025-03-11 14:43:10 +01:00
henryruhs 7e938c2ec9 Use a lot of ignores 2025-03-11 14:43:10 +01:00
henryruhs c7d55d0d17 Use a lot of ignores 2025-03-11 14:43:10 +01:00
henryruhs 6bc44ad3d8 Fix CI 2025-03-11 14:43:10 +01:00
henryruhs b829d5e42c Migrate most to self.config and self.context 2025-03-11 14:43:10 +01:00
henryruhs ab3b699124 Rework on config 2025-03-11 14:43:10 +01:00
henryruhs a5d99c139e Fix bug 2025-03-11 14:43:10 +01:00
henryruhs bfa9924b40 Remove types 2025-03-11 14:43:10 +01:00
henryruhs a0c42bedbe Improve imports 2025-03-11 14:43:10 +01:00
henryruhs d215b6f98b Fix CI 2025-03-11 14:43:10 +01:00
henryruhs 01278d679f Forward config parser 2025-03-11 14:43:10 +01:00
henryruhs a2a9b78dac More adjustments 2025-03-11 14:43:10 +01:00
henryruhs cf26f66e36 More adjustments 2025-03-11 14:43:10 +01:00
henryruhs f4a1e18ca9 More adjustments 2025-03-11 14:43:10 +01:00
henryruhs 7ab7efbbf4 More adjustments 2025-03-11 14:43:10 +01:00
henryruhs 6a5f81e5fe More adjustments 2025-03-11 14:43:10 +01:00
henryruhs 6a11603e7e More adjustments 2025-03-11 14:43:10 +01:00
henryruhs ff9b777b28 More adjustments 2025-03-11 14:43:10 +01:00
henryruhs 5bacb048dd More adjustments 2025-03-11 14:43:10 +01:00
henryruhs 847579f925 More adjustments 2025-03-11 14:43:10 +01:00
henryruhs 57aad5204e More adjustments 2025-03-11 14:43:10 +01:00
henryruhs d3b0051912 More adjustments 2025-03-11 14:43:10 +01:00
henryruhs d944d95bfd Fix CI 2025-03-11 14:43:10 +01:00
henryruhs 8f1f63f2ef Fix CI 2025-03-11 14:43:10 +01:00
henryruhs b59e172fa3 Always config injection 2025-03-11 14:43:10 +01:00
henryruhs 368da824aa Fix CI 2025-03-11 14:43:10 +01:00
henryruhs e61e470432 Fix CI 2025-03-11 14:43:10 +01:00
henryruhs c8953ce8a1 Apply new config approach for embedding converter 2025-03-11 14:43:10 +01:00
harisreedhar e428ae04e3 add 128px 2025-03-11 14:43:10 +01:00
harisreedhar a89e51c2f8 remove grad for inference only models 2025-03-11 14:43:10 +01:00
harisreedhar 61f48d9246 changes 2025-03-11 14:43:10 +01:00
henryruhs dfd018a897 Fix CI 2025-03-11 14:43:10 +01:00
henryruhs b63562abad Simplify infer() 2025-03-11 14:43:10 +01:00
henryruhs e9ea9cd9e5 Clean GazeLoss 2025-03-11 14:43:10 +01:00
henryruhs abdc770892 Improve UNet 2025-03-11 14:43:10 +01:00
harisreedhar 7cf5609c1f changes 2025-03-11 14:43:10 +01:00
henryruhs 6388727262 Fix CI 2025-03-11 14:43:10 +01:00
henryruhs aedaa20d78 Remove UnetPro, make float values visible in README 2025-03-11 14:43:10 +01:00
henryruhs f9de4ce78a Remove UnetPro, make float values visible in README 2025-03-11 14:43:10 +01:00
henryruhs 6d805438ad Adjust networks for 512 2025-03-11 14:43:10 +01:00
henryruhs 7cc893c32e Adjust networks for 512 2025-03-11 14:43:10 +01:00
henryruhs 3e69d5a9a9 Adjust networks for 512 2025-03-11 14:43:10 +01:00
henryruhs f791178ded Fix CI 2025-03-11 14:43:10 +01:00
henryruhs 866019d44f Debug path 2025-03-11 14:43:10 +01:00
henryruhs 94ad33cb1e Debug path 2025-03-11 14:43:10 +01:00
henryruhs de72e50233 Only ubuntu needed for test 2025-03-11 14:43:10 +01:00
henryruhs 5f053f9f69 Crazy fix 2025-03-11 14:43:10 +01:00
henryruhs f678aa8f7e Crazy fix 2025-03-11 14:43:10 +01:00
henryruhs 0e8207ccc8 Crazy fix 2025-03-11 14:43:10 +01:00
henryruhs 4a2559c866 Crazy fix 2025-03-11 14:43:10 +01:00
henryruhs 3c5554c1c5 Crazy fix 2025-03-11 14:43:10 +01:00
henryruhs dbf5687bcd Crazy fix 2025-03-11 14:43:10 +01:00
henryruhs 2148e9b701 Prepare test for 512 2025-03-11 14:43:10 +01:00
henryruhs 64ebfa7b84 Add basic test for aad and unet 2025-03-11 14:43:10 +01:00
henryruhs daeec46e36 Add basic test for aad and unet 2025-03-11 14:43:10 +01:00
henryruhs 13d15029b7 Add basic test for aad and unet 2025-03-11 14:43:10 +01:00
henryruhs d5c51a90e8 Add basic test for aad and unet 2025-03-11 14:43:10 +01:00
henryruhs 6fa8d6b6eb Add basic test for aad and unet 2025-03-11 14:43:10 +01:00
henryruhs 72371b9f11 Add basic test for aad and unet 2025-03-11 14:43:10 +01:00
henryruhs 786adf73a2 Fix UnetPro 2025-03-11 14:43:10 +01:00
henryruhs dcc5ccccd7 Extend Unet with more layers 2025-03-11 14:43:10 +01:00
henryruhs 5056b8df75 Variable AAD layer according to output size 2025-03-11 14:43:10 +01:00
henryruhs 430c71d031 Adjust config and namings 2025-03-11 14:43:10 +01:00
harisreedhar 176dced1f6 changes 2025-03-11 14:43:10 +01:00
henryruhs 18a605e1a3 Change the weight config order 2025-03-11 14:43:10 +01:00
henryruhs fea75ff949 Remove useless comma 2025-03-11 14:43:10 +01:00
harisreedhar df5895e266 changes 2025-03-11 14:43:10 +01:00
harisreedhar f2d3f8a19f changes 2025-03-11 14:43:10 +01:00
henryruhs ceb3c0cfdf Rename to AugmentTransform 2025-03-11 14:43:10 +01:00
henryruhs dd0a2fe649 Add albumentations 2025-03-11 14:43:10 +01:00
henryruhs e5a4a54e61 Introduce Transform classes, Add albumentations 2025-03-11 14:43:10 +01:00
henryruhs 9dc1031fa5 Fix same and equal 2025-03-11 14:43:10 +01:00
henryruhs d93daf9e5f Introduce batch mode via config for equal and same batches 2025-03-11 14:43:10 +01:00
henryruhs 3a61da8bab Introduce batch mode via config for equal and same batches 2025-03-11 14:43:10 +01:00
harisreedhar 83c20f8331 changes 2025-03-11 14:43:10 +01:00
harisreedhar dfd9e99aed changes 2025-03-11 14:43:10 +01:00
harisreedhar 2fb0b4289d changes 2025-03-11 14:43:10 +01:00
harisreedhar e6ea454360 supress warning 2025-03-11 14:43:10 +01:00
henryruhs 80d1694e23 Replace naming warp_matrix with warp_template 2025-03-11 14:43:10 +01:00
henryruhs 5e68de9170 Add config to dataloader alignment 2025-03-11 14:43:10 +01:00
henryruhs 6ca68f1408 Add config to dataloader alignment 2025-03-11 14:43:10 +01:00
henryruhs 43f99db5d7 Add config to dataloader alignment 2025-03-11 14:43:10 +01:00
henryruhs c9e70ebc18 Add config to dataloader alignment 2025-03-11 14:43:10 +01:00
henryruhs 1ab57b1d3f Add config to dataloader alignment 2025-03-11 14:43:10 +01:00
henryruhs cc9a0ba83e Add config to dataloader alignment 2025-03-11 14:43:10 +01:00
henryruhs 589568bfb5 Add config to dataloader alignment 2025-03-11 14:43:10 +01:00
harisreedhar 34a7f3ef55 changes 2025-03-11 14:43:10 +01:00
harisreedhar c35b0a6f4c changes 2025-03-11 14:43:10 +01:00
harisreedhar 23ab7dc89d fix 2025-03-11 14:43:10 +01:00
henryruhs 8443f3512a torch.float32 is the default 2025-03-11 14:43:10 +01:00
henryruhs d3a2035d7a Fix order 2025-03-11 14:43:10 +01:00
henryruhs 16c8b32269 Fix some naming and types 2025-03-11 14:43:10 +01:00
harisreedhar 65ab796835 changes 2025-03-11 14:43:10 +01:00
henryruhs 56be3f0b9b Rename to validation_score 2025-03-11 14:43:10 +01:00
henryruhs a22adaf51f Adjust some namings 2025-03-11 14:43:10 +01:00
henryruhs 0055c0c97f Adjust some namings 2025-03-11 14:43:10 +01:00
harisreedhar 35b779b1ed change pretrained models mode to eval 2025-03-11 14:43:10 +01:00
harisreedhar ea1b0205f0 constructor injection 2025-03-11 14:43:10 +01:00
harisreedhar a5eb7d6aa1 change Adam to AdamW 2025-03-11 14:43:10 +01:00
harisreedhar b27b8663e5 changes 2025-03-11 14:43:10 +01:00
harisreedhar d87f6c0b15 changes 2025-03-11 14:43:10 +01:00
harisreedhar 5d1b90ff19 changes 2025-03-11 14:43:10 +01:00
harisreedhar 2ddcf52b66 changes 2025-03-11 14:43:10 +01:00
henryruhs c8801ececd Fix export using Trainer 2025-03-11 14:43:10 +01:00
henryruhs 84b4451366 Fix export using Trainer 2025-03-11 14:43:10 +01:00
henryruhs cadbe9cf76 Follow the concept of layers and sequences 2025-03-11 14:43:10 +01:00
henryruhs 58a85a80bb Use high float32 matmul precision 2025-03-11 14:43:10 +01:00
harisreedhar ab0a59fb74 changes 2025-03-11 14:43:10 +01:00
henryruhs 578b07a7f4 Add StatefulDataloader, Manual trigger scheduler 2025-03-11 14:43:10 +01:00
henryruhs 7ce9d27097 Different naming 2025-03-11 14:43:10 +01:00
henryruhs 8c24c9ec27 Fix CI 2025-03-11 14:43:10 +01:00
henryruhs 0ad2556c4c Improve optimizer configs 2025-03-11 14:43:10 +01:00
henryruhs 484a49c27d Restore dataset behaviour for same person 2025-03-11 14:43:10 +01:00
harisreedhar e8cc2bfff1 add gaze loss 2025-03-11 14:43:10 +01:00
henryruhs a951d700fc Restore RGB normalization 2025-03-11 14:43:10 +01:00
henryruhs 6eff69a41a There is no need to make directories 2025-03-11 14:43:10 +01:00
henryruhs 5b3b2abdd7 Remove annoying Tuner, Match both trainer configs a bit more 2025-03-11 14:43:10 +01:00
henryruhs bfcbd6bf95 Fix CI 2025-03-11 14:43:10 +01:00
henryruhs 607c55ff1f Remove Numpy and CV2 to fully use Tensors 2025-03-11 14:43:10 +01:00
henryruhs 93cbbf52d0 Fix CI 2025-03-11 14:43:10 +01:00
henryruhs f5cd6b6336 This should be a StaticDataset, Fix learning rate finder 2025-03-11 14:43:10 +01:00
henryruhs 257e5e56a4 Rename split ratio 2025-03-11 14:43:10 +01:00
henryruhs f19908ccd6 Fix spacing for lambda 2025-03-11 14:43:10 +01:00
henryruhs 18a2531b54 Fix naming in dataset.py 2025-03-11 14:43:10 +01:00
henryruhs 84be7d1ffb Remove deprecated argument 2025-03-11 14:43:10 +01:00
henryruhs 8b53c76a0a Remove deprecated argument 2025-03-11 14:43:10 +01:00
henryruhs 7d8cb146a4 Use latest numpy 2025-03-11 14:43:10 +01:00
henryruhs 303cbfa024 Final refactoring for AAD done 2025-03-11 14:43:10 +01:00
henryruhs bc174186eb Final refactoring for AAD done 2025-03-11 14:43:10 +01:00
henryruhs d7158749c2 Improve naming 2025-03-11 14:43:10 +01:00
henryruhs 8b2b6892aa Improve naming 2025-03-11 14:43:10 +01:00
henryruhs 5bba2a1c69 Remove the condition from reconstruction loss 2025-03-11 14:43:10 +01:00
henryruhs 94480e16eb Rename to AAD as this is not a full AIENet 2025-03-11 14:43:10 +01:00
henryruhs bbcb1c35f0 Modernize AIENet 2025-03-11 14:43:10 +01:00
henryruhs ee3fc40e83 Move output channels to config 2025-03-11 14:43:10 +01:00
henryruhs 335d597e53 Rename to aienet.py 2025-03-11 14:43:10 +01:00
henryruhs 14bbece850 Rename id_ to identity_ 2025-03-11 14:43:10 +01:00
henryruhs fad38da864 Rename temp1 with positive and temp2 with negative 2025-03-11 14:43:10 +01:00
henryruhs e75a3c58f9 Replace calc() with forward(), Rename temp1 with positive and temp2 with negative 2025-03-11 14:43:10 +01:00
harisreedhar 6fed877d33 reconstruction-loss fix 2025-03-11 14:43:10 +01:00
henryruhs de6cfbc35b Fix CI 2025-03-11 14:43:10 +01:00
henryruhs ed0f6ae897 Use new loss code, Remove unused code, Remove old types, Ban VisionTensor naming 2025-03-11 14:43:10 +01:00
henryruhs 63e4bea3cd Use new config for old code 2025-03-11 14:43:10 +01:00
henryruhs 14b9bccafe Introduce new DiscriminatorLoss class, Remove useless super call params 2025-03-11 14:43:10 +01:00
henryruhs 579d3ef51c Introduce new GazeLoss class (switched to mean) 2025-03-11 14:43:10 +01:00
henryruhs a797548329 Introduce new PoseLoss class (switched to mean) 2025-03-11 14:43:10 +01:00
henryruhs 6eabcad1d0 Introduce new AttributeLoss class 2025-03-11 14:43:10 +01:00
henryruhs 7848d28b02 Introduce new AdversarialLoss class fix 2025-03-11 14:43:10 +01:00
henryruhs 30e787129a Introduce new AdversarialLoss class 2025-03-11 14:43:09 +01:00
henryruhs 38211f0340 Introduce new AdversarialLoss class 2025-03-11 14:43:09 +01:00
henryruhs f2833a32c3 Introduce new AdversarialLoss class 2025-03-11 14:43:09 +01:00
henryruhs 3b7d3b6688 Introduce new ReconstructionLoss class 2025-03-11 14:43:09 +01:00
henryruhs 086d9eed87 Introduce new ReconstructionLoss class 2025-03-11 14:43:09 +01:00
henryruhs 085c493e18 Rename id_embedder to embedder, Tons of naming in training step, Introduce new IdentityLoss class 2025-03-11 14:43:09 +01:00
henryruhs 2220f5ef08 Add missing encoder_type to config 2025-03-11 14:43:09 +01:00
henryruhs f482d46798 Remove VisionTensor from Discriminator 2025-03-11 14:43:09 +01:00
henryruhs a6e1405c70 Restore map_location = 'cpu' 2025-03-11 14:43:09 +01:00
henryruhs ac41bab3a2 Restore map_location = 'cpu' 2025-03-11 14:43:09 +01:00
henryruhs 206a1411d1 Fix Generator 2025-03-11 14:43:09 +01:00
henryruhs dc0abff0ce Minor typo 2025-03-11 14:43:09 +01:00
henryruhs 10ce04ed58 Introduce feature flag for Unet 2025-03-11 14:43:09 +01:00
henryruhs 094d5cea9e Fix CI 2025-03-11 14:43:09 +01:00
henryruhs f6c59257d9 Partial use Resnet34 as a DownSample replacement 2025-03-11 14:43:09 +01:00
henryruhs 83ef075b1d Remove map_location 2025-03-11 14:43:09 +01:00
henryruhs 575f215408 lr only in optimizer 2025-03-11 14:43:09 +01:00
henryruhs d153c68813 Better split_ratio 2025-03-11 14:43:09 +01:00
henryruhs 00dccf07b9 Fix tensor foobar 2025-03-11 14:43:09 +01:00
henryruhs 5a6e3393e2 Revert loss for the moment 2025-03-11 14:43:09 +01:00
henryruhs c17378f3c7 Let's call it batch ratio 2025-03-11 14:43:09 +01:00
henryruhs 84503761b9 follow the not invented here syndrome 2025-03-11 14:43:09 +01:00
henryruhs 723e9fde78 follow the not invented here syndrome 2025-03-11 14:43:09 +01:00
henryruhs 09e913233b Follow convention of the other project 2025-03-11 14:43:09 +01:00
henryruhs b4bbd862e2 Follow convention of the other project 2025-03-11 14:43:09 +01:00
harisreedhar 4d8433f54a changes 2025-03-11 14:43:09 +01:00
harisreedhar db44c91dd8 changes 2025-03-11 14:43:09 +01:00
henryruhs b33281425a This is not a TensorDataset 2025-03-11 14:43:09 +01:00
henryruhs 3d9ff4add0 Adjust validation to one time per epoch 2025-03-11 14:43:09 +01:00
henryruhs 5934b47961 Add cosine_similarity 2025-03-11 14:43:09 +01:00
henryruhs 04eaa831ea Rename to resume_path only 2025-03-11 14:43:09 +01:00
henryruhs c1bed34c27 More tweaks 2025-03-11 14:43:09 +01:00
henryruhs 4078681031 Fix typing 2025-03-11 14:43:09 +01:00
henryruhs 15ee6fa763 Simplify sizes 2025-03-11 14:43:09 +01:00
henryruhs 251e610f0e Follow the lightning naming and call this dataset, Improve config and types 2025-03-11 14:43:09 +01:00
henryruhs da51c5336d This can be marked static 2025-03-11 14:43:09 +01:00
henryruhs 9bd68c3d14 Similarity validation for embedding converter 2025-03-11 14:43:09 +01:00
henryruhs 0a50e2d706 Rename data loader 2025-03-11 14:43:09 +01:00
henryruhs 40dcef7fc7 Modernize data loader, remove read image helper 2025-03-11 14:43:09 +01:00
henryruhs c041073953 Modernize data loader, remove read image helper 2025-03-11 14:43:09 +01:00
henryruhs 354315502b Modernize data loader, remove read image helper 2025-03-11 14:43:09 +01:00
henryruhs 254bc17c98 Modernize data loader, remove read image helper 2025-03-11 14:43:09 +01:00
henryruhs 39c0313202 Modernize data loader, remove read image helper 2025-03-11 14:43:09 +01:00
henryruhs 3cf9711df0 Modernize data loader, remove read image helper 2025-03-11 14:43:09 +01:00
harisreedhar d1bf54276d changes 2025-03-11 14:43:09 +01:00
harisreedhar b47c6b72ee changes 2025-03-11 14:43:09 +01:00
henryruhs dcf19634d1 Brain fart 2025-03-11 14:43:09 +01:00
henryruhs 857365770f Make validation step more solid, failed on empty checksums 2025-03-11 14:43:09 +01:00
henryruhs d25f2865a9 Normalize validation output 2025-03-11 14:43:09 +01:00
henryruhs 0d45568bd1 Uniform validation, Add cosine_similarity validation to face swapper 2025-03-11 14:43:09 +01:00
henryruhs ccf6fa7f43 Uniform validation, Add cosine_similarity validation to face swapper 2025-03-11 14:43:09 +01:00
henryruhs 28977d37d6 Uniform validation, Add cosine_similarity validation to face swapper 2025-03-11 14:43:09 +01:00
henryruhs c161da2f25 Adjust naming and typing 2025-03-11 14:43:09 +01:00
henryruhs 39818a16df Fix CI 2025-03-11 14:43:09 +01:00
henryruhs e1ba81f220 Improve generator namings, Flip args to source then target 2025-03-11 14:43:09 +01:00
henryruhs bf696be097 Improve generator namings, Flip args to source then target 2025-03-11 14:43:09 +01:00
henryruhs f63bc788ac Fix typo 2025-03-11 14:43:09 +01:00
henryruhs 11bb9065ba Uniform resume checkpoint approach 2025-03-11 14:43:09 +01:00
harisreedhar 7b2b8f0f85 fix discriminator training 2025-03-11 14:43:09 +01:00
henryruhs 999f2c9cbe Rename variable of EmbeddingConverterTrainer 2025-03-11 14:43:09 +01:00
henryruhs 254f3efe68 Remove live portrait submodule 2025-03-11 14:43:09 +01:00
henryruhs 5bb41ecbb2 Adjust README 2025-03-11 14:43:09 +01:00
henryruhs 777a8384c2 Extend README with TensorBoard guide 2025-03-11 14:43:09 +01:00
henryruhs 0949c1358b Fix CI 2025-03-11 14:43:09 +01:00
henryruhs d9e10a9f7c Use lightning over pytorch_lightning import, Configure tensorboard 2025-03-11 14:43:09 +01:00
henryruhs cd4b10c832 Fix imports and update dependencies 2025-03-11 14:43:09 +01:00
henryruhs 0b7d25a36e Adjust licenses 2025-03-11 14:43:09 +01:00
henryruhs 026bcf0c97 Adjust licenses 2025-03-11 14:43:09 +01:00
henryruhs aa5094e576 We need MS license 2025-03-11 14:43:09 +01:00
harisreedhar 8f8dfecdbc some fixes 2025-03-11 14:43:09 +01:00
harisreedhar 030d912c1b some fixes 2025-03-11 14:43:09 +01:00
henryruhs 0e148845af Extend gitignore 2025-03-11 14:43:09 +01:00
henryruhs e1e0c11bb5 Fix CI 2025-03-11 14:43:09 +01:00
henryruhs 88c4e53192 Use skip_tensor variable 2025-03-11 14:43:09 +01:00
henryruhs 650551c19b Simplify Batch type 2025-03-11 14:43:09 +01:00
henryruhs a971506271 Revert changes in generator 2025-03-11 14:43:09 +01:00
henryruhs b69f69d015 Fix UNet 2025-03-11 14:43:09 +01:00
henryruhs 953525e6b0 Fix typo 2025-03-11 14:43:09 +01:00
henryruhs 58ad6af619 Add Sample suffix again 2025-03-11 14:43:09 +01:00
henryruhs 7264884ff9 Fix CI 2025-03-11 14:43:09 +01:00
henryruhs 1872f99584 Refacto UNet 2025-03-11 14:43:09 +01:00
henryruhs 29e82f909a Move nld to networks 2025-03-11 14:43:09 +01:00
henryruhs 257ab668ee Name what it is 2025-03-11 14:43:09 +01:00
henryruhs 62d897f9d8 Direct return layers 2025-03-11 14:43:09 +01:00
henryruhs f05ff6cdb1 Enforce IR version for older onnxruntime 2025-03-11 14:43:09 +01:00
henryruhs 3fe32d7832 Fix example config 2025-03-11 14:43:09 +01:00
henryruhs 3be8368eaa Improve lot of types, imports and names 2025-03-11 14:43:09 +01:00
henryruhs b6b4f9f65b Improve lot of types, imports and names 2025-03-11 14:43:09 +01:00
henryruhs e33bc0d52a Modernize to use ModuleList, Fix some types 2025-03-11 14:43:09 +01:00
henryruhs f3409c5ade Modernize to use ModuleList, Fix some types 2025-03-11 14:43:09 +01:00
henryruhs c6d16c0cf6 Modernize to use ModuleList, Fix some types 2025-03-11 14:43:09 +01:00
henryruhs dfe7ab3b6f Modernize to use ModuleList, Fix some types 2025-03-11 14:43:09 +01:00
henryruhs 3c6dfa4efe Modernize to use ModuleList, Fix some types 2025-03-11 14:43:09 +01:00
henryruhs 34d0bc10ed Fix append 2025-03-11 14:43:09 +01:00
henryruhs b785525f3b Fix append 2025-03-11 14:43:09 +01:00
henryruhs dd320ea5be Fix self 2025-03-11 14:43:09 +01:00
henryruhs 71e0ae34c0 Fix self 2025-03-11 14:43:09 +01:00
henryruhs 32dfdcf1b3 Remove unsued 2025-03-11 14:43:09 +01:00
henryruhs 494b84aecb Refactor discriminator to use ModuleList, Reduce complexity of layer creation 2025-03-11 14:43:09 +01:00
henryruhs 860771e482 Fix weight identity 2025-03-11 14:43:09 +01:00
henryruhs b42d2b06e7 Improve generate_preview 2025-03-11 14:43:09 +01:00
henryruhs 11c038cb81 Fix monitor 2025-03-11 14:43:09 +01:00
henryruhs 746cc86d52 Fix monitor 2025-03-11 14:43:09 +01:00
henryruhs 58c81cd646 Move magic methods up 2025-03-11 14:43:09 +01:00
henryruhs 4d2038d4ce Rename loss_id to loss_identity 2025-03-11 14:43:09 +01:00
henryruhs 67ad9badac Small typo 2025-03-11 14:43:09 +01:00
henryruhs 1f4405be44 Fix CI 2025-03-11 14:43:09 +01:00
henryruhs 1b6e7a6ca5 Rename ArcFace Converter to Embedding Converter, Add EmbeddingDataset, Add learning rate to config 2025-03-11 14:43:06 +01:00
henryruhs 62a69cddd2 Fix import order 2025-03-11 14:43:03 +01:00
henryruhs 611618e413 Make Embedding great again 2025-03-11 14:43:03 +01:00
harisreedhar 6381e755d7 list join 2025-03-11 14:43:03 +01:00
harisreedhar 2ed558a873 cleanup 2025-03-11 14:43:03 +01:00
henryruhs b7e2d3ccd7 It should be models and networks 2025-03-11 14:43:03 +01:00
harisreedhar 66af7d7957 add face swapper preview 2025-03-11 14:43:03 +01:00
harisreedhar dc0ef53668 add face swapper preview 2025-03-11 14:43:03 +01:00
henryruhs a17d050648 Mixed bag of cleanups 2025-03-11 14:43:03 +01:00
henryruhs 989a81c751 Mixed bag of cleanups 2025-03-11 14:43:03 +01:00
henryruhs a1cd025f81 Mixed bag of cleanups 2025-03-11 14:43:03 +01:00
henryruhs 9c15f584aa Mixed bag of cleanups 2025-03-11 14:43:03 +01:00
henryruhs 5892460c3d Add the final license 2025-03-11 14:43:03 +01:00
harisreedhar bb0e3b4a8a remove liveportrait and update requirements 2025-03-11 14:43:03 +01:00
harisreedhar fcb3390796 again again cleaning 2025-03-11 14:43:03 +01:00
harisreedhar 3b8b6442fc add infer and some cleaning 2025-03-11 14:43:03 +01:00
harisreedhar cfcf0ee2bd add glob image pattern to config 2025-03-11 14:43:03 +01:00
harisreedhar 4260bd6c28 fix type 2025-03-11 14:43:03 +01:00
harisreedhar d018dd0633 some cleaning and add exporting 2025-03-11 14:43:03 +01:00
harisreedhar 798ba48a52 clean data_loader 2025-03-11 14:43:03 +01:00
harisreedhar e45f46d355 cleaning 2025-03-11 14:43:03 +01:00
harisreedhar 008a221f55 cleaning 2025-03-11 14:43:03 +01:00
henryruhs 10826558b4 Fix indent to tabs 2025-03-11 14:43:03 +01:00
henryruhs 23ac63d55b Fix more typing 2025-03-11 14:43:03 +01:00
henryruhs 2bbba3563b Use mypy and flake8 for face swapper 2025-03-11 14:43:03 +01:00
henryruhs e3273221c9 Delete unused file 2025-03-11 14:43:03 +01:00
harisreedhar 9e1c71b498 clean 2025-03-11 14:43:03 +01:00
harisreedhar ef313042c6 clean generator, discriminator and typing 2025-03-11 14:43:03 +01:00
harisreedhar a1fd382659 ugly training code 2025-03-11 14:43:03 +01:00
harisreedhar 650268c06b ugly training code 2025-03-11 14:43:03 +01:00
harisreedhar fe28b6fffe ugly training code 2025-03-11 14:43:02 +01:00
harisreedhar e6c2a64256 ugly training code 2025-03-11 14:43:02 +01:00
henryruhs 7bef17b551 Fix mypy 2025-03-11 14:43:02 +01:00
harisreedhar 8e53c6bc9f new swapper 2025-03-11 14:43:02 +01:00
henryruhs a461d9c389 Remove shebang 2025-01-15 09:28:00 +01:00
henryruhs f3e1c3feaa Improve naming for arcface converter trainer 2025-01-15 09:26:21 +01:00
Henry Ruhs abce63007d Update FUNDING.yml 2024-12-29 01:08:19 +01:00
henryruhs 5fa53dabf2 Move license to model directory 2024-12-06 11:55:40 +01:00
Henry Ruhs bd4725b52f Update dependencies 2024-11-11 21:54:04 +01:00
Henry Ruhs ef38a0e23f Modernize .gitignore 2024-10-30 08:21:30 +01:00
Henry Ruhs e1348dd596 Update FUNDING.yml 2024-10-27 22:30:08 +01:00
54 changed files with 2187 additions and 385 deletions
+2 -4
View File
@@ -1,7 +1,5 @@
[flake8] [flake8]
select = E3, E4, F, I1, I2 select = E22, E23, E24, E27, E3, E4, E7, F, I1, I2
plugins = flake8-import-order plugins = flake8-import-order
application_import_names = arcface_converter application_import_names = crossface, hyperswap
import-order-style = pycharm import-order-style = pycharm
per-file-ignores = preparing.py:E402
+1 -2
View File
@@ -1,2 +1 @@
github: henryruhs custom: [ buymeacoffee.com/facefusion, ko-fi.com/facefusion ]
custom: [ buymeacoffee.com/henryruhs, paypal.me/henryruhs ]
Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

+16 -4
View File
@@ -8,12 +8,24 @@ jobs:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python 3.10 - name: Set up Python 3.12
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: '3.10' python-version: '3.12'
- run: pip install flake8 - run: pip install flake8
- run: pip install flake8-import-order - run: pip install flake8-import-order
- run: pip install mypy - run: pip install mypy
- run: flake8 arcface_converter - run: flake8 crossface hyperswap
- run: mypy arcface_converter - run: mypy crossface hyperswap
test:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
- run: pip install torch torchvision
- run: pip install pytest
- run: PYTHONPATH=/home/runner/work/facefusion-labs/facefusion-labs pytest
+9
View File
@@ -1,2 +1,11 @@
__pycache__
.assets
.claude
.datasets
.idea .idea
.inputs
.exports
.logs
.models
.outputs
.vscode .vscode
-3
View File
@@ -1,3 +0,0 @@
MIT license
Copyright (c) 2024 Henry Ruhs
-1
View File
@@ -4,4 +4,3 @@ FaceFusion Labs
> Industry leading face manipulation platform. > Industry leading face manipulation platform.
[![Build Status](https://img.shields.io/github/actions/workflow/status/facefusion/facefusion-labs/ci.yml.svg?branch=master)](https://github.com/facefusion/facefusion-labs/actions?query=workflow:ci) [![Build Status](https://img.shields.io/github/actions/workflow/status/facefusion/facefusion-labs/ci.yml.svg?branch=master)](https://github.com/facefusion/facefusion-labs/actions?query=workflow:ci)
![License](https://img.shields.io/badge/license-MIT-green)
-91
View File
@@ -1,91 +0,0 @@
ArcFace Converter
=================
> Convert face embeddings between various ArcFace models.
Preview
-------
![Preview](https://raw.githubusercontent.com/facefusion/facefusion-labs/master/.github/preview_arcface_converter.png?sanitize=true)
Installation
------------
```
pip install -r requirements.txt
```
Example
-------
This example utilizes the MegaFace dataset to train an ArcFace Converter for SimSwap.
```
[preparing.dataset]
dataset_path = datasets/megaface/train.rec
crop_size = 112
process_limit = 650000
[preparing.model]
source_path = models/arcface_w600k_r50.onnx
target_path = models/arcface_simswap.onnx
[preparing.input]
directory_path = inputs
source_path = inputs/arcface_w600k_r50.npy
target_path = inputs/arcface_simswap.npy
[training.loader]
split_ratio = 0.8
batch_size = 51200
num_workers = 8
[training.trainer]
max_epochs = 4096
[training.output]
directory_path = outputs
file_pattern = arcface_converter_simswap_{epoch:02d}_{val_loss:.4f}
[exporting]
directory_path = exports
source_path = outputs/last.ckpt
target_path = exports/arcface_converter_simswap.onnx
opset_version = 15
[execution]
providers = CUDAExecutionProvider
```
Preparing
---------
Prepare the face embedding pairs.
```
python prepare.py
```
Training
--------
Train the ArcFace converter model.
```
python train.py
```
Exporting
---------
Export the model to ONNX.
```
python export.py
```
-22
View File
@@ -1,22 +0,0 @@
import configparser
from os import makedirs
import torch
from .training import ArcFaceConverterTrainer
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def export() -> None:
directory_path = CONFIG.get('exporting', 'directory_path')
source_path = CONFIG.get('exporting', 'source_path')
target_path = CONFIG.get('exporting', 'target_path')
opset_version = CONFIG.getint('exporting', 'opset_version')
makedirs(directory_path, exist_ok = True)
model = ArcFaceConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
model.eval()
input_tensor = torch.randn(1, 512)
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
-21
View File
@@ -1,21 +0,0 @@
import torch
import torch.nn as nn
from torch import Tensor
class ArcFaceConverter(nn.Module):
def __init__(self) -> None:
super(ArcFaceConverter, self).__init__()
self.fc1 = nn.Linear(512, 1024)
self.fc2 = nn.Linear(1024, 2048)
self.fc3 = nn.Linear(2048, 1024)
self.fc4 = nn.Linear(1024, 512)
self.activation = nn.LeakyReLU()
def forward(self, inputs : Tensor) -> Tensor:
norm_inputs = inputs / torch.norm(inputs)
outputs = self.activation(self.fc1(norm_inputs))
outputs = self.activation(self.fc2(outputs))
outputs = self.activation(self.fc3(outputs))
outputs = self.fc4(outputs)
return outputs
-81
View File
@@ -1,81 +0,0 @@
#!/usr/bin/env python3
import configparser
from os import makedirs
from os.path import isfile
from typing import List
import numpy
numpy.bool = numpy.bool_
from mxnet.io import ImageRecordIter
from onnxruntime import InferenceSession
from tqdm import tqdm
from .typing import Embedding, EmbeddingPairs, VisionFrame
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def prepare_crop_vision_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
crop_vision_frame = crop_vision_frame.astype(numpy.float32) / 255
crop_vision_frame = (crop_vision_frame - 0.5) * 2
return crop_vision_frame
def create_inference_session(model_path : str, execution_providers : List[str]) -> InferenceSession:
inference_session = InferenceSession(model_path, providers = execution_providers)
return inference_session
def forward(inference_session : InferenceSession, crop_vision_frame : VisionFrame) -> Embedding:
embedding = inference_session.run(None,
{
'input': crop_vision_frame
})[0]
return embedding
def process_embeddings(dataset_reader : ImageRecordIter, source_inference_session : InferenceSession, target_inference_session : InferenceSession) -> EmbeddingPairs:
dataset_process_limit = CONFIG.getint('preparing.dataset', 'process_limit')
embedding_pairs = []
with tqdm(total = dataset_process_limit) as progress:
for batch in dataset_reader:
crop_vision_frame = batch.data[0].asnumpy()
crop_vision_frame = prepare_crop_vision_frame(crop_vision_frame)
source_embedding = forward(source_inference_session, crop_vision_frame)
target_embedding = forward(target_inference_session, crop_vision_frame)
embedding_pairs.append([ source_embedding, target_embedding ])
progress.update()
if progress.n == dataset_process_limit:
return numpy.concatenate(embedding_pairs, axis = 1).T
return numpy.concatenate(embedding_pairs, axis = 1).T
def prepare() -> None:
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
dataset_crop_size = CONFIG.getint('preparing.dataset', 'crop_size')
model_source_path = CONFIG.get('preparing.model', 'source_path')
model_target_path = CONFIG.get('preparing.model', 'target_path')
input_directory_path = CONFIG.get('preparing.input', 'directory_path')
input_source_path = CONFIG.get('preparing.input', 'source_path')
input_target_path = CONFIG.get('preparing.input', 'target_path')
execution_providers = CONFIG.get('execution', 'providers').split(' ')
makedirs(input_directory_path, exist_ok = True)
if isfile(dataset_path) and isfile(model_source_path) and isfile(model_target_path):
dataset_reader = ImageRecordIter(
path_imgrec = dataset_path,
data_shape = (3, dataset_crop_size, dataset_crop_size),
batch_size = 1,
shuffle = False
)
source_inference_session = create_inference_session(model_source_path, execution_providers)
target_inference_session = create_inference_session(model_target_path, execution_providers)
embedding_pairs = process_embeddings(dataset_reader, source_inference_session, target_inference_session)
numpy.save(input_source_path, embedding_pairs[..., 0].T)
numpy.save(input_target_path, embedding_pairs[..., 1].T)
-118
View File
@@ -1,118 +0,0 @@
#!/usr/bin/env python3
import configparser
from typing import Any, Tuple
import numpy
import pytorch_lightning
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.tuner.tuning import Tuner
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
from .model import ArcFaceConverter
from .typing import Batch, Loader
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class ArcFaceConverterTrainer(pytorch_lightning.LightningModule):
def __init__(self) -> None:
super(ArcFaceConverterTrainer, self).__init__()
self.model = ArcFaceConverter()
self.loss_fn = torch.nn.MSELoss()
self.lr = 0.001
def forward(self, source_embedding : Tensor) -> Tensor:
return self.model(source_embedding)
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source, target = batch
output = self(source)
loss = self.loss_fn(output, target)
self.log('train_loss', loss, prog_bar = True, logger = True)
return loss
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source, target = batch
output = self(source)
loss = self.loss_fn(output, target)
self.log('val_loss', loss, prog_bar = True, logger = True)
return loss
def configure_optimizers(self) -> Any:
optimizer = torch.optim.Adam(self.parameters(), lr = self.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return\
{
'optimizer': optimizer,
'lr_scheduler':
{
'scheduler': scheduler,
'monitor': 'train_loss',
'interval': 'epoch',
'frequency': 1
}
}
def create_loaders() -> Tuple[Loader, Loader]:
loader_batch_size = CONFIG.getint('training.loader', 'batch_size')
loader_num_workers = CONFIG.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset()
training_loader = DataLoader(training_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = True, pin_memory = True)
validation_loader = DataLoader(validate_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = False, pin_memory = True)
return training_loader, validation_loader
def split_dataset() -> Tuple[Dataset[Any], Dataset[Any]]:
input_source_path = CONFIG.get('preparing.input', 'source_path')
input_target_path = CONFIG.get('preparing.input', 'target_path')
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
source_input = torch.from_numpy(numpy.load(input_source_path)).float()
target_input = torch.from_numpy(numpy.load(input_target_path)).float()
dataset = TensorDataset(source_input, target_input)
dataset_size = len(dataset)
training_size = int(loader_split_ratio * len(dataset))
validation_size = int(dataset_size - training_size)
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
return training_dataset, validate_dataset
def create_trainer() -> Trainer:
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
output_directory_path = CONFIG.get('training.output', 'directory_path')
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
return Trainer(
max_epochs = trainer_max_epochs,
callbacks =
[
ModelCheckpoint(
monitor = 'train_loss',
dirpath = output_directory_path,
filename = output_file_pattern,
every_n_epochs = 10,
save_top_k = 3,
save_last = True
)
],
enable_progress_bar = True,
log_every_n_steps = 2
)
def train() -> None:
trainer = create_trainer()
training_loader, validation_loader = create_loaders()
model = ArcFaceConverterTrainer()
tuner = Tuner(trainer)
tuner.lr_find(model, training_loader, validation_loader)
trainer.fit(model, training_loader, validation_loader)
-13
View File
@@ -1,13 +0,0 @@
from typing import Any, Tuple
from numpy.typing import NDArray
from torch import Tensor
from torch.utils.data import DataLoader
Batch = Tuple[Tensor, Tensor]
Loader = DataLoader[Tuple[Tensor, ...]]
Embedding = NDArray[Any]
EmbeddingPairs = NDArray[Any]
FaceLandmark5 = NDArray[Any]
VisionFrame = NDArray[Any]
+3
View File
@@ -0,0 +1,3 @@
OpenRAIL-MS license
Copyright (c) 2025 Henry Ruhs
+104
View File
@@ -0,0 +1,104 @@
CrossFace
=========
> Seamless face embedding across various models.
![License](https://img.shields.io/badge/license-OpenRAIL--MS-green)
Preview
-------
![Preview](https://raw.githubusercontent.com/facefusion/facefusion-labs/master/.github/previews/crossface.png?sanitize=true)
Installation
------------
```
pip install -r requirements.txt
```
Setup
-----
This `config.ini` utilizes the MegaFace dataset to train the CrossFace model for SimSwap.
```
[training.dataset]
file_pattern = .datasets/megaface/**/*.jpg
```
```
[training.loader]
batch_size = 128
num_workers = 8
split_ratio = 0.95
```
```
[training.model]
source_path = .models/arcface_w600k_r50.pt
target_path = .models/arcface_simswap.pt
```
```
[training.trainer]
max_epochs = 4096
strategy = auto
precision = 16-mixed
```
```
[training.optimizer]
learning_rate = 0.001
```
```
[training.logger]
logger_path = .logs
logger_name = crossface_simswap
```
```
[training.output]
directory_path = .outputs
file_pattern = crossface_simswap_{epoch}_{step}
resume_path = .outputs/last.ckpt
```
```
[exporting]
directory_path = .exports
source_path = .outputs/last.ckpt
target_path = .exports/crossface_simswap.onnx
ir_version = 10
opset_version = 15
```
Training
--------
Train the model.
```
python train.py
```
Launch the TensorBoard to monitor the training.
```
tensorboard --logdir .logs
```
Exporting
---------
Export the model to ONNX.
```
python export.py
```
@@ -1,34 +1,35 @@
[preparing.dataset] [training.dataset]
dataset_path = file_pattern =
crop_size =
process_limit =
[preparing.model]
source_path =
target_path =
[preparing.input]
directory_path =
source_path =
target_path =
[training.loader] [training.loader]
split_ratio =
batch_size = batch_size =
num_workers = num_workers =
split_ratio =
[training.model]
source_path =
target_path =
[training.trainer] [training.trainer]
max_epochs = max_epochs =
strategy =
precision =
[training.optimizer]
learning_rate =
[training.logger]
logger_path =
logger_name =
[training.output] [training.output]
directory_path = directory_path =
file_pattern = file_pattern =
resume_path =
[exporting] [exporting]
directory_path = directory_path =
source_path = source_path =
target_path = target_path =
ir_version =
opset_version = opset_version =
[execution]
providers =
+34
View File
@@ -0,0 +1,34 @@
import glob
from configparser import ConfigParser
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
from .types import Batch
class StaticDataset(Dataset[Tensor]):
def __init__(self, config_parser : ConfigParser) -> None:
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
self.file_paths = glob.glob(self.config_file_pattern)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
file_path = self.file_paths[index]
temp_tensor = io.read_image(file_path)
return self.transforms(temp_tensor)
def __len__(self) -> int:
return len(self.file_paths)
@staticmethod
def compose_transforms() -> transforms:
return transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
+23
View File
@@ -0,0 +1,23 @@
import os
from configparser import ConfigParser
import torch
from .training import CrossFaceTrainer
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
def export() -> None:
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
config_target_path = CONFIG_PARSER.get('exporting', 'target_path')
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
os.makedirs(config_directory_path, exist_ok = True)
model = CrossFaceTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
model.ir_version = torch.tensor(config_ir_version)
input_tensor = torch.randn(1, 512)
torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version)
View File
+37
View File
@@ -0,0 +1,37 @@
from torch import Tensor, nn
class CrossFace(nn.Module):
def __init__(self) -> None:
super().__init__()
self.sequence = self.create_sequence()
self.linear = nn.Linear(512, 512)
self.apply(init_weight)
@staticmethod
def create_sequence() -> nn.Sequential:
return nn.Sequential(
nn.Linear(512, 1024),
nn.LayerNorm(1024),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(1024, 2048),
nn.LayerNorm(2048),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(2048, 1024),
nn.LayerNorm(1024),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(1024, 512)
)
def forward(self, input_tensor : Tensor) -> Tensor:
temp_tensor = nn.functional.normalize(input_tensor, p = 2, dim = -1)
return self.sequence(temp_tensor) + 0.2 * self.linear(temp_tensor)
def init_weight(module : nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
nn.init.constant_(module.bias, 0.01)
+145
View File
@@ -0,0 +1,145 @@
import os
import shutil
from configparser import ConfigParser
from pathlib import Path
from typing import Tuple, cast
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.utils.data import Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import StaticDataset
from .models.crossface import CrossFace
from .types import Batch, Embedding, OptimizerSet, TrainerPrecision, TrainerStrategy
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
class CrossFaceTrainer(LightningModule):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_source_path = config_parser.get('training.model', 'source_path')
self.config_target_path = config_parser.get('training.model', 'target_path')
self.config_learning_rate = config_parser.getfloat('training.optimizer', 'learning_rate')
self.crossface = CrossFace()
self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval()
self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval()
self.mse_loss = nn.MSELoss()
def forward(self, source_embedding : Embedding) -> Embedding:
return self.crossface(source_embedding)
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
with torch.no_grad():
source_embedding = self.source_embedder(batch)
target_embedding = self.target_embedder(batch)
output_embedding = self(source_embedding)
training_loss = self.mse_loss(output_embedding, target_embedding)
self.log('training_loss', training_loss, prog_bar = True)
return training_loss
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
with torch.no_grad():
source_embedding = self.source_embedder(batch)
output_embedding = self(source_embedding)
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
return validation_score
def configure_optimizers(self) -> OptimizerSet:
optimizer = torch.optim.AdamW(self.parameters(), lr = self.config_learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
optimizer_set =\
{
'optimizer': optimizer,
'lr_scheduler':
{
'scheduler': scheduler,
'monitor': 'training_loss',
'interval': 'epoch',
'frequency': 1
}
}
return optimizer_set
class ModelWithConfigCheckpoint(ModelCheckpoint):
def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None:
super()._save_checkpoint(trainer, checkpoint_path)
config_path = Path(checkpoint_path).with_suffix('.ini')
shutil.copy('config.ini', config_path)
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
dataset_size = len(dataset) # type:ignore[arg-type]
training_size = int(dataset_size * config_split_ratio)
validation_size = int(dataset_size - training_size)
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
return training_dataset, validate_dataset
def create_trainer() -> Trainer:
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy'))
config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision'))
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
logger = TensorBoardLogger(config_logger_path, config_logger_name)
return Trainer(
logger = logger,
log_every_n_steps = 10,
max_epochs = config_max_epochs,
strategy = config_strategy,
precision = config_precision,
callbacks =
[
ModelWithConfigCheckpoint(
monitor = 'training_loss',
dirpath = config_directory_path,
filename = config_file_pattern,
every_n_epochs = 1000,
save_top_k = 5,
save_last = True
),
StochasticWeightAveraging(swa_lrs = 1e-2)
],
val_check_interval = 1000
)
def train() -> None:
config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path')
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = StaticDataset(CONFIG_PARSER)
training_loader, validation_loader = create_loaders(dataset)
crossface_trainer = CrossFaceTrainer(CONFIG_PARSER)
trainer = create_trainer()
if os.path.exists(config_resume_path):
trainer.fit(crossface_trainer, training_loader, validation_loader, ckpt_path = config_resume_path)
else:
trainer.fit(crossface_trainer, training_loader, validation_loader)
+11
View File
@@ -0,0 +1,11 @@
from typing import Any, Literal, TypeAlias
from torch import Tensor
Batch : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
OptimizerSet : TypeAlias = Any
TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true']
TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16']
+3
View File
@@ -0,0 +1,3 @@
ResearchRAIL-MS license
Copyright (c) 2025 Henry Ruhs
+189
View File
@@ -0,0 +1,189 @@
HyperSwap
=========
> Hyper accurate face swapping for everyone.
![License](https://img.shields.io/badge/license-ResearchRAIL--MS-orange)
Preview
-------
![Preview](https://raw.githubusercontent.com/facefusion/facefusion-labs/master/.github/previews/hyperswap.png?sanitize=true)
Installation
------------
```
pip install -r requirements.txt
```
Setup
-----
This `config.ini` utilizes the VGGFace2 dataset to train the HyperSwap model.
```
[training.dataset]
file_pattern = .datasets/vggface2/**/*.jpg
convert_template = vggfacehq_512_to_arcface_128
multiplier = 1
transform_size = 256
usage_mode = both
batch_mode = same
batch_ratio = 0.2
```
```
[training.loader]
batch_size = 8
num_workers = 8
split_ratio = 0.9995
```
```
[training.model]
generator_embedder_path = .models/blendface.pt
loss_embedder_path = .models/arcface.pt
gazer_path = .models/gazer.pt
face_masker_path = .models/face_masker.pt
```
```
[training.model.generator]
source_channels = 512
output_size = 256
num_blocks = 2
```
```
[training.model.discriminator]
input_channels = 3
num_filters = 64
num_layers = 5
num_discriminators = 3
kernel_size = 4
```
```
[training.model.masker]
input_channels = 67
output_channels = 1
num_filters = 16
```
```
[training.losses]
adversarial_weight = 1.0
cycle_weight = 1.0
feature_weight = 10.0
reconstruction_weight = 10.0
identity_weight = 20.0
gaze_weight = 0.05
mask_weight = 5.0
```
```
[training.trainer]
accumulate_size = 4
discriminator_ratio = 0.4
gradient_clip = 20.0
max_epochs = 50
strategy = auto
precision = 16-mixed
sync_batchnorm = false
preview_frequency = 100
```
```
[training.modifier]
mask_factor = 0.01
noise_factor = 0.05
```
```
[training.optimizer.generator]
learning_rate = 0.0004
momentum = 0.5
scheduler_factor = 0.7
scheduler_patience = 2000
```
```
[training.optimizer.discriminator]
learning_rate = 0.0002
momentum = 0.5
scheduler_factor = 0.7
scheduler_patience = 2000
```
```
[training.logger]
logger_path = .logs
logger_name = hyperswap
```
```
[training.output]
directory_path = .outputs
file_pattern = hyperswap_{epoch}_{step}
resume_path = .outputs/last.ckpt
```
```
[exporting]
directory_path = .exports
source_path = .outputs/last.ckpt
target_path = .exports/hyperswap_256.onnx
target_size = 256
ir_version = 10
opset_version = 15
precision = full
```
```
[inferencing]
generator_path = .outputs/last.ckpt
embedder_path = .models/arcface.pt
source_path = .assets/source.jpg
target_path = .assets/target.jpg
output_path = .outputs/output.jpg
```
Training
--------
Train the model.
```
python train.py
```
Launch the TensorBoard to monitor the training.
```
tensorboard --logdir .logs
```
Exporting
---------
Export the model to ONNX.
```
python export.py
```
Inferencing
-----------
Inference the model.
```
python infer.py
```
View File
+96
View File
@@ -0,0 +1,96 @@
[training.dataset]
file_pattern =
convert_template =
multiplier =
transform_size =
usage_mode =
batch_mode =
batch_ratio =
[training.loader]
batch_size =
num_workers =
split_ratio =
[training.model]
generator_embedder_path =
loss_embedder_path =
gazer_path =
face_masker_path =
[training.model.generator]
source_channels =
output_size =
num_blocks =
[training.model.discriminator]
input_channels =
num_filters =
num_layers =
num_discriminators =
kernel_size =
[training.model.masker]
input_channels =
output_channels =
num_filters =
[training.losses]
adversarial_weight =
cycle_weight =
feature_weight =
reconstruction_weight =
identity_weight =
gaze_weight =
mask_weight =
[training.trainer]
accumulate_size =
discriminator_ratio =
gradient_clip =
max_epochs =
strategy =
precision =
sync_batchnorm =
preview_frequency =
[training.modifier]
mask_factor =
noise_factor =
[training.optimizer.generator]
learning_rate =
momentum =
scheduler_factor =
scheduler_patience =
[training.optimizer.discriminator]
learning_rate =
momentum =
scheduler_factor =
scheduler_patience =
[training.logger]
logger_path =
logger_name =
[training.output]
directory_path =
file_pattern =
resume_path =
[exporting]
directory_path =
source_path =
target_path =
target_size =
ir_version =
opset_version =
precision =
[inferencing]
generator_path =
embedder_path =
source_path =
target_path =
output_path =
@@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from src.preparing import prepare from src.exporting import export
if __name__ == '__main__': if __name__ == '__main__':
prepare() export()
+6
View File
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
from src.inferencing import infer
if __name__ == '__main__':
infer()
View File
+154
View File
@@ -0,0 +1,154 @@
import os
import random
from configparser import ConfigParser
from typing import cast
import albumentations
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
from .helper import convert_tensor, resolve_static_file_pattern
from .types import Batch, BatchMode, ConvertTemplate, UsageMode
class DynamicDataset(Dataset[Tensor]):
def __init__(self, config_parser : ConfigParser) -> None:
self.config_file_pattern = config_parser.get('training.dataset.current', 'file_pattern')
self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset.current', 'convert_template'))
self.config_transform_size = config_parser.getint('training.dataset.current', 'transform_size')
self.config_usage_mode = cast(UsageMode, config_parser.get('training.dataset.current', 'usage_mode'))
self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset.current', 'batch_mode'))
self.config_batch_ratio = config_parser.getfloat('training.dataset.current', 'batch_ratio')
self.config_parser = config_parser
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
file_path = resolve_static_file_pattern(self.config_file_pattern)[index]
if random.random() < self.config_batch_ratio:
if self.config_batch_mode == 'equal':
return self.prepare_equal_batch(file_path)
if self.config_batch_mode == 'same':
return self.prepare_same_batch(file_path)
if self.config_usage_mode == 'source':
return self.prepare_source_batch(file_path)
if self.config_usage_mode == 'target':
return self.prepare_target_batch(file_path)
return self.prepare_different_batch(file_path)
def __len__(self) -> int:
return len(resolve_static_file_pattern(self.config_file_pattern))
def prepare_equal_batch(self, source_path : str) -> Batch:
return self.create_batch(source_path, source_path, self.config_convert_template, self.config_convert_template)
def prepare_same_batch(self, source_path : str) -> Batch:
target_directory_path = os.path.dirname(source_path)
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template)
def prepare_source_batch(self, source_path : str) -> Batch:
config_parser = self.filter_config_by_usage_mode('both')
config_section = random.choice(config_parser.sections())
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
target_path = random.choice(resolve_static_file_pattern(config_file_pattern))
return self.create_batch(source_path, target_path, self.config_convert_template, config_convert_template)
def prepare_target_batch(self, target_path : str) -> Batch:
config_parser = self.filter_config_by_usage_mode('both')
config_section = random.choice(config_parser.sections())
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
source_path = random.choice(resolve_static_file_pattern(config_file_pattern))
return self.create_batch(source_path, target_path, config_convert_template, self.config_convert_template)
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern))
return self.create_batch(source_path, target_path, self.config_convert_template, self.config_convert_template)
def compose_transforms(self) -> transforms:
return transforms.Compose(
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def filter_config_by_usage_mode(self, usage_mode : UsageMode) -> ConfigParser:
config_parser = ConfigParser()
for config_section in self.config_parser.sections():
if config_section.startswith('training.dataset'):
current_usage_mode = cast(UsageMode, self.config_parser.get(config_section, 'usage_mode'))
if current_usage_mode == usage_mode:
config_parser.add_section(config_section)
for key, value in self.config_parser.items(config_section):
config_parser.set(config_section, key, value)
return config_parser
def create_batch(self, source_path : str, target_path : str, source_convert_template : ConvertTemplate, target_convert_template : ConvertTemplate) -> Batch:
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, source_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, target_convert_template)
return source_tensor, target_tensor
@staticmethod
def conditional_convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
if convert_template:
temp_tensor = input_tensor.unsqueeze(0)
return convert_tensor(temp_tensor, convert_template).squeeze(0)
return input_tensor
class AugmentTransform:
def __init__(self) -> None:
self.transforms = self.compose_transforms()
def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.numpy().transpose(1, 2, 0)
return self.transforms(image = temp_tensor).get('image')
@staticmethod
def compose_transforms() -> albumentations.Compose:
return albumentations.Compose(
[
albumentations.HorizontalFlip(p = 0.5),
albumentations.OneOf(
[
albumentations.MotionBlur(),
albumentations.ZoomBlur(max_factor = (1.0, 1.2))
], p = 0.1),
albumentations.OneOf(
[
albumentations.RandomGamma(),
albumentations.RandomBrightnessContrast(),
albumentations.Illumination()
], p = 0.2),
albumentations.OneOf(
[
albumentations.ColorJitter(),
albumentations.RGBShift(),
albumentations.HueSaturationValue()
], p = 0.2),
albumentations.Affine(
translate_percent = (-0.05, 0.05),
scale = (0.95, 1.05),
rotate = (-2, 2),
border_mode = 1,
p = 0.2
)
])
+47
View File
@@ -0,0 +1,47 @@
import os
from configparser import ConfigParser
from typing import Tuple
import torch
from torch import Tensor, nn
from .training import HyperSwapTrainer
from .types import Embedding, Mask, Module
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
class HalfPrecision(nn.Module):
def __init__(self, model : Module) -> None:
super().__init__()
self.model = model.half()
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
source_embedding = source_embedding.half()
target_tensor = target_tensor.half()
output_tensor, output_mask = self.model(source_embedding, target_tensor)
output_tensor = output_tensor.float()
output_mask = output_mask.float()
return output_tensor, output_mask
def export() -> None:
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
config_target_path = CONFIG_PARSER.get('exporting', 'target_path')
config_target_size = CONFIG_PARSER.getint('exporting', 'target_size')
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
config_precision = CONFIG_PARSER.get('exporting', 'precision')
os.makedirs(config_directory_path, exist_ok = True)
model = HyperSwapTrainer.load_from_checkpoint(config_source_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
if config_precision == 'half':
model = HalfPrecision(model).eval()
model.ir_version = torch.tensor(config_ir_version)
source_tensor = torch.randn(1, 512)
target_tensor = torch.randn(1, 3, config_target_size, config_target_size)
torch.onnx.export(model, (source_tensor, target_tensor), config_target_path, input_names = [ 'source', 'target' ], output_names = [ 'output', 'mask' ], opset_version = config_opset_version)
+82
View File
@@ -0,0 +1,82 @@
import glob
from functools import lru_cache
from typing import List
import torch
from torch import Tensor, nn
from .types import ConvertTemplate, ConvertTemplateSet, EmbedderModule, Embedding, Mask, Padding
CONVERT_TEMPLATE_SET : ConvertTemplateSet =\
{
'arcface_128_to_arcface_112_v2': torch.tensor(
[
[ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ],
[ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ]
]),
'ffhq_512_to_arcface_128': torch.tensor(
[
[ 8.50048894e-01, -1.29486822e-04, 1.90956388e-03 ],
[ 1.29486822e-04, 8.50048894e-01, 9.56254653e-02 ]
]),
'vggfacehq_512_to_arcface_128': torch.tensor(
[
[ 1.01305414, -0.00140513, -0.00585911 ],
[ 0.00140513, 1.01305414, 0.11169602 ]
])
}
def convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
convert_matrix = CONVERT_TEMPLATE_SET.get(convert_template).repeat(input_tensor.shape[0], 1, 1)
affine_grid = nn.functional.affine_grid(convert_matrix.to(input_tensor.device), list(input_tensor.shape))
output_tensor = nn.functional.grid_sample(input_tensor, affine_grid, padding_mode = 'reflection')
return output_tensor
def calculate_face_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = convert_tensor(input_tensor, 'arcface_128_to_arcface_112_v2')
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
crop_tensor[:, :, :padding[0], :] = 0
crop_tensor[:, :, 112 - padding[1]:, :] = 0
crop_tensor[:, :, :, :padding[2]] = 0
crop_tensor[:, :, :, 112 - padding[3]:] = 0
face_embedding = embedder(crop_tensor)
face_embedding = nn.functional.normalize(face_embedding, p = 2)
return face_embedding
def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device)
overlay_tensor[:, 2, :, :] = 1
input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8)
output_tensor = input_tensor * (1 - input_mask) + overlay_tensor * input_mask
return output_tensor
def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
padding = int(input_tensor.shape[2] * factor + 0.5)
kernel_size = 1 + 2 * padding
temp_tensor = nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
output_tensor = nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
return output_tensor
def erode_mask(input_tensor : Tensor, factor : float) -> Tensor:
padding = int(input_tensor.shape[2] * factor + 0.5)
kernel_size = 1 + 2 * padding
temp_tensor = 1 - nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
output_tensor = 1 - nn.functional.max_pool2d(temp_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
return output_tensor
def apply_noise(input_tensor : Tensor, factor : float) -> Tensor:
noise_tensor = torch.randn_like(input_tensor) * factor
output_tensor = input_tensor + noise_tensor
return output_tensor
@lru_cache(maxsize = None)
def resolve_static_file_pattern(file_pattern : str) -> List[str]:
return sorted(glob.glob(file_pattern))
+27
View File
@@ -0,0 +1,27 @@
import configparser
import torch
from torchvision import io
from .helper import calculate_face_embedding
from .training import HyperSwapTrainer
CONFIG_PARSER = configparser.ConfigParser()
CONFIG_PARSER.read('config.ini')
def infer() -> None:
config_generator_path = CONFIG_PARSER.get('inferencing', 'generator_path')
config_embedder_path = CONFIG_PARSER.get('inferencing', 'embedder_path')
config_source_path = CONFIG_PARSER.get('inferencing', 'source_path')
config_target_path = CONFIG_PARSER.get('inferencing', 'target_path')
config_output_path = CONFIG_PARSER.get('inferencing', 'output_path')
generator = HyperSwapTrainer.load_from_checkpoint(config_generator_path, config_parser = CONFIG_PARSER, map_location ='cpu').eval()
embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval()
source_tensor = io.read_image(config_source_path)
target_tensor = io.read_image(config_target_path)
source_embedding = calculate_face_embedding(embedder, source_tensor, (0, 0, 0, 0))
output_tensor, _ = generator(source_embedding, target_tensor)
io.write_jpeg(output_tensor, config_output_path)
View File
+35
View File
@@ -0,0 +1,35 @@
from configparser import ConfigParser
from typing import List
from torch import Tensor, nn
from ..networks.nld import NLD
class Discriminator(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators')
self.config_parser = config_parser
self.discriminators = self.create_discriminators()
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
def create_discriminators(self) -> nn.ModuleList:
discriminators = nn.ModuleList()
for _ in range(self.config_num_discriminators):
discriminator = NLD(self.config_parser).sequences
discriminators.append(discriminator)
return discriminators
def forward(self, input_tensor : Tensor) -> List[Tensor]:
temp_tensor = input_tensor
output_tensors = []
for discriminator in self.discriminators:
output_tensor = discriminator(temp_tensor)
output_tensors.append(output_tensor)
temp_tensor = self.avg_pool(temp_tensor)
return output_tensors
+42
View File
@@ -0,0 +1,42 @@
from configparser import ConfigParser
from typing import Tuple
from torch import Tensor, nn
from ..networks.aad import AAD
from ..networks.masknet import MaskNet
from ..networks.unet import UNet
from ..types import Embedding, Feature, Mask
class Generator(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.encoder = UNet(config_parser)
self.generator = AAD(config_parser)
self.masker = MaskNet(config_parser)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
self.masker.apply(init_weight)
def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask]:
output_tensor = self.generator(source_embedding, target_features)
target_feature = target_features[-1]
output_mask = self.masker(target_tensor, target_feature)
output_tensor = output_tensor * output_mask + target_tensor * (1 - output_mask)
return output_tensor, output_mask
def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]:
return self.encoder(input_tensor)
def init_weight(module : nn.Module) -> None:
if isinstance(module, nn.Linear):
module.weight.data.normal_(std = 0.001)
module.bias.data.zero_()
if isinstance(module, nn.Conv2d):
nn.init.xavier_normal_(module.weight.data)
if isinstance(module, nn.ConvTranspose2d):
nn.init.xavier_normal_(module.weight.data)
+192
View File
@@ -0,0 +1,192 @@
from configparser import ConfigParser
from typing import List, Tuple
import torch
from pytorch_msssim import ssim
from torch import Tensor, nn
from torchvision import transforms
from ..helper import calculate_face_embedding, dilate_mask
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
class DiscriminatorLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, discriminator_real_tensors : List[Tensor], discriminator_fake_tensors : List[Tensor]) -> Loss:
positive_tensors = []
negative_tensors = []
for discriminator_real_tensor in discriminator_real_tensors:
positive_tensor = torch.relu(1 - discriminator_real_tensor).mean(dim = [ 1, 2, 3 ])
positive_tensors.append(positive_tensor)
for discriminator_fake_tensor in discriminator_fake_tensors:
negative_tensor = torch.relu(discriminator_fake_tensor + 1).mean(dim = [ 1, 2, 3 ])
negative_tensors.append(negative_tensor)
positive_loss = torch.stack(positive_tensors).mean()
negative_loss = torch.stack(negative_tensors).mean()
discriminator_loss = (positive_loss + negative_loss) * 0.5
return discriminator_loss
class AdversarialLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_adversarial_weight = config_parser.getfloat('training.losses', 'adversarial_weight')
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Loss, Loss]:
temp_tensors = []
for discriminator_output_tensor in discriminator_output_tensors:
temp_tensor = torch.relu(1 - discriminator_output_tensor).mean(dim = [ 1, 2, 3 ]).mean()
temp_tensors.append(temp_tensor)
adversarial_loss = torch.stack(temp_tensors).mean()
weighted_adversarial_loss = adversarial_loss * self.config_adversarial_weight
return adversarial_loss, weighted_adversarial_loss
class CycleLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight')
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_feature, output_feature in zip(target_features, cycle_features):
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
temp_tensors.append(temp_tensor)
feature_loss = torch.stack(temp_tensors).mean()
reconstruction_loss = self.l1_loss(target_tensor, cycle_tensor)
cycle_loss = (feature_loss + reconstruction_loss) * 0.5
weighted_feature_loss = cycle_loss * self.config_cycle_weight
return cycle_loss, weighted_feature_loss
class FeatureLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_feature_weight = config_parser.getfloat('training.losses', 'feature_weight')
def forward(self, target_features : Tuple[Feature, ...], output_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_feature, output_feature in zip(target_features, output_features):
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
temp_tensors.append(temp_tensor)
feature_loss = torch.stack(temp_tensors).mean() * 0.5
weighted_feature_loss = feature_loss * self.config_feature_weight
return feature_loss, weighted_feature_loss
class ReconstructionLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
super().__init__()
self.config_reconstruction_weight = config_parser.getfloat('training.losses', 'reconstruction_weight')
self.embedder = embedder
self.mse_loss = nn.MSELoss()
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
with torch.no_grad():
source_embedding = calculate_face_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calculate_face_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = 2.0)
visual_loss = (visual_loss * has_similar_identity).mean()
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight
return reconstruction_loss, weighted_reconstruction_loss
class IdentityLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
super().__init__()
self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight')
self.embedder = embedder
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
output_embedding = calculate_face_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calculate_face_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
weighted_identity_loss = identity_loss * self.config_identity_weight
return identity_loss, weighted_identity_loss
class GazeLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None:
super().__init__()
self.config_gaze_weight = config_parser.getfloat('training.losses', 'gaze_weight')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.gazer = gazer
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
output_pitch, output_yaw = self.detect_gaze(output_tensor)
target_pitch, target_yaw = self.detect_gaze(target_tensor)
pitch_loss = self.l1_loss(output_pitch, target_pitch)
yaw_loss = self.l1_loss(output_yaw, target_yaw)
gaze_loss = (pitch_loss + yaw_loss) * 0.5
weighted_gaze_loss = gaze_loss * self.config_gaze_weight
return gaze_loss, weighted_gaze_loss
def detect_gaze(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor]:
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config_output_size).int()
crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]]
crop_tensor = (crop_tensor + 1) * 0.5
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic')
with torch.no_grad():
pitch, yaw = self.gazer(crop_tensor)
return pitch, yaw
class MaskLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
super().__init__()
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
self.config_mask_factor = config_parser.getfloat('training.modifier', 'mask_factor')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.face_masker = face_masker
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
target_mask = self.calculate_mask(target_tensor)
if self.config_mask_factor > 0:
target_mask = dilate_mask(target_mask, self.config_mask_factor)
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
mask_loss = self.mse_loss(target_mask, output_mask)
weighted_mask_loss = mask_loss * self.config_mask_weight
return mask_loss, weighted_mask_loss
def calculate_mask(self, target_tensor : Tensor) -> Tensor:
target_tensor = torch.nn.functional.interpolate(target_tensor, (256, 256), mode = 'bilinear')
target_tensor = (target_tensor.clip(-1, 1) + 1) * 0.5
with torch.no_grad():
output_tensor = self.face_masker(target_tensor)
output_tensor = output_tensor.clamp(0, 1)
output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear')
return output_tensor
View File
+190
View File
@@ -0,0 +1,190 @@
from configparser import ConfigParser
from typing import Tuple
import torch
from torch import Tensor, nn
from ..types import Embedding, Feature
class AAD(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks')
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, 4096)
self.layers = self.create_layers()
def create_layers(self) -> nn.ModuleList:
layers = nn.ModuleList()
if self.config_output_size == 128:
layers.extend(
[
AdaptiveFeatureModulation(1024, 1024, 512, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 256:
layers.extend(
[
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 512:
layers.extend(
[
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1536, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 768, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 1024:
layers.extend(
[
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 3072, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1536, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 768, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
layers.extend(
[
AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(256, 128, 128, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(128, 64, 64, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(64, 3, 64, self.config_source_channels, self.config_num_blocks)
])
return layers
def forward(self, source_embedding : Embedding, target_features : Tuple[Feature, ...]) -> Tensor:
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
for index, layer in enumerate(self.layers[:-1]):
target_feature = target_features[index]
temp_tensor = layer(temp_tensors, source_embedding, target_feature)
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
target_feature = target_features[-1]
temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_feature)
output_tensor = torch.tanh(temp_tensors)
return output_tensor
class AdaptiveFeatureModulation(nn.Module):
def __init__(self, input_channels : int, output_channels : int, target_channels : int, source_channels : int, num_blocks : int) -> None:
super().__init__()
self.context_input_channels = input_channels
self.context_output_channels = output_channels
self.context_target_channels = target_channels
self.context_source_channels = source_channels
self.context_num_blocks = num_blocks
self.primary_layers = self.create_primary_layers()
self.shortcut_layers = self.create_shortcut_layers()
def create_primary_layers(self) -> nn.ModuleList:
primary_layers = nn.ModuleList()
for index in range(self.context_num_blocks):
primary_layers.extend(
[
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
nn.ReLU()
])
if index < self.context_num_blocks - 1:
primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_input_channels, kernel_size = 3, padding = 1, bias = False))
else:
primary_layers.append(nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False))
return primary_layers
def create_shortcut_layers(self) -> nn.ModuleList:
shortcut_layers = nn.ModuleList()
if self.context_input_channels > self.context_output_channels:
shortcut_layers.extend(
[
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
nn.ReLU(),
nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False)
])
return shortcut_layers
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
primary_tensor = input_tensor
for primary_layer in self.primary_layers:
if isinstance(primary_layer, FeatureModulation):
primary_tensor = primary_layer(primary_tensor, source_embedding, target_feature)
else:
primary_tensor = primary_layer(primary_tensor)
if self.context_input_channels > self.context_output_channels:
shortcut_tensor = input_tensor
for shortcut_layer in self.shortcut_layers:
if isinstance(shortcut_layer, FeatureModulation):
shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_feature)
else:
shortcut_tensor = shortcut_layer(shortcut_tensor)
input_tensor = shortcut_tensor
return primary_tensor + input_tensor
class FeatureModulation(nn.Module):
def __init__(self, input_channels : int, target_channels : int, source_channels : int) -> None:
super().__init__()
self.context_input_channels = input_channels
self.conv1 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
self.conv2 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1)
self.linear1 = nn.Linear(source_channels, input_channels)
self.linear2 = nn.Linear(source_channels, input_channels)
self.instance_norm = nn.InstanceNorm2d(input_channels)
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
temp_tensor = self.instance_norm(input_tensor)
source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
source_modulation = source_scale * temp_tensor + source_shift
target_scale = self.conv1(target_feature)
target_shift = self.conv2(target_feature)
target_modulation = target_scale * temp_tensor + target_shift
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
output_tensor = (1 - temp_mask) * target_modulation + temp_mask * source_modulation
return output_tensor
class PixelShuffleUpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1),
nn.PixelShuffle(upscale_factor = 2)
)
def forward(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.view(input_tensor.shape[0], -1, 1, 1)
output_tensor = self.sequences(temp_tensor)
return output_tensor
+111
View File
@@ -0,0 +1,111 @@
from configparser import ConfigParser
import torch
from torch import Tensor, nn
from ..types import Feature, Mask
class MaskNet(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_input_channels = config_parser.getint('training.model.masker', 'input_channels')
self.config_output_channels = config_parser.getint('training.model.masker', 'output_channels')
self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters')
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters)
self.up_samples = self.create_up_samples(self.config_num_filters)
self.bottleneck = BottleNeck(self.config_num_filters * 4)
self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1)
self.sigmoid = nn.Sigmoid()
@staticmethod
def create_down_samples(input_channels : int, num_filters : int) -> nn.ModuleList:
return nn.ModuleList(
[
DownSample(input_channels, num_filters),
DownSample(num_filters, num_filters * 2),
DownSample(num_filters * 2, num_filters * 4)
])
@staticmethod
def create_up_samples(num_filters : int) -> nn.ModuleList:
return nn.ModuleList(
[
UpSample(num_filters * 4, num_filters * 2),
UpSample(num_filters * 2, num_filters),
UpSample(num_filters, num_filters)
])
def forward(self, input_tensor : Tensor, input_feature : Feature) -> Mask:
output_mask = torch.cat([ input_tensor, input_feature ], dim = 1)
for down_sample in self.down_samples:
output_mask = down_sample(output_mask)
output_mask = self.bottleneck(output_mask)
for up_sample in self.up_samples:
output_mask = up_sample(output_mask)
output_mask = self.conv(output_mask)
output_mask = self.sigmoid(output_mask)
return output_mask
class BottleNeck(nn.Module):
def __init__(self, num_filters : int):
super().__init__()
self.sequences = self.create_sequences(num_filters)
self.relu = nn.ReLU()
@staticmethod
def create_sequences(num_filters : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(num_filters),
nn.ReLU(),
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(num_filters),
nn.ReLU()
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.sequences(input_tensor) + input_tensor
output_tensor = self.relu(output_tensor)
return output_tensor
class UpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2),
nn.ReLU()
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.sequences(input_tensor)
return output_tensor
class DownSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(),
nn.MaxPool2d(2)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.sequences(input_tensor)
return output_tensor
+48
View File
@@ -0,0 +1,48 @@
import math
from configparser import ConfigParser
from torch import Tensor, nn
class NLD(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_input_channels = config_parser.getint('training.model.discriminator', 'input_channels')
self.config_num_filters = config_parser.getint('training.model.discriminator', 'num_filters')
self.config_kernel_size = config_parser.getint('training.model.discriminator', 'kernel_size')
self.config_num_layers = config_parser.getint('training.model.discriminator', 'num_layers')
self.layers = self.create_layers()
self.sequences = nn.Sequential(*self.layers)
def create_layers(self) -> nn.ModuleList:
padding = math.ceil((self.config_kernel_size - 1) / 2)
current_filters = self.config_num_filters
layers = nn.ModuleList(
[
nn.Conv2d(self.config_input_channels, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
nn.LeakyReLU(0.2)
])
for _ in range(1, self.config_num_layers):
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2)
]
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2),
nn.Conv2d(current_filters, 1, kernel_size = self.config_kernel_size, padding = padding)
]
return layers
def forward(self, input_tensor : Tensor) -> Tensor:
return self.sequences(input_tensor)
+160
View File
@@ -0,0 +1,160 @@
from configparser import ConfigParser
from typing import Tuple
import torch
from torch import Tensor, nn
from ..types import Feature
class UNet(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.down_samples = self.create_down_samples()
self.up_samples = self.create_up_samples()
def create_down_samples(self) -> nn.ModuleList:
down_samples = nn.ModuleList(
[
DownSample(3, 32),
DownSample(32, 64),
DownSample(64, 128),
DownSample(128, 256),
DownSample(256, 512)
])
if self.config_output_size == 128:
down_samples.extend(
[
DownSample(512, 512)
])
if self.config_output_size == 256:
down_samples.extend(
[
DownSample(512, 1024),
DownSample(1024, 1024)
])
if self.config_output_size == 512:
down_samples.extend(
[
DownSample(512, 1024),
DownSample(1024, 1024),
DownSample(1024, 1024)
])
if self.config_output_size == 1024:
down_samples.extend(
[
DownSample(512, 1024),
DownSample(1024, 2048),
DownSample(2048, 2048),
DownSample(2048, 2048)
])
return down_samples
def create_up_samples(self) -> nn.ModuleList:
up_samples = nn.ModuleList()
if self.config_output_size == 128:
up_samples.extend(
[
UpSample(512, 512),
UpSample(1024, 256)
])
if self.config_output_size == 256:
up_samples.extend(
[
UpSample(1024, 1024),
UpSample(2048, 512),
UpSample(1024, 256)
])
if self.config_output_size == 512:
up_samples.extend(
[
UpSample(1024, 1024),
UpSample(2048, 512),
UpSample(1536, 256),
UpSample(768, 256)
])
if self.config_output_size == 1024:
up_samples.extend(
[
UpSample(2048, 2048),
UpSample(4096, 1024),
UpSample(3072, 512),
UpSample(1536, 256),
UpSample(768, 256)
])
up_samples.extend(
[
UpSample(512, 128),
UpSample(256, 64),
UpSample(128, 32)
])
return up_samples
def forward(self, target_tensor : Tensor) -> Tuple[Feature, ...]:
down_features = []
up_features = []
temp_feature = target_tensor
for down_sample in self.down_samples:
temp_feature = down_sample(temp_feature)
down_features.append(temp_feature)
bottleneck_feature = down_features[-1]
temp_feature = bottleneck_feature
for index, up_sample in enumerate(self.up_samples):
skip_tensor = down_features[-(index + 2)]
temp_feature = up_sample(temp_feature, skip_tensor)
up_features.append(temp_feature)
final_feature = nn.functional.interpolate(temp_feature, scale_factor = 2, mode = 'bilinear', align_corners = False)
return bottleneck_feature, *up_features, final_feature
class UpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.1)
)
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
output_tensor = self.sequences(input_tensor)
output_tensor = torch.cat((output_tensor, skip_tensor), dim = 1)
return output_tensor
class DownSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.1)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.sequences(input_tensor)
return output_tensor
+299
View File
@@ -0,0 +1,299 @@
import os
import shutil
import warnings
from configparser import ConfigParser
from copy import deepcopy
from pathlib import Path
from typing import List, Tuple, cast
import torch
import torchvision
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.utils.data import ConcatDataset, Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import DynamicDataset
from .helper import apply_noise, calculate_face_embedding, erode_mask, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
from .types import Batch, Embedding, Mask, OptimizerSet, TrainerPrecision, TrainerStrategy
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
class HyperSwapTrainer(LightningModule):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_generator_embedder_path = config_parser.get('training.model', 'generator_embedder_path')
self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path')
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_discriminator_ratio = config_parser.getfloat('training.trainer', 'discriminator_ratio')
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
self.config_mask_factor = config_parser.getfloat('training.modifier', 'mask_factor')
self.config_noise_factor = config_parser.getfloat('training.modifier', 'noise_factor')
self.config_generator_learning_rate = config_parser.getfloat('training.optimizer.generator', 'learning_rate')
self.config_generator_momentum = config_parser.getfloat('training.optimizer.generator', 'momentum')
self.config_generator_scheduler_factor = config_parser.getfloat('training.optimizer.generator', 'scheduler_factor')
self.config_generator_scheduler_patience = config_parser.getint('training.optimizer.generator', 'scheduler_patience')
self.config_discriminator_learning_rate = config_parser.getfloat('training.optimizer.discriminator', 'learning_rate')
self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum')
self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor')
self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience')
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval()
self.generator = Generator(config_parser)
self.discriminator = Discriminator(config_parser)
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss(config_parser)
self.cycle_loss = CycleLoss(config_parser)
self.feature_loss = FeatureLoss(config_parser)
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
self.gaze_loss = GazeLoss(config_parser, self.gazer)
self.mask_loss = MaskLoss(config_parser, self.face_masker)
self.automatic_optimization = False
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
with torch.no_grad():
generator_target_features = self.generator.encode_features(target_tensor)
output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
if self.config_mask_factor > 0:
output_mask = erode_mask(output_mask, self.config_mask_factor)
return output_tensor, output_mask
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_generator_learning_rate, betas = (self.config_generator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_discriminator_learning_rate, betas = (self.config_discriminator_momentum, 0.999), weight_decay = 1e-4, eps = 1e-8)
generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(generator_optimizer, mode = 'min', factor = self.config_generator_scheduler_factor, patience = self.config_generator_scheduler_patience, min_lr = 1e-8)
discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(discriminator_optimizer, mode = 'min', factor = self.config_discriminator_scheduler_factor, patience = self.config_discriminator_scheduler_patience, min_lr = 1e-8)
generator_config =\
{
'optimizer': generator_optimizer,
'lr_scheduler':
{
'scheduler': generator_scheduler
}
}
discriminator_config =\
{
'optimizer': discriminator_optimizer,
'lr_scheduler':
{
'scheduler': discriminator_scheduler
}
}
return generator_config, discriminator_config
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
do_update = (batch_index + 1) % self.config_accumulate_size == 0
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
generator_scheduler, discriminator_scheduler = self.lr_schedulers() #type:ignore[attr-defined]
source_embedding = calculate_face_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calculate_face_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
if self.config_noise_factor > 0:
source_embedding = apply_noise(source_embedding, self.config_noise_factor)
source_embedding = nn.functional.normalize(source_embedding, p = 2)
generator_target_features = self.generator.encode_features(target_tensor)
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
generator_output_features = self.generator.encode_features(generator_output_tensor)
cycle_output_tensor, cycle_output_mask = self.generator(target_embedding, generator_output_tensor, generator_output_features)
cycle_output_features = self.generator.encode_features(cycle_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
cycle_loss, weighted_cycle_loss = self.cycle_loss(target_tensor, cycle_output_tensor, generator_target_features, cycle_output_features)
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask)
generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss
if torch.randn(1).item() < self.config_discriminator_ratio:
discriminator_real_tensors = self.discriminator(source_tensor)
else:
discriminator_real_tensors = self.discriminator(target_tensor)
discriminator_fake_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_real_tensors, discriminator_fake_tensors)
self.toggle_optimizer(generator_optimizer)
self.manual_backward(generator_loss)
if do_update:
if self.config_gradient_clip:
self.clip_gradients(
generator_optimizer,
gradient_clip_val = self.config_gradient_clip,
gradient_clip_algorithm = 'norm'
)
generator_optimizer.step()
generator_optimizer.zero_grad()
self.untoggle_optimizer(generator_optimizer)
self.toggle_optimizer(discriminator_optimizer)
self.manual_backward(discriminator_loss)
if do_update:
if self.config_gradient_clip:
self.clip_gradients(
discriminator_optimizer,
gradient_clip_val = self.config_gradient_clip,
gradient_clip_algorithm = 'norm'
)
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
self.untoggle_optimizer(discriminator_optimizer)
if self.global_step % self.config_preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)
self.log('generator_loss', generator_loss, prog_bar = True)
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
self.log('adversarial_loss', adversarial_loss)
self.log('cycle_loss', cycle_loss)
self.log('feature_loss', feature_loss)
self.log('reconstruction_loss', reconstruction_loss)
self.log('identity_loss', identity_loss)
self.log('gaze_loss', gaze_loss)
self.log('mask_loss', mask_loss)
if do_update:
generator_scheduler.step(generator_loss)
discriminator_scheduler.step(discriminator_loss)
return generator_loss
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
source_embedding = calculate_face_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
output_tensor, _ = self.forward(source_embedding, target_tensor)
output_embedding = calculate_face_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0))
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
return validation_score
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, output_mask : Mask) -> None:
preview_limit = 8
preview_cells = []
overlay_tensor = overlay_mask(output_tensor, output_mask)
for source_tensor, target_tensor, output_tensor, overlay_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], overlay_tensor[:preview_limit]):
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2)
preview_cells.append(preview_cell)
preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)
preview_grid = torchvision.utils.make_grid(preview_cells, normalize = True, scale_each = True)
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
class ModelWithConfigCheckpoint(ModelCheckpoint):
def _save_checkpoint(self, trainer : Trainer, checkpoint_path : str) -> None:
super()._save_checkpoint(trainer, checkpoint_path)
config_path = Path(checkpoint_path).with_suffix('.ini')
shutil.copy('config.ini', config_path)
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
dataset_size = len(dataset) # type:ignore[arg-type]
training_size = int(dataset_size * config_split_ratio)
validation_size = int(dataset_size - training_size)
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
return training_dataset, validate_dataset
def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]:
datasets = []
for config_section in config_parser.sections():
if config_section.startswith('training.dataset'):
config_multiplier = config_parser.getint(config_section, 'multiplier')
__config_parser__ = deepcopy(config_parser)
__config_parser__.remove_section(config_section)
__config_parser__.add_section('training.dataset.current')
for key, value in config_parser.items(config_section):
__config_parser__.set('training.dataset.current', key, value)
dynamic_dataset = DynamicDataset(__config_parser__)
datasets.extend([ dynamic_dataset ] * config_multiplier)
return datasets
def create_trainer() -> Trainer:
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
config_strategy = cast(TrainerStrategy, CONFIG_PARSER.get('training.trainer', 'strategy'))
config_precision = cast(TrainerPrecision, CONFIG_PARSER.get('training.trainer', 'precision'))
config_sync_batchnorm = CONFIG_PARSER.getboolean('training.trainer', 'sync_batchnorm')
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
logger = TensorBoardLogger(config_logger_path, config_logger_name)
return Trainer(
logger = logger,
log_every_n_steps = 10,
max_epochs = config_max_epochs,
strategy = config_strategy,
precision = config_precision,
sync_batchnorm = config_sync_batchnorm,
callbacks =
[
ModelWithConfigCheckpoint(
monitor = 'generator_loss',
dirpath = config_directory_path,
filename = config_file_pattern,
every_n_train_steps = 1000,
save_top_k = 5,
save_last = True
)
],
val_check_interval = 1000
)
def train() -> None:
config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path')
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = ConcatDataset(prepare_datasets(CONFIG_PARSER))
training_loader, validation_loader = create_loaders(dataset)
hyperswap_trainer = HyperSwapTrainer(CONFIG_PARSER)
trainer = create_trainer()
if os.path.isfile(config_resume_path):
trainer.fit(hyperswap_trainer, training_loader, validation_loader, ckpt_path = config_resume_path)
else:
trainer.fit(hyperswap_trainer, training_loader, validation_loader)
+28
View File
@@ -0,0 +1,28 @@
from typing import Any, Dict, Literal, Tuple, TypeAlias
from torch import Tensor
from torch.nn import Module
Batch : TypeAlias = Tuple[Tensor, Tensor]
BatchMode = Literal['equal', 'same', 'different']
UsageMode = Literal['source', 'target', 'both']
ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor]
Feature : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
Mask : TypeAlias = Tensor
Loss : TypeAlias = Tensor
Padding : TypeAlias = Tuple[int, int, int, int]
GeneratorModule : TypeAlias = Module
EmbedderModule : TypeAlias = Module
GazerModule : TypeAlias = Module
FaceMaskerModule : TypeAlias = Module
OptimizerSet : TypeAlias = Any
TrainerStrategy = Literal['auto', 'ddp', 'ddp_spawn', 'ddp_find_unused_parameters_true']
TrainerPrecision = Literal['64-true', '32-true', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', 'transformer-engine', 'transformer-engine-float16']
+56
View File
@@ -0,0 +1,56 @@
from configparser import ConfigParser
import pytest
import torch
from hyperswap.src.networks.aad import AAD
from hyperswap.src.networks.masknet import MaskNet
from hyperswap.src.networks.unet import UNet
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
def test_aad_with_unet(output_size : int) -> None:
config_parser = ConfigParser()
config_parser.read_dict(
{
'training.model.generator':
{
'source_channels': '512',
'output_size': str(output_size),
'num_blocks': '2'
}
})
encoder = UNet(config_parser).eval()
generator = AAD(config_parser).eval()
source_tensor = torch.randn(1, 512)
target_tensor = torch.randn(1, 3, output_size, output_size)
target_features = encoder(target_tensor)
output_tensor = generator(source_tensor, target_features)
assert output_tensor.shape == (1, 3, output_size, output_size)
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
def test_mask_net(output_size : int) -> None:
config_parser = ConfigParser()
config_parser.read_dict(
{
'training.model.masker':
{
'input_channels': '67',
'output_channels': '1',
'num_filters': '16'
}
})
masker = MaskNet(config_parser).eval()
target_tensor = torch.randn(1, 3, output_size, output_size)
target_feature = torch.randn(1, 64, output_size, output_size)
output_mask = masker(target_tensor, target_feature)
assert output_mask.shape == (1, 1, output_size, output_size)
+6
View File
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
from src.training import train
if __name__ == '__main__':
train()
+1
View File
@@ -5,3 +5,4 @@ disallow_untyped_calls = True
disallow_untyped_defs = True disallow_untyped_defs = True
ignore_missing_imports = True ignore_missing_imports = True
strict_optional = False strict_optional = False
explicit_package_bases = True
+10 -6
View File
@@ -1,6 +1,10 @@
lightning==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu128
numpy==1.26.4 albumentations==2.0.8
onnx==1.16.2 lightning==2.5.5
onnxruntime==1.19.0 onnx==1.18.0
opencv-python==4.10.0.84 onnxruntime==1.22.0
mxnet==1.9.1 pytorch-msssim==1.0.0
torch==2.8.0
torchdata==0.11.0
torchvision==0.23.0
tensorboard==2.20.0