diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index a87848c..3c248a9 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -38,9 +38,9 @@ class EmbeddingConverterTrainer(lightning.LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch output_tensor = self(source_tensor) - loss_validation = self.mse_loss(output_tensor, target_tensor) - self.log('loss_validation', loss_validation, prog_bar = True) - return loss_validation + validation = self.mse_loss(output_tensor, target_tensor) + self.log('validation', validation, prog_bar = True) + return validation def configure_optimizers(self) -> Any: learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate') @@ -94,6 +94,7 @@ def create_trainer() -> Trainer: return Trainer( logger = logger, + log_every_n_steps = 10, max_epochs = trainer_max_epochs, callbacks = [ @@ -106,8 +107,7 @@ def create_trainer() -> Trainer: save_last = True ) ], - enable_progress_bar = True, - log_every_n_steps = 2 + val_check_interval = 1000 ) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 388a987..89fa662 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -8,9 +8,9 @@ import torchvision from lightning import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger -from torch import Tensor +from torch import Tensor, nn from torch.optim import Optimizer -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Subset from .data_loader import DataLoaderVGG from .helper import calc_id_embedding @@ -74,6 +74,15 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.log('loss_reconstruction', generator_losses.get('loss_reconstruction')) return generator_losses.get('loss_generator') + def validation_step(self, batch : Batch, batch_index : int) -> Tensor: + source_tensor, target_tensor, _ = batch + source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) + output_tensor, target_attributes = self.generator(target_tensor, source_embedding) + output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0)) + validation = nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + self.log('validation', validation) + return validation + def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None: preview_limit = 8 preview_items = [] @@ -95,6 +104,7 @@ def create_trainer() -> Trainer: os.makedirs(output_directory_path, exist_ok = True) return Trainer( logger = logger, + log_every_n_steps = 10, max_epochs = trainer_max_epochs, precision = trainer_precision, callbacks = @@ -108,7 +118,7 @@ def create_trainer() -> Trainer: save_last = True ) ], - log_every_n_steps = 10 + val_check_interval = 1000, ) @@ -122,11 +132,12 @@ def train() -> None: resume_file_path = CONFIG.get('training.output', 'resume_file_path') dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability) - data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + training_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + validation_loader = DataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) face_swapper_trainer = FaceSwapperTrainer() trainer = create_trainer() if os.path.isfile(resume_file_path): - trainer.fit(face_swapper_trainer, data_loader, ckpt_path = resume_file_path) + trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = resume_file_path) else: - trainer.fit(face_swapper_trainer, data_loader) + trainer.fit(face_swapper_trainer, training_loader, validation_loader)