Uniform validation, Add cosine_similarity validation to face swapper

This commit is contained in:
henryruhs
2025-02-18 14:14:23 +01:00
parent c161da2f25
commit 28977d37d6
2 changed files with 22 additions and 11 deletions
+5 -5
View File
@@ -38,9 +38,9 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
output_tensor = self(source_tensor)
loss_validation = self.mse_loss(output_tensor, target_tensor)
self.log('loss_validation', loss_validation, prog_bar = True)
return loss_validation
validation = self.mse_loss(output_tensor, target_tensor)
self.log('validation', validation, prog_bar = True)
return validation
def configure_optimizers(self) -> Any:
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
@@ -94,6 +94,7 @@ def create_trainer() -> Trainer:
return Trainer(
logger = logger,
log_every_n_steps = 10,
max_epochs = trainer_max_epochs,
callbacks =
[
@@ -106,8 +107,7 @@ def create_trainer() -> Trainer:
save_last = True
)
],
enable_progress_bar = True,
log_every_n_steps = 2
val_check_interval = 1000
)
+17 -6
View File
@@ -8,9 +8,9 @@ import torchvision
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Subset
from .data_loader import DataLoaderVGG
from .helper import calc_id_embedding
@@ -74,6 +74,15 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'))
return generator_losses.get('loss_generator')
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor, _ = batch
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
output_tensor, target_attributes = self.generator(target_tensor, source_embedding)
output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0))
validation = nn.functional.cosine_similarity(source_embedding, output_embedding).mean()
self.log('validation', validation)
return validation
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None:
preview_limit = 8
preview_items = []
@@ -95,6 +104,7 @@ def create_trainer() -> Trainer:
os.makedirs(output_directory_path, exist_ok = True)
return Trainer(
logger = logger,
log_every_n_steps = 10,
max_epochs = trainer_max_epochs,
precision = trainer_precision,
callbacks =
@@ -108,7 +118,7 @@ def create_trainer() -> Trainer:
save_last = True
)
],
log_every_n_steps = 10
val_check_interval = 1000,
)
@@ -122,11 +132,12 @@ def train() -> None:
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
training_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = DataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
face_swapper_trainer = FaceSwapperTrainer()
trainer = create_trainer()
if os.path.isfile(resume_file_path):
trainer.fit(face_swapper_trainer, data_loader, ckpt_path = resume_file_path)
trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = resume_file_path)
else:
trainer.fit(face_swapper_trainer, data_loader)
trainer.fit(face_swapper_trainer, training_loader, validation_loader)