Introduce batch mode via config for equal and same batches

This commit is contained in:
henryruhs
2025-03-03 12:13:54 +01:00
parent 83c20f8331
commit 3a61da8bab
5 changed files with 19 additions and 6 deletions
+1
View File
@@ -29,6 +29,7 @@ This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model.
[training.dataset]
file_pattern = .datasets/vggface2/**/*.jpg
warp_template = vgg_face_hq_to_arcface_128_v2
batch_mode = same
batch_ratio = 0.2
```
+1
View File
@@ -1,6 +1,7 @@
[training.dataset]
file_pattern =
warp_template =
batch_mode =
batch_ratio =
[training.loader]
+13 -4
View File
@@ -7,21 +7,25 @@ from torch.utils.data import Dataset
from torchvision import io, transforms
from .helper import warp_tensor
from .types import Batch, WarpTemplate
from .types import Batch, BatchMode, WarpTemplate
class DynamicDataset(Dataset[Tensor]):
def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_ratio : float) -> None:
def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_mode : BatchMode, batch_ratio : float) -> None:
self.file_paths = glob.glob(file_pattern)
self.transforms = self.compose_transforms()
self.warp_template = warp_template
self.batch_mode = batch_mode
self.batch_ratio = batch_ratio
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
file_path = self.file_paths[index]
if random.random() < self.batch_ratio:
return self.prepare_equal_batch(file_path)
if self.batch_mode == 'equal':
return self.prepare_equal_batch(file_path)
if self.batch_mode == 'same':
return self.prepare_same_batch(file_path)
return self.prepare_different_batch(file_path)
@@ -51,6 +55,11 @@ class DynamicDataset(Dataset[Tensor]):
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor
def prepare_same_batch(self, source_path : str) -> Batch:
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
return source_tensor, source_tensor
def prepare_equal_batch(self, source_path : str) -> Batch:
target_directory_path = os.path.dirname(source_path)
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
+3 -2
View File
@@ -18,7 +18,7 @@ from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
from .types import Batch, Embedding, OptimizerConfig, WarpTemplate
from .types import Batch, BatchMode, Embedding, OptimizerConfig, WarpTemplate
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -198,13 +198,14 @@ def create_trainer() -> Trainer:
def train() -> None:
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
dataset_warp_template = cast(WarpTemplate, CONFIG.get('training.dataset', 'warp_template'))
dataset_batch_mode = cast(BatchMode, CONFIG.get('training.dataset', 'batch_mode'))
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
output_resume_path = CONFIG.get('training.output', 'resume_path')
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_ratio)
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_mode, dataset_batch_ratio)
training_loader, validation_loader = create_loaders(dataset)
face_swapper_trainer = FaceSwapperTrainer()
trainer = create_trainer()
+1
View File
@@ -4,6 +4,7 @@ from torch import Tensor
from torch.nn import Module
Batch : TypeAlias = Tuple[Tensor, Tensor]
BatchMode = Literal['equal', 'same']
Attributes : TypeAlias = Tuple[Tensor, ...]
Embedding : TypeAlias = Tensor