From 578b07a7f46acb51454d06d0e98e3afd4ec3e038 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 26 Feb 2025 00:43:42 +0100 Subject: [PATCH] Add StatefulDataloader, Manual trigger scheduler --- embedding_converter/src/training.py | 9 +++++---- face_swapper/src/training.py | 17 +++++++++++------ requirements.txt | 1 + 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 331ce39..76a534f 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -8,7 +8,8 @@ from lightning import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import Dataset, random_split +from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import StaticDataset from .models.embedding_converter import EmbeddingConverter @@ -68,13 +69,13 @@ class EmbeddingConverterTrainer(lightning.LightningModule): return config -def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]: +def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]: batch_size = CONFIG.getint('training.loader', 'batch_size') num_workers = CONFIG.getint('training.loader', 'num_workers') training_dataset, validate_dataset = split_dataset(dataset) - training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) - validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True) + training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True) return training_loader, validation_loader diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index a8cc6f6..d732d37 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -9,7 +9,8 @@ from lightning import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import Dataset, random_split +from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import DynamicDataset from .helper import calc_embedding @@ -56,7 +57,6 @@ class FaceSwapperTrainer(lightning.LightningModule): 'lr_scheduler': { 'scheduler': generator_scheduler, - 'monitor': 'generator_loss', 'interval': 'step', 'frequency': 1000 } @@ -67,7 +67,6 @@ class FaceSwapperTrainer(lightning.LightningModule): 'lr_scheduler': { 'scheduler': discriminator_scheduler, - 'monitor': 'discriminator_loss', 'interval': 'step', 'frequency': 1000 } @@ -97,6 +96,9 @@ class FaceSwapperTrainer(lightning.LightningModule): self.manual_backward(generator_loss) generator_optimizer.step() + generator_scheduler = self.lr_schedulers()[0] + generator_scheduler.step(generator_loss) + discriminator_source_tensors = self.discriminator(source_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) @@ -105,6 +107,9 @@ class FaceSwapperTrainer(lightning.LightningModule): self.manual_backward(discriminator_loss) discriminator_optimizer.step() + discriminator_scheduler = self.lr_schedulers()[1] + discriminator_scheduler.step(discriminator_loss) + if self.global_step % preview_frequency == 0: self.generate_preview(source_tensor, target_tensor, generator_output_tensor) @@ -140,13 +145,13 @@ class FaceSwapperTrainer(lightning.LightningModule): self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined] -def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]: +def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]: batch_size = CONFIG.getint('training.loader', 'batch_size') num_workers = CONFIG.getint('training.loader', 'num_workers') training_dataset, validate_dataset = split_dataset(dataset) - training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) - validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True) + training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True) return training_loader, validation_loader diff --git a/requirements.txt b/requirements.txt index 1620651..d635148 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,6 @@ onnx==1.17.0 onnxruntime==1.20.1 pytorch-msssim==1.0.0 torch==2.6.0 +torchdata==0.11.0 torchvision==0.21.0 tensorboard==2.19.0