mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Dataset Usage Mode (#79)
* Introduce FilePool to support usage modes * Fix lint * Add to README and config * Enforce equal and same swaps * Different approach to forward convert tempalte * Changes * Changes * Changes * Changes * Introduce V3 of the usage mode feature * Fix lint * Proper use of config parser * fix filter to filter config
This commit is contained in:
+2
-1
@@ -30,7 +30,8 @@ This `config.ini` utilizes the VGGFace2 dataset to train the HyperSwap model.
|
||||
file_pattern = .datasets/vggface2/**/*.jpg
|
||||
convert_template = vggfacehq_512_to_arcface_128
|
||||
transform_size = 256
|
||||
batch_mode = equal
|
||||
usage_mode = both
|
||||
batch_mode = same
|
||||
batch_ratio = 0.2
|
||||
```
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
file_pattern =
|
||||
convert_template =
|
||||
transform_size =
|
||||
usage_mode =
|
||||
batch_mode =
|
||||
batch_ratio =
|
||||
|
||||
|
||||
+83
-33
@@ -10,18 +10,20 @@ from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .helper import convert_tensor
|
||||
from .types import Batch, BatchMode, ConvertTemplate
|
||||
from .types import Batch, BatchMode, ConvertTemplate, UsageMode
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
|
||||
self.config_transform_size = config_parser.getint('training.dataset', 'transform_size')
|
||||
self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset', 'batch_mode'))
|
||||
self.config_batch_ratio = config_parser.getfloat('training.dataset', 'batch_ratio')
|
||||
self.config_file_pattern = config_parser.get('training.dataset.current', 'file_pattern')
|
||||
self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset.current', 'convert_template'))
|
||||
self.config_transform_size = config_parser.getint('training.dataset.current', 'transform_size')
|
||||
self.config_usage_mode = cast(UsageMode, config_parser.get('training.dataset.current', 'usage_mode'))
|
||||
self.config_batch_mode = cast(BatchMode, config_parser.get('training.dataset.current', 'batch_mode'))
|
||||
self.config_batch_ratio = config_parser.getfloat('training.dataset.current', 'batch_ratio')
|
||||
self.config_parser = config_parser
|
||||
self.file_paths = glob.glob(self.config_file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
self.file_paths = glob.glob(self.config_file_pattern)
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = self.file_paths[index]
|
||||
@@ -32,38 +34,31 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
if self.config_batch_mode == 'same':
|
||||
return self.prepare_same_batch(file_path)
|
||||
|
||||
if self.config_usage_mode == 'source':
|
||||
return self.prepare_source_batch(file_path)
|
||||
|
||||
if self.config_usage_mode == 'target':
|
||||
return self.prepare_target_batch(file_path)
|
||||
|
||||
return self.prepare_different_batch(file_path)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.file_paths)
|
||||
|
||||
def compose_transforms(self) -> transforms:
|
||||
__transforms__ =\
|
||||
return transforms.Compose(
|
||||
[
|
||||
AugmentTransform(),
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor()
|
||||
]
|
||||
|
||||
if self.config_parser.get('training.dataset', 'convert_template'):
|
||||
__transforms__.append(ConvertTensorTransform(self.config_parser))
|
||||
|
||||
__transforms__.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
|
||||
|
||||
return transforms.Compose(__transforms__)
|
||||
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(self.file_paths)
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
return source_tensor, target_tensor
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
def prepare_equal_batch(self, source_path : str) -> Batch:
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
|
||||
return source_tensor, source_tensor
|
||||
|
||||
def prepare_same_batch(self, source_path : str) -> Batch:
|
||||
@@ -72,10 +67,74 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_source_batch(self, source_path : str) -> Batch:
|
||||
config_parser = self.filter_config_by_usage_mode('both')
|
||||
config_section = random.choice(config_parser.sections())
|
||||
config_file_pattern = config_parser.get(config_section, 'file_pattern')
|
||||
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
|
||||
|
||||
target_path = random.choice(glob.glob(config_file_pattern))
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
target_tensor = self.conditional_convert_tensor(target_tensor, config_convert_template)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_target_batch(self, target_path : str) -> Batch:
|
||||
config_parser = self.filter_config_by_usage_mode('both')
|
||||
config_section = random.choice(config_parser.sections())
|
||||
config_file_pattern = config_parser.get(config_section, 'file_pattern')
|
||||
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
|
||||
|
||||
source_path = random.choice(glob.glob(config_file_pattern))
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, config_convert_template)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(self.file_paths)
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
|
||||
target_tensor = io.read_image(target_path)
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
target_tensor = self.conditional_convert_tensor(target_tensor, self.config_convert_template)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def filter_config_by_usage_mode(self, usage_mode : UsageMode) -> ConfigParser:
|
||||
config_parser = ConfigParser()
|
||||
|
||||
for config_section in self.config_parser.sections():
|
||||
|
||||
if config_section.startswith('training.dataset'):
|
||||
current_usage_mode = cast(UsageMode, self.config_parser.get(config_section, 'usage_mode'))
|
||||
if current_usage_mode == usage_mode:
|
||||
config_parser.add_section(config_section)
|
||||
|
||||
for key, value in self.config_parser.items(config_section):
|
||||
config_parser.set(config_section, key, value)
|
||||
|
||||
return config_parser
|
||||
|
||||
@staticmethod
|
||||
def conditional_convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
|
||||
if convert_template:
|
||||
temp_tensor = input_tensor.unsqueeze(0)
|
||||
return convert_tensor(temp_tensor, convert_template).squeeze(0)
|
||||
return input_tensor
|
||||
|
||||
|
||||
class AugmentTransform:
|
||||
def __init__(self) -> None:
|
||||
@@ -101,12 +160,3 @@ class AugmentTransform:
|
||||
albumentations.Illumination(p = 0.2),
|
||||
albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.3)
|
||||
])
|
||||
|
||||
|
||||
class ConvertTensorTransform:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset', 'convert_template'))
|
||||
|
||||
def __call__(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.unsqueeze(0)
|
||||
return convert_tensor(temp_tensor, self.config_convert_template).squeeze(0)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import warnings
|
||||
from configparser import ConfigParser
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
@@ -207,13 +208,14 @@ def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]:
|
||||
for config_section in config_parser.sections():
|
||||
|
||||
if config_section.startswith('training.dataset'):
|
||||
current_config_parser = ConfigParser()
|
||||
current_config_parser.add_section('training.dataset')
|
||||
__config_parser__ = deepcopy(config_parser)
|
||||
__config_parser__.remove_section(config_section)
|
||||
__config_parser__.add_section('training.dataset.current')
|
||||
|
||||
for key, value in config_parser.items(config_section):
|
||||
current_config_parser.set('training.dataset', key, value)
|
||||
__config_parser__.set('training.dataset.current', key, value)
|
||||
|
||||
datasets.append(DynamicDataset(current_config_parser))
|
||||
datasets.append(DynamicDataset(__config_parser__))
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
@@ -5,6 +5,10 @@ from torch.nn import Module
|
||||
|
||||
Batch : TypeAlias = Tuple[Tensor, Tensor]
|
||||
BatchMode = Literal['equal', 'same', 'different']
|
||||
UsageMode = Literal['source', 'target', 'both']
|
||||
|
||||
ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
|
||||
ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor]
|
||||
|
||||
Feature : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
@@ -19,6 +23,3 @@ GazerModule : TypeAlias = Module
|
||||
FaceMaskerModule : TypeAlias = Module
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
|
||||
ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor]
|
||||
|
||||
Reference in New Issue
Block a user