mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Add StatefulDataloader, Manual trigger scheduler
This commit is contained in:
@@ -8,7 +8,8 @@ from lightning import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
from torch.utils.data import Dataset, random_split
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import StaticDataset
|
||||
from .models.embedding_converter import EmbeddingConverter
|
||||
@@ -68,13 +69,13 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
return config
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]:
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ from lightning import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
from torch.utils.data import Dataset, random_split
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding
|
||||
@@ -56,7 +57,6 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': generator_scheduler,
|
||||
'monitor': 'generator_loss',
|
||||
'interval': 'step',
|
||||
'frequency': 1000
|
||||
}
|
||||
@@ -67,7 +67,6 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
'lr_scheduler':
|
||||
{
|
||||
'scheduler': discriminator_scheduler,
|
||||
'monitor': 'discriminator_loss',
|
||||
'interval': 'step',
|
||||
'frequency': 1000
|
||||
}
|
||||
@@ -97,6 +96,9 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
self.manual_backward(generator_loss)
|
||||
generator_optimizer.step()
|
||||
|
||||
generator_scheduler = self.lr_schedulers()[0]
|
||||
generator_scheduler.step(generator_loss)
|
||||
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
@@ -105,6 +107,9 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
self.manual_backward(discriminator_loss)
|
||||
discriminator_optimizer.step()
|
||||
|
||||
discriminator_scheduler = self.lr_schedulers()[1]
|
||||
discriminator_scheduler.step(discriminator_loss)
|
||||
|
||||
if self.global_step % preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
|
||||
@@ -140,13 +145,13 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]:
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
|
||||
@@ -4,5 +4,6 @@ onnx==1.17.0
|
||||
onnxruntime==1.20.1
|
||||
pytorch-msssim==1.0.0
|
||||
torch==2.6.0
|
||||
torchdata==0.11.0
|
||||
torchvision==0.21.0
|
||||
tensorboard==2.19.0
|
||||
|
||||
Reference in New Issue
Block a user