From 0722db91f128981fd4904a548d41de2f22f05378 Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Mon, 26 May 2025 09:44:09 +0200 Subject: [PATCH] refactor/convert tensor (#76) * change to convert template * changes * changes * Conditional convert input tensor * Conditional convert input tensor --------- Co-authored-by: harisreedhar --- hyperswap/README.md | 2 +- hyperswap/config.ini | 2 +- hyperswap/src/dataset.py | 25 +++++++++++++++---------- hyperswap/src/helper.py | 12 ++++++------ hyperswap/src/training.py | 1 - hyperswap/src/types.py | 4 ++-- 6 files changed, 25 insertions(+), 21 deletions(-) diff --git a/hyperswap/README.md b/hyperswap/README.md index 9ce5bfd..5cc2baf 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -28,7 +28,7 @@ This `config.ini` utilizes the VGGFace2 dataset to train the HyperSwap model. ``` [training.dataset] file_pattern = .datasets/vggface2/**/*.jpg -warp_template = vggfacehq_512_to_arcface_128 +convert_template = vggfacehq_512_to_arcface_128 transform_size = 256 batch_mode = equal batch_ratio = 0.2 diff --git a/hyperswap/config.ini b/hyperswap/config.ini index b7d50c0..2c626a5 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -1,6 +1,6 @@ [training.dataset] file_pattern = -warp_template = +convert_template = transform_size = batch_mode = batch_ratio = diff --git a/hyperswap/src/dataset.py b/hyperswap/src/dataset.py index db04932..2115571 100644 --- a/hyperswap/src/dataset.py +++ b/hyperswap/src/dataset.py @@ -9,8 +9,8 @@ from torch import Tensor from torch.utils.data import Dataset from torchvision import io, transforms -from .helper import warp_tensor -from .types import Batch, BatchMode, WarpTemplate +from .helper import convert_tensor +from .types import Batch, BatchMode, ConvertTemplate class DynamicDataset(Dataset[Tensor]): @@ -38,15 +38,20 @@ class DynamicDataset(Dataset[Tensor]): return len(self.file_paths) def compose_transforms(self) -> transforms: - return transforms.Compose( + __transforms__ =\ [ AugmentTransform(), transforms.ToPILImage(), transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC), - transforms.ToTensor(), - WarpTransform(self.config_parser), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - ]) + transforms.ToTensor() + ] + + if self.config_parser.get('training.dataset', 'convert_template'): + __transforms__.append(ConvertTensorTransform(self.config_parser)) + + __transforms__.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + + return transforms.Compose(__transforms__) def prepare_different_batch(self, source_path : str) -> Batch: target_path = random.choice(self.file_paths) @@ -98,10 +103,10 @@ class AugmentTransform: ]) -class WarpTransform: +class ConvertTensorTransform: def __init__(self, config_parser : ConfigParser) -> None: - self.config_warp_template = cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template')) + self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset', 'convert_template')) def __call__(self, input_tensor : Tensor) -> Tensor: temp_tensor = input_tensor.unsqueeze(0) - return warp_tensor(temp_tensor, self.config_warp_template).squeeze(0) + return convert_tensor(temp_tensor, self.config_convert_template).squeeze(0) diff --git a/hyperswap/src/helper.py b/hyperswap/src/helper.py index fee2c95..eb740c4 100644 --- a/hyperswap/src/helper.py +++ b/hyperswap/src/helper.py @@ -1,9 +1,9 @@ import torch from torch import Tensor, nn -from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpTemplateSet +from .types import ConvertTemplate, ConvertTemplateSet, EmbedderModule, Embedding, Mask, Padding -WARP_TEMPLATE_SET : WarpTemplateSet =\ +CONVERT_TEMPLATE_SET : ConvertTemplateSet =\ { 'arcface_128_to_arcface_112_v2': torch.tensor( [ @@ -23,15 +23,15 @@ WARP_TEMPLATE_SET : WarpTemplateSet =\ } -def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor: - normed_warp_template = WARP_TEMPLATE_SET.get(warp_template).repeat(input_tensor.shape[0], 1, 1) - affine_grid = nn.functional.affine_grid(normed_warp_template.to(input_tensor.device), list(input_tensor.shape)) +def convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor: + convert_matrix = CONVERT_TEMPLATE_SET.get(convert_template).repeat(input_tensor.shape[0], 1, 1) + affine_grid = nn.functional.affine_grid(convert_matrix.to(input_tensor.device), list(input_tensor.shape)) output_tensor = nn.functional.grid_sample(input_tensor, affine_grid, padding_mode = 'reflection') return output_tensor def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding: - crop_tensor = warp_tensor(input_tensor, 'arcface_128_to_arcface_112_v2') + crop_tensor = convert_tensor(input_tensor, 'arcface_128_to_arcface_112_v2') crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area') crop_tensor[:, :, :padding[0], :] = 0 crop_tensor[:, :, 112 - padding[1]:, :] = 0 diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 58275c1..c7c2c32 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -89,7 +89,6 @@ class HyperSwapTrainer(LightningModule): source_tensor, target_tensor = batch do_update = (batch_index + 1) % self.config_accumulate_size == 0 generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] - source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0)) generator_target_features = self.generator.encode_features(target_tensor) diff --git a/hyperswap/src/types.py b/hyperswap/src/types.py index a4f4f31..1389e4f 100644 --- a/hyperswap/src/types.py +++ b/hyperswap/src/types.py @@ -20,5 +20,5 @@ FaceMaskerModule : TypeAlias = Module OptimizerSet : TypeAlias = Any -WarpTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128'] -WarpTemplateSet : TypeAlias = Dict[WarpTemplate, Tensor] +ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128'] +ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor]