Sort out the warp template naming

This commit is contained in:
henryruhs
2025-04-24 13:07:35 +02:00
parent 810df0f540
commit 5c855aae4e
3 changed files with 6 additions and 6 deletions
+1 -1
View File
@@ -28,7 +28,7 @@ This `config.ini` utilizes the VGGFace2 dataset to train the HyperSwap model.
```
[training.dataset]
file_pattern = .datasets/vggface2/**/*.jpg
warp_template = vggfacehq_256_to_arcface_128_v2
warp_template = vggfacehq_512_to_arcface_128
transform_size = 256
batch_mode = equal
batch_ratio = 0.2
+4 -4
View File
@@ -5,17 +5,17 @@ from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpT
WARP_TEMPLATE_SET : WarpTemplateSet =\
{
'arcface_128_v2_to_arcface_112_v2': torch.tensor(
'arcface_128_to_arcface_112_v2': torch.tensor(
[
[ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ],
[ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ]
]),
'ffhq_512_to_arcface_128_v2': torch.tensor(
'ffhq_512_to_arcface_128': torch.tensor(
[
[ 8.50048894e-01, -1.29486822e-04, 1.90956388e-03 ],
[ 1.29486822e-04, 8.50048894e-01, 9.56254653e-02 ]
]),
'vggfacehq_256_to_arcface_128_v2': torch.tensor(
'vggfacehq_512_to_arcface_128': torch.tensor(
[
[ 1.01305414, -0.00140513, -0.00585911 ],
[ 0.00140513, 1.01305414, 0.11169602 ]
@@ -31,7 +31,7 @@ def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor:
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = warp_tensor(input_tensor, 'arcface_128_v2_to_arcface_112_v2')
crop_tensor = warp_tensor(input_tensor, 'arcface_128_to_arcface_112_v2')
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
crop_tensor[:, :, :padding[0], :] = 0
crop_tensor[:, :, 112 - padding[1]:, :] = 0
+1 -1
View File
@@ -20,5 +20,5 @@ FaceMaskerModule : TypeAlias = Module
OptimizerSet : TypeAlias = Any
WarpTemplate = Literal['arcface_128_v2_to_arcface_112_v2', 'ffhq_512_to_arcface_128_v2', 'vggfacehq_256_to_arcface_128_v2']
WarpTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
WarpTemplateSet : TypeAlias = Dict[WarpTemplate, Tensor]