Dataset Usage Mode (#79)

* Introduce FilePool to support usage modes

* Fix lint

* Add to README and config

* Enforce equal and same swaps

* Different approach to forward convert tempalte

* Changes

* Changes

* Changes

* Changes

* Introduce V3 of the usage mode feature

* Fix lint

* Proper use of config parser

* fix filter to filter config
This commit is contained in:
Henry Ruhs
2025-06-04 09:25:31 +02:00
committed by GitHub
parent 24f45877f5
commit a602bbd474
5 changed files with 96 additions and 41 deletions
+2 -1
View File
@@ -30,7 +30,8 @@ This `config.ini` utilizes the VGGFace2 dataset to train the HyperSwap model.
file_pattern = .datasets/vggface2/**/*.jpg
convert_template = vggfacehq_512_to_arcface_128
transform_size = 256
batch_mode = equal
usage_mode = both
batch_mode = same
batch_ratio = 0.2
```
+1
View File
@@ -2,6 +2,7 @@
file_pattern =
convert_template =
transform_size =
usage_mode =
batch_mode =
batch_ratio =
+83 -33
View File
@@ -10,18 +10,20 @@ from torch.utils.data import Dataset
from torchvision import io, transforms
from .helper import convert_tensor
from .types import Batch, BatchMode, ConvertTemplate
from .types import Batch, BatchMode, ConvertTemplate, UsageMode
class DynamicDataset(Dataset[Tensor]):
def __init__(self, config_parser : ConfigParser) -> None:
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
self.config_transform_size = config_parser.getint('training.dataset', 'transform_size')
self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset', 'batch_mode'))
self.config_batch_ratio = config_parser.getfloat('training.dataset', 'batch_ratio')
self.config_file_pattern = config_parser.get('training.dataset.current', 'file_pattern')
self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset.current', 'convert_template'))
self.config_transform_size = config_parser.getint('training.dataset.current', 'transform_size')
self.config_usage_mode = cast(UsageMode, config_parser.get('training.dataset.current', 'usage_mode'))
self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset.current', 'batch_mode'))
self.config_batch_ratio = config_parser.getfloat('training.dataset.current', 'batch_ratio')
self.config_parser = config_parser
self.file_paths = glob.glob(self.config_file_pattern)
self.transforms = self.compose_transforms()
self.file_paths = glob.glob(self.config_file_pattern)
def __getitem__(self, index : int) -> Batch:
file_path = self.file_paths[index]
@@ -32,38 +34,31 @@ class DynamicDataset(Dataset[Tensor]):
if self.config_batch_mode == 'same':
return self.prepare_same_batch(file_path)
if self.config_usage_mode == 'source':
return self.prepare_source_batch(file_path)
if self.config_usage_mode == 'target':
return self.prepare_target_batch(file_path)
return self.prepare_different_batch(file_path)
def __len__(self) -> int:
return len(self.file_paths)
def compose_transforms(self) -> transforms:
__transforms__ =\
return transforms.Compose(
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ToTensor()
]
if self.config_parser.get('training.dataset', 'convert_template'):
__transforms__.append(ConvertTensorTransform(self.config_parser))
__transforms__.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
return transforms.Compose(__transforms__)
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(self.file_paths)
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def prepare_equal_batch(self, source_path : str) -> Batch:
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
return source_tensor, source_tensor
def prepare_same_batch(self, source_path : str) -> Batch:
@@ -72,10 +67,74 @@ class DynamicDataset(Dataset[Tensor]):
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
return source_tensor, target_tensor
def prepare_source_batch(self, source_path : str) -> Batch:
config_parser = self.filter_config_by_usage_mode('both')
config_section = random.choice(config_parser.sections())
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
target_path = random.choice(glob.glob(config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, config_convert_template)
return source_tensor, target_tensor
def prepare_target_batch(self, target_path : str) -> Batch:
config_parser = self.filter_config_by_usage_mode('both')
config_section = random.choice(config_parser.sections())
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
source_path = random.choice(glob.glob(config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
return source_tensor, target_tensor
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(self.file_paths)
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
return source_tensor, target_tensor
def filter_config_by_usage_mode(self, usage_mode : UsageMode) -> ConfigParser:
config_parser = ConfigParser()
for config_section in self.config_parser.sections():
if config_section.startswith('training.dataset'):
current_usage_mode = cast(UsageMode, self.config_parser.get(config_section, 'usage_mode'))
if current_usage_mode == usage_mode:
config_parser.add_section(config_section)
for key, value in self.config_parser.items(config_section):
config_parser.set(config_section, key, value)
return config_parser
@staticmethod
def conditional_convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
if convert_template:
temp_tensor = input_tensor.unsqueeze(0)
return convert_tensor(temp_tensor, convert_template).squeeze(0)
return input_tensor
class AugmentTransform:
def __init__(self) -> None:
@@ -101,12 +160,3 @@ class AugmentTransform:
albumentations.Illumination(p = 0.2),
albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.3)
])
class ConvertTensorTransform:
def __init__(self, config_parser : ConfigParser) -> None:
self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset', 'convert_template'))
def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.unsqueeze(0)
return convert_tensor(temp_tensor, self.config_convert_template).squeeze(0)
+6 -4
View File
@@ -1,6 +1,7 @@
import os
import warnings
from configparser import ConfigParser
from copy import deepcopy
from typing import List, Tuple
import torch
@@ -207,13 +208,14 @@ def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]:
for config_section in config_parser.sections():
if config_section.startswith('training.dataset'):
current_config_parser = ConfigParser()
current_config_parser.add_section('training.dataset')
__config_parser__ = deepcopy(config_parser)
__config_parser__.remove_section(config_section)
__config_parser__.add_section('training.dataset.current')
for key, value in config_parser.items(config_section):
current_config_parser.set('training.dataset', key, value)
__config_parser__.set('training.dataset.current', key, value)
datasets.append(DynamicDataset(current_config_parser))
datasets.append(DynamicDataset(__config_parser__))
return datasets
+4 -3
View File
@@ -5,6 +5,10 @@ from torch.nn import Module
Batch : TypeAlias = Tuple[Tensor, Tensor]
BatchMode = Literal['equal', 'same', 'different']
UsageMode = Literal['source', 'target', 'both']
ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor]
Feature : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
@@ -19,6 +23,3 @@ GazerModule : TypeAlias = Module
FaceMaskerModule : TypeAlias = Module
OptimizerSet : TypeAlias = Any
ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor]