diff --git a/face_swapper/README.md b/face_swapper/README.md index 5215a5d..86731ef 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -29,6 +29,7 @@ This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model. [training.dataset] file_pattern = .datasets/vggface2/**/*.jpg warp_template = vgg_face_hq_to_arcface_128_v2 +batch_mode = same batch_ratio = 0.2 ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 4cf9f80..0b585f4 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -1,6 +1,7 @@ [training.dataset] file_pattern = warp_template = +batch_mode = batch_ratio = [training.loader] diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index e289ac9..ee9a224 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -7,21 +7,25 @@ from torch.utils.data import Dataset from torchvision import io, transforms from .helper import warp_tensor -from .types import Batch, WarpTemplate +from .types import Batch, BatchMode, WarpTemplate class DynamicDataset(Dataset[Tensor]): - def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_ratio : float) -> None: + def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_mode : BatchMode, batch_ratio : float) -> None: self.file_paths = glob.glob(file_pattern) - self.transforms = self.compose_transforms() self.warp_template = warp_template + self.batch_mode = batch_mode self.batch_ratio = batch_ratio + self.transforms = self.compose_transforms() def __getitem__(self, index : int) -> Batch: file_path = self.file_paths[index] if random.random() < self.batch_ratio: - return self.prepare_equal_batch(file_path) + if self.batch_mode == 'equal': + return self.prepare_equal_batch(file_path) + if self.batch_mode == 'same': + return self.prepare_same_batch(file_path) return self.prepare_different_batch(file_path) @@ -51,6 +55,11 @@ class DynamicDataset(Dataset[Tensor]): target_tensor = self.transforms(target_tensor) return source_tensor, target_tensor + def prepare_same_batch(self, source_path : str) -> Batch: + source_tensor = io.read_image(source_path) + source_tensor = self.transforms(source_tensor) + return source_tensor, source_tensor + def prepare_equal_batch(self, source_path : str) -> Batch: target_directory_path = os.path.dirname(source_path) target_file_name_and_extension = random.choice(os.listdir(target_directory_path)) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 91f20bf..90dab65 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -18,7 +18,7 @@ from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss -from .types import Batch, Embedding, OptimizerConfig, WarpTemplate +from .types import Batch, BatchMode, Embedding, OptimizerConfig, WarpTemplate warnings.filterwarnings('ignore', category = UserWarning, module = 'torch') @@ -198,13 +198,14 @@ def create_trainer() -> Trainer: def train() -> None: dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') dataset_warp_template = cast(WarpTemplate, CONFIG.get('training.dataset', 'warp_template')) + dataset_batch_mode = cast(BatchMode, CONFIG.get('training.dataset', 'batch_mode')) dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio') output_resume_path = CONFIG.get('training.output', 'resume_path') if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_ratio) + dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_mode, dataset_batch_ratio) training_loader, validation_loader = create_loaders(dataset) face_swapper_trainer = FaceSwapperTrainer() trainer = create_trainer() diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index 44a1e8a..5aed9bd 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -4,6 +4,7 @@ from torch import Tensor from torch.nn import Module Batch : TypeAlias = Tuple[Tensor, Tensor] +BatchMode = Literal['equal', 'same'] Attributes : TypeAlias = Tuple[Tensor, ...] Embedding : TypeAlias = Tensor