Add StatefulDataloader, Manual trigger scheduler

This commit is contained in:
henryruhs
2025-02-26 00:43:42 +01:00
parent 7ce9d27097
commit 578b07a7f4
3 changed files with 17 additions and 10 deletions
+5 -4
View File
@@ -8,7 +8,8 @@ from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.data import Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import StaticDataset
from .models.embedding_converter import EmbeddingConverter
@@ -68,13 +69,13 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
return config
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]:
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
+11 -6
View File
@@ -9,7 +9,8 @@ from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.data import Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import DynamicDataset
from .helper import calc_embedding
@@ -56,7 +57,6 @@ class FaceSwapperTrainer(lightning.LightningModule):
'lr_scheduler':
{
'scheduler': generator_scheduler,
'monitor': 'generator_loss',
'interval': 'step',
'frequency': 1000
}
@@ -67,7 +67,6 @@ class FaceSwapperTrainer(lightning.LightningModule):
'lr_scheduler':
{
'scheduler': discriminator_scheduler,
'monitor': 'discriminator_loss',
'interval': 'step',
'frequency': 1000
}
@@ -97,6 +96,9 @@ class FaceSwapperTrainer(lightning.LightningModule):
self.manual_backward(generator_loss)
generator_optimizer.step()
generator_scheduler = self.lr_schedulers()[0]
generator_scheduler.step(generator_loss)
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
@@ -105,6 +107,9 @@ class FaceSwapperTrainer(lightning.LightningModule):
self.manual_backward(discriminator_loss)
discriminator_optimizer.step()
discriminator_scheduler = self.lr_schedulers()[1]
discriminator_scheduler.step(discriminator_loss)
if self.global_step % preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
@@ -140,13 +145,13 @@ class FaceSwapperTrainer(lightning.LightningModule):
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]:
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
+1
View File
@@ -4,5 +4,6 @@ onnx==1.17.0
onnxruntime==1.20.1
pytorch-msssim==1.0.0
torch==2.6.0
torchdata==0.11.0
torchvision==0.21.0
tensorboard==2.19.0