From 5c855aae4ef78b4adc0e0052dcf3e1bd8bebc70d Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 24 Apr 2025 13:07:35 +0200 Subject: [PATCH] Sort out the warp template naming --- hyperswap/README.md | 2 +- hyperswap/src/helper.py | 8 ++++---- hyperswap/src/types.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/hyperswap/README.md b/hyperswap/README.md index d8a710a..81d37e8 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -28,7 +28,7 @@ This `config.ini` utilizes the VGGFace2 dataset to train the HyperSwap model. ``` [training.dataset] file_pattern = .datasets/vggface2/**/*.jpg -warp_template = vggfacehq_256_to_arcface_128_v2 +warp_template = vggfacehq_512_to_arcface_128 transform_size = 256 batch_mode = equal batch_ratio = 0.2 diff --git a/hyperswap/src/helper.py b/hyperswap/src/helper.py index d4a2371..fee2c95 100644 --- a/hyperswap/src/helper.py +++ b/hyperswap/src/helper.py @@ -5,17 +5,17 @@ from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpT WARP_TEMPLATE_SET : WarpTemplateSet =\ { - 'arcface_128_v2_to_arcface_112_v2': torch.tensor( + 'arcface_128_to_arcface_112_v2': torch.tensor( [ [ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ], [ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ] ]), - 'ffhq_512_to_arcface_128_v2': torch.tensor( + 'ffhq_512_to_arcface_128': torch.tensor( [ [ 8.50048894e-01, -1.29486822e-04, 1.90956388e-03 ], [ 1.29486822e-04, 8.50048894e-01, 9.56254653e-02 ] ]), - 'vggfacehq_256_to_arcface_128_v2': torch.tensor( + 'vggfacehq_512_to_arcface_128': torch.tensor( [ [ 1.01305414, -0.00140513, -0.00585911 ], [ 0.00140513, 1.01305414, 0.11169602 ] @@ -31,7 +31,7 @@ def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor: def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding: - crop_tensor = warp_tensor(input_tensor, 'arcface_128_v2_to_arcface_112_v2') + crop_tensor = warp_tensor(input_tensor, 'arcface_128_to_arcface_112_v2') crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area') crop_tensor[:, :, :padding[0], :] = 0 crop_tensor[:, :, 112 - padding[1]:, :] = 0 diff --git a/hyperswap/src/types.py b/hyperswap/src/types.py index 43e108f..a4f4f31 100644 --- a/hyperswap/src/types.py +++ b/hyperswap/src/types.py @@ -20,5 +20,5 @@ FaceMaskerModule : TypeAlias = Module OptimizerSet : TypeAlias = Any -WarpTemplate = Literal['arcface_128_v2_to_arcface_112_v2', 'ffhq_512_to_arcface_128_v2', 'vggfacehq_256_to_arcface_128_v2'] +WarpTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128'] WarpTemplateSet : TypeAlias = Dict[WarpTemplate, Tensor]