Add config to dataloader alignment

This commit is contained in:
henryruhs
2025-03-03 09:23:40 +01:00
parent 34a7f3ef55
commit 589568bfb5
6 changed files with 21 additions and 13 deletions
+1
View File
@@ -28,6 +28,7 @@ This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model.
```
[training.dataset]
file_pattern = .datasets/vggface2/**/*.jpg
warp_matrix = vgg_face_hq_to_arcface_128_v2
batch_ratio = 0.2
```
+1
View File
@@ -1,5 +1,6 @@
[training.dataset]
file_pattern =
warp_matrix =
batch_ratio =
[training.loader]
+8 -5
View File
@@ -7,13 +7,14 @@ from torch.utils.data import Dataset
from torchvision import io, transforms
from .helper import warp_tensor
from .types import Batch
from .types import Batch, WarpMatrix
class DynamicDataset(Dataset[Tensor]):
def __init__(self, file_pattern : str, batch_ratio : float) -> None:
def __init__(self, file_pattern : str, warp_matrix : WarpMatrix, batch_ratio : float) -> None:
self.file_paths = glob.glob(file_pattern)
self.transforms = self.compose_transforms()
self.warp_matrix = warp_matrix
self.batch_ratio = batch_ratio
def __getitem__(self, index : int) -> Batch:
@@ -27,8 +28,7 @@ class DynamicDataset(Dataset[Tensor]):
def __len__(self) -> int:
return len(self.file_paths)
@staticmethod
def compose_transforms() -> transforms:
def compose_transforms(self) -> transforms:
return transforms.Compose(
[
transforms.ToPILImage(),
@@ -36,10 +36,13 @@ class DynamicDataset(Dataset[Tensor]):
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)),
transforms.ToTensor(),
transforms.Lambda(lambda temp_tensor: warp_tensor(temp_tensor.unsqueeze(0), 'vgg_face_hq_to_arcface_128_v2').squeeze(0)),
transforms.Lambda(self.warp_tensor),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def warp_tensor(self, temp_tensor : Tensor):
return warp_tensor(temp_tensor.unsqueeze(0), self.warp_matrix).squeeze(0)
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(self.file_paths)
source_tensor = io.read_image(source_path)
+4 -4
View File
@@ -1,7 +1,7 @@
import torch
from torch import Tensor, nn
from .types import EmbedderModule, Embedding, Padding, WarpMatrixSet
from .types import EmbedderModule, Embedding, Padding, WarpMatrix, WarpMatrixSet
WARP_MATRIX_SET : WarpMatrixSet =\
{
@@ -18,9 +18,9 @@ WARP_MATRIX_SET : WarpMatrixSet =\
}
def warp_tensor(input_tensor : Tensor, alignment_matrix : str) -> Tensor:
matrix = WARP_MATRIX_SET.get(alignment_matrix).repeat(input_tensor.shape[0], 1, 1)
grid = nn.functional.affine_grid(matrix.to(input_tensor.device), list(input_tensor.shape))
def warp_tensor(input_tensor : Tensor, warp_matrix : WarpMatrix) -> Tensor:
warp_matrix = WARP_MATRIX_SET.get(warp_matrix).repeat(input_tensor.shape[0], 1, 1)
grid = nn.functional.affine_grid(warp_matrix.to(input_tensor.device), list(input_tensor.shape))
output_tensor = nn.functional.grid_sample(input_tensor, grid, align_corners = False, padding_mode = 'reflection')
return output_tensor
+4 -3
View File
@@ -62,7 +62,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
'lr_scheduler':
{
'scheduler': generator_scheduler,
'interval': 'step',
'interval': 'step'
}
}
discriminator_config =\
@@ -71,7 +71,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
'lr_scheduler':
{
'scheduler': discriminator_scheduler,
'interval': 'step',
'interval': 'step'
}
}
return generator_config, discriminator_config
@@ -194,13 +194,14 @@ def create_trainer() -> Trainer:
def train() -> None:
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
dataset_warp_matrix = CONFIG.get('training.dataset', 'warp_matrix')
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
output_resume_path = CONFIG.get('training.output', 'resume_path')
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = DynamicDataset(dataset_file_pattern, dataset_batch_ratio)
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_matrix, dataset_batch_ratio)
training_loader, validation_loader = create_loaders(dataset)
face_swapper_trainer = FaceSwapperTrainer()
trainer = create_trainer()
+3 -1
View File
@@ -1,5 +1,6 @@
from typing import Any, Dict, Tuple, TypeAlias
from jinja2.nodes import Literal
from torch import Tensor
from torch.nn import Module
@@ -18,4 +19,5 @@ MotionExtractorModule : TypeAlias = Module
OptimizerConfig : TypeAlias = Any
WarpMatrixSet : TypeAlias = Dict[str, Tensor]
WarpMatrix = Literal['vgg_face_hq_to_arcface_128_v2', 'arcface_128_v2_to_arcface_112_v2']
WarpMatrixSet : TypeAlias = Dict[WarpMatrix, Tensor]