From 65ab796835ef074eeb578ae3ef0bc98287bfc2f6 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Thu, 27 Feb 2025 23:17:06 +0530 Subject: [PATCH] changes --- face_swapper/src/dataset.py | 2 ++ face_swapper/src/helper.py | 26 ++++++++++++++++++++++++-- face_swapper/src/types.py | 4 +++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index 5dc73b0..8fa31d0 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -6,6 +6,7 @@ from torch import Tensor from torch.utils.data import Dataset from torchvision import io, transforms +from .helper import warp_tensor from .types import Batch @@ -35,6 +36,7 @@ class DynamicDataset(Dataset[Tensor]): transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1), transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)), transforms.ToTensor(), + lambda temp: warp_tensor(temp.unsqueeze(0), '__vgg_face_hq__to__arcface_128_v2__').squeeze(0), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index ebc11f8..4b5d692 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -1,10 +1,32 @@ +import torch from torch import Tensor, nn -from .types import EmbedderModule, Embedding, Padding +from .types import AlignmentMatrices, EmbedderModule, Embedding, Padding + +ALIGNMENT_MATRICES: AlignmentMatrices =\ +{ + '__vgg_face_hq__to__arcface_128_v2__': torch.tensor( + [ + [ 1.01305414, -0.00140513, -0.00585911 ], + [ 0.00140513, 1.01305414, 0.11169602 ] + ], dtype = torch.float32), + '__arcface_128_v2__to__arcface_112_v2__': torch.tensor( + [ + [ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ], + [ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ] + ], dtype = torch.float32) +} + + +def warp_tensor(input_tensor : Tensor, alignment_matrix : str) -> Tensor: + matrix = ALIGNMENT_MATRICES.get(alignment_matrix).repeat(input_tensor.shape[0], 1, 1) + grid = nn.functional.affine_grid(matrix.to(input_tensor.device), list(input_tensor.shape)) + output_tensor = nn.functional.grid_sample(input_tensor, grid, align_corners = False, padding_mode = 'reflection') + return output_tensor def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding: - crop_tensor = input_tensor[:, :, 15: 241, 15: 241] + crop_tensor = warp_tensor(input_tensor, '__arcface_128_v2__to__arcface_112_v2__') crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area') crop_tensor[:, :, :padding[0], :] = 0 crop_tensor[:, :, 112 - padding[1]:, :] = 0 diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index dd3ed87..7040abe 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, TypeAlias +from typing import Any, Dict, Tuple, TypeAlias from torch import Tensor from torch.nn import Module @@ -17,3 +17,5 @@ LandmarkerModule : TypeAlias = Module MotionExtractorModule : TypeAlias = Module OptimizerConfig : TypeAlias = Any + +AlignmentMatrices : TypeAlias = Dict[str, Tensor]