diff --git a/hyperswap/README.md b/hyperswap/README.md index babd7e9..50d43db 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -30,7 +30,8 @@ This `config.ini` utilizes the VGGFace2 dataset to train the HyperSwap model. file_pattern = .datasets/vggface2/**/*.jpg convert_template = vggfacehq_512_to_arcface_128 transform_size = 256 -batch_mode = equal +usage_mode = both +batch_mode = same batch_ratio = 0.2 ``` diff --git a/hyperswap/config.ini b/hyperswap/config.ini index f969fe9..6e709b1 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -2,6 +2,7 @@ file_pattern = convert_template = transform_size = +usage_mode = batch_mode = batch_ratio = diff --git a/hyperswap/src/dataset.py b/hyperswap/src/dataset.py index 2115571..77df6cc 100644 --- a/hyperswap/src/dataset.py +++ b/hyperswap/src/dataset.py @@ -10,18 +10,20 @@ from torch.utils.data import Dataset from torchvision import io, transforms from .helper import convert_tensor -from .types import Batch, BatchMode, ConvertTemplate +from .types import Batch, BatchMode, ConvertTemplate, UsageMode class DynamicDataset(Dataset[Tensor]): def __init__(self, config_parser : ConfigParser) -> None: - self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern') - self.config_transform_size = config_parser.getint('training.dataset', 'transform_size') - self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset', 'batch_mode')) - self.config_batch_ratio = config_parser.getfloat('training.dataset', 'batch_ratio') + self.config_file_pattern = config_parser.get('training.dataset.current', 'file_pattern') + self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset.current', 'convert_template')) + self.config_transform_size = config_parser.getint('training.dataset.current', 'transform_size') + self.config_usage_mode = cast(UsageMode, config_parser.get('training.dataset.current', 'usage_mode')) + self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset.current', 'batch_mode')) + self.config_batch_ratio = config_parser.getfloat('training.dataset.current', 'batch_ratio') self.config_parser = config_parser - self.file_paths = glob.glob(self.config_file_pattern) self.transforms = self.compose_transforms() + self.file_paths = glob.glob(self.config_file_pattern) def __getitem__(self, index : int) -> Batch: file_path = self.file_paths[index] @@ -32,38 +34,31 @@ class DynamicDataset(Dataset[Tensor]): if self.config_batch_mode == 'same': return self.prepare_same_batch(file_path) + if self.config_usage_mode == 'source': + return self.prepare_source_batch(file_path) + + if self.config_usage_mode == 'target': + return self.prepare_target_batch(file_path) + return self.prepare_different_batch(file_path) def __len__(self) -> int: return len(self.file_paths) def compose_transforms(self) -> transforms: - __transforms__ =\ + return transforms.Compose( [ AugmentTransform(), transforms.ToPILImage(), transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC), - transforms.ToTensor() - ] - - if self.config_parser.get('training.dataset', 'convert_template'): - __transforms__.append(ConvertTensorTransform(self.config_parser)) - - __transforms__.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) - - return transforms.Compose(__transforms__) - - def prepare_different_batch(self, source_path : str) -> Batch: - target_path = random.choice(self.file_paths) - source_tensor = io.read_image(source_path) - source_tensor = self.transforms(source_tensor) - target_tensor = io.read_image(target_path) - target_tensor = self.transforms(target_tensor) - return source_tensor, target_tensor + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) def prepare_equal_batch(self, source_path : str) -> Batch: source_tensor = io.read_image(source_path) source_tensor = self.transforms(source_tensor) + source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template) return source_tensor, source_tensor def prepare_same_batch(self, source_path : str) -> Batch: @@ -72,10 +67,74 @@ class DynamicDataset(Dataset[Tensor]): target_path = os.path.join(target_directory_path, target_file_name_and_extension) source_tensor = io.read_image(source_path) source_tensor = self.transforms(source_tensor) + source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template) target_tensor = io.read_image(target_path) target_tensor = self.transforms(target_tensor) + target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template) return source_tensor, target_tensor + def prepare_source_batch(self, source_path : str) -> Batch: + config_parser = self.filter_config_by_usage_mode('both') + config_section = random.choice(config_parser.sections()) + config_file_pattern = config_parser.get(config_section, 'file_pattern') + config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template')) + + target_path = random.choice(glob.glob(config_file_pattern)) + source_tensor = io.read_image(source_path) + source_tensor = self.transforms(source_tensor) + source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template) + target_tensor = io.read_image(target_path) + target_tensor = self.transforms(target_tensor) + target_tensor = self.conditional_convert_tensor(target_tensor, config_convert_template) + return source_tensor, target_tensor + + def prepare_target_batch(self, target_path : str) -> Batch: + config_parser = self.filter_config_by_usage_mode('both') + config_section = random.choice(config_parser.sections()) + config_file_pattern = config_parser.get(config_section, 'file_pattern') + config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template')) + + source_path = random.choice(glob.glob(config_file_pattern)) + source_tensor = io.read_image(source_path) + source_tensor = self.transforms(source_tensor) + source_tensor = self.conditional_convert_tensor(source_tensor, config_convert_template) + target_tensor = io.read_image(target_path) + target_tensor = self.transforms(target_tensor) + target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template) + return source_tensor, target_tensor + + def prepare_different_batch(self, source_path : str) -> Batch: + target_path = random.choice(self.file_paths) + source_tensor = io.read_image(source_path) + source_tensor = self.transforms(source_tensor) + source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template) + target_tensor = io.read_image(target_path) + target_tensor = self.transforms(target_tensor) + target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template) + return source_tensor, target_tensor + + def filter_config_by_usage_mode(self, usage_mode : UsageMode) -> ConfigParser: + config_parser = ConfigParser() + + for config_section in self.config_parser.sections(): + + if config_section.startswith('training.dataset'): + current_usage_mode = cast(UsageMode, self.config_parser.get(config_section, 'usage_mode')) + if current_usage_mode == usage_mode: + config_parser.add_section(config_section) + + for key, value in self.config_parser.items(config_section): + config_parser.set(config_section, key, value) + + return config_parser + + @staticmethod + def conditional_convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor: + if convert_template: + temp_tensor = input_tensor.unsqueeze(0) + return convert_tensor(temp_tensor, convert_template).squeeze(0) + return input_tensor + class AugmentTransform: def __init__(self) -> None: @@ -101,12 +160,3 @@ class AugmentTransform: albumentations.Illumination(p = 0.2), albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.3) ]) - - -class ConvertTensorTransform: - def __init__(self, config_parser : ConfigParser) -> None: - self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset', 'convert_template')) - - def __call__(self, input_tensor : Tensor) -> Tensor: - temp_tensor = input_tensor.unsqueeze(0) - return convert_tensor(temp_tensor, self.config_convert_template).squeeze(0) diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index cab9977..f9f83d2 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -1,6 +1,7 @@ import os import warnings from configparser import ConfigParser +from copy import deepcopy from typing import List, Tuple import torch @@ -207,13 +208,14 @@ def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]: for config_section in config_parser.sections(): if config_section.startswith('training.dataset'): - current_config_parser = ConfigParser() - current_config_parser.add_section('training.dataset') + __config_parser__ = deepcopy(config_parser) + __config_parser__.remove_section(config_section) + __config_parser__.add_section('training.dataset.current') for key, value in config_parser.items(config_section): - current_config_parser.set('training.dataset', key, value) + __config_parser__.set('training.dataset.current', key, value) - datasets.append(DynamicDataset(current_config_parser)) + datasets.append(DynamicDataset(__config_parser__)) return datasets diff --git a/hyperswap/src/types.py b/hyperswap/src/types.py index 1389e4f..1ce949a 100644 --- a/hyperswap/src/types.py +++ b/hyperswap/src/types.py @@ -5,6 +5,10 @@ from torch.nn import Module Batch : TypeAlias = Tuple[Tensor, Tensor] BatchMode = Literal['equal', 'same', 'different'] +UsageMode = Literal['source', 'target', 'both'] + +ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128'] +ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor] Feature : TypeAlias = Tensor Embedding : TypeAlias = Tensor @@ -19,6 +23,3 @@ GazerModule : TypeAlias = Module FaceMaskerModule : TypeAlias = Module OptimizerSet : TypeAlias = Any - -ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128'] -ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor]