mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Add config to dataloader alignment
This commit is contained in:
@@ -28,6 +28,7 @@ This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model.
|
||||
```
|
||||
[training.dataset]
|
||||
file_pattern = .datasets/vggface2/**/*.jpg
|
||||
warp_matrix = vgg_face_hq_to_arcface_128_v2
|
||||
batch_ratio = 0.2
|
||||
```
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
[training.dataset]
|
||||
file_pattern =
|
||||
warp_matrix =
|
||||
batch_ratio =
|
||||
|
||||
[training.loader]
|
||||
|
||||
@@ -7,13 +7,14 @@ from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .helper import warp_tensor
|
||||
from .types import Batch
|
||||
from .types import Batch, WarpMatrix
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
def __init__(self, file_pattern : str, batch_ratio : float) -> None:
|
||||
def __init__(self, file_pattern : str, warp_matrix : WarpMatrix, batch_ratio : float) -> None:
|
||||
self.file_paths = glob.glob(file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
self.warp_matrix = warp_matrix
|
||||
self.batch_ratio = batch_ratio
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
@@ -27,8 +28,7 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
def __len__(self) -> int:
|
||||
return len(self.file_paths)
|
||||
|
||||
@staticmethod
|
||||
def compose_transforms() -> transforms:
|
||||
def compose_transforms(self) -> transforms:
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.ToPILImage(),
|
||||
@@ -36,10 +36,13 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
|
||||
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda temp_tensor: warp_tensor(temp_tensor.unsqueeze(0), 'vgg_face_hq_to_arcface_128_v2').squeeze(0)),
|
||||
transforms.Lambda(self.warp_tensor),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
def warp_tensor(self, temp_tensor : Tensor):
|
||||
return warp_tensor(temp_tensor.unsqueeze(0), self.warp_matrix).squeeze(0)
|
||||
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(self.file_paths)
|
||||
source_tensor = io.read_image(source_path)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import EmbedderModule, Embedding, Padding, WarpMatrixSet
|
||||
from .types import EmbedderModule, Embedding, Padding, WarpMatrix, WarpMatrixSet
|
||||
|
||||
WARP_MATRIX_SET : WarpMatrixSet =\
|
||||
{
|
||||
@@ -18,9 +18,9 @@ WARP_MATRIX_SET : WarpMatrixSet =\
|
||||
}
|
||||
|
||||
|
||||
def warp_tensor(input_tensor : Tensor, alignment_matrix : str) -> Tensor:
|
||||
matrix = WARP_MATRIX_SET.get(alignment_matrix).repeat(input_tensor.shape[0], 1, 1)
|
||||
grid = nn.functional.affine_grid(matrix.to(input_tensor.device), list(input_tensor.shape))
|
||||
def warp_tensor(input_tensor : Tensor, warp_matrix : WarpMatrix) -> Tensor:
|
||||
warp_matrix = WARP_MATRIX_SET.get(warp_matrix).repeat(input_tensor.shape[0], 1, 1)
|
||||
grid = nn.functional.affine_grid(warp_matrix.to(input_tensor.device), list(input_tensor.shape))
|
||||
output_tensor = nn.functional.grid_sample(input_tensor, grid, align_corners = False, padding_mode = 'reflection')
|
||||
return output_tensor
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': generator_scheduler,
|
||||
'interval': 'step',
|
||||
'interval': 'step'
|
||||
}
|
||||
}
|
||||
discriminator_config =\
|
||||
@@ -71,7 +71,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': discriminator_scheduler,
|
||||
'interval': 'step',
|
||||
'interval': 'step'
|
||||
}
|
||||
}
|
||||
return generator_config, discriminator_config
|
||||
@@ -194,13 +194,14 @@ def create_trainer() -> Trainer:
|
||||
|
||||
def train() -> None:
|
||||
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
|
||||
dataset_warp_matrix = CONFIG.get('training.dataset', 'warp_matrix')
|
||||
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
|
||||
output_resume_path = CONFIG.get('training.output', 'resume_path')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = DynamicDataset(dataset_file_pattern, dataset_batch_ratio)
|
||||
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_matrix, dataset_batch_ratio)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
face_swapper_trainer = FaceSwapperTrainer()
|
||||
trainer = create_trainer()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, Tuple, TypeAlias
|
||||
|
||||
from jinja2.nodes import Literal
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
@@ -18,4 +19,5 @@ MotionExtractorModule : TypeAlias = Module
|
||||
|
||||
OptimizerConfig : TypeAlias = Any
|
||||
|
||||
WarpMatrixSet : TypeAlias = Dict[str, Tensor]
|
||||
WarpMatrix = Literal['vgg_face_hq_to_arcface_128_v2', 'arcface_128_v2_to_arcface_112_v2']
|
||||
WarpMatrixSet : TypeAlias = Dict[WarpMatrix, Tensor]
|
||||
|
||||
Reference in New Issue
Block a user