refactor/convert tensor (#76)

* change to convert template

* changes

* changes

* Conditional convert input tensor

* Conditional convert input tensor

---------

Co-authored-by: harisreedhar <h4harisreedhar.s.s@gmail.com>
This commit is contained in:
Henry Ruhs
2025-05-26 09:44:09 +02:00
committed by GitHub
parent 475b8b1538
commit 0722db91f1
6 changed files with 25 additions and 21 deletions
+1 -1
View File
@@ -28,7 +28,7 @@ This `config.ini` utilizes the VGGFace2 dataset to train the HyperSwap model.
```
[training.dataset]
file_pattern = .datasets/vggface2/**/*.jpg
warp_template = vggfacehq_512_to_arcface_128
convert_template = vggfacehq_512_to_arcface_128
transform_size = 256
batch_mode = equal
batch_ratio = 0.2
+1 -1
View File
@@ -1,6 +1,6 @@
[training.dataset]
file_pattern =
warp_template =
convert_template =
transform_size =
batch_mode =
batch_ratio =
+15 -10
View File
@@ -9,8 +9,8 @@ from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
from .helper import warp_tensor
from .types import Batch, BatchMode, WarpTemplate
from .helper import convert_tensor
from .types import Batch, BatchMode, ConvertTemplate
class DynamicDataset(Dataset[Tensor]):
@@ -38,15 +38,20 @@ class DynamicDataset(Dataset[Tensor]):
return len(self.file_paths)
def compose_transforms(self) -> transforms:
return transforms.Compose(
__transforms__ =\
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
WarpTransform(self.config_parser),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transforms.ToTensor()
]
if self.config_parser.get('training.dataset', 'convert_template'):
__transforms__.append(ConvertTensorTransform(self.config_parser))
__transforms__.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
return transforms.Compose(__transforms__)
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(self.file_paths)
@@ -98,10 +103,10 @@ class AugmentTransform:
])
class WarpTransform:
class ConvertTensorTransform:
def __init__(self, config_parser : ConfigParser) -> None:
self.config_warp_template = cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template'))
self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset', 'convert_template'))
def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.unsqueeze(0)
return warp_tensor(temp_tensor, self.config_warp_template).squeeze(0)
return convert_tensor(temp_tensor, self.config_convert_template).squeeze(0)
+6 -6
View File
@@ -1,9 +1,9 @@
import torch
from torch import Tensor, nn
from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpTemplateSet
from .types import ConvertTemplate, ConvertTemplateSet, EmbedderModule, Embedding, Mask, Padding
WARP_TEMPLATE_SET : WarpTemplateSet =\
CONVERT_TEMPLATE_SET : ConvertTemplateSet =\
{
'arcface_128_to_arcface_112_v2': torch.tensor(
[
@@ -23,15 +23,15 @@ WARP_TEMPLATE_SET : WarpTemplateSet =\
}
def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor:
normed_warp_template = WARP_TEMPLATE_SET.get(warp_template).repeat(input_tensor.shape[0], 1, 1)
affine_grid = nn.functional.affine_grid(normed_warp_template.to(input_tensor.device), list(input_tensor.shape))
def convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
convert_matrix = CONVERT_TEMPLATE_SET.get(convert_template).repeat(input_tensor.shape[0], 1, 1)
affine_grid = nn.functional.affine_grid(convert_matrix.to(input_tensor.device), list(input_tensor.shape))
output_tensor = nn.functional.grid_sample(input_tensor, affine_grid, padding_mode = 'reflection')
return output_tensor
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = warp_tensor(input_tensor, 'arcface_128_to_arcface_112_v2')
crop_tensor = convert_tensor(input_tensor, 'arcface_128_to_arcface_112_v2')
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
crop_tensor[:, :, :padding[0], :] = 0
crop_tensor[:, :, 112 - padding[1]:, :] = 0
-1
View File
@@ -89,7 +89,6 @@ class HyperSwapTrainer(LightningModule):
source_tensor, target_tensor = batch
do_update = (batch_index + 1) % self.config_accumulate_size == 0
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
generator_target_features = self.generator.encode_features(target_tensor)
+2 -2
View File
@@ -20,5 +20,5 @@ FaceMaskerModule : TypeAlias = Module
OptimizerSet : TypeAlias = Any
WarpTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
WarpTemplateSet : TypeAlias = Dict[WarpTemplate, Tensor]
ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor]