mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Uniform validation, Add cosine_similarity validation to face swapper
This commit is contained in:
@@ -38,9 +38,9 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
output_tensor = self(source_tensor)
|
||||
loss_validation = self.mse_loss(output_tensor, target_tensor)
|
||||
self.log('loss_validation', loss_validation, prog_bar = True)
|
||||
return loss_validation
|
||||
validation = self.mse_loss(output_tensor, target_tensor)
|
||||
self.log('validation', validation, prog_bar = True)
|
||||
return validation
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
@@ -94,6 +94,7 @@ def create_trainer() -> Trainer:
|
||||
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = trainer_max_epochs,
|
||||
callbacks =
|
||||
[
|
||||
@@ -106,8 +107,7 @@ def create_trainer() -> Trainer:
|
||||
save_last = True
|
||||
)
|
||||
],
|
||||
enable_progress_bar = True,
|
||||
log_every_n_steps = 2
|
||||
val_check_interval = 1000
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@ import torchvision
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from .data_loader import DataLoaderVGG
|
||||
from .helper import calc_id_embedding
|
||||
@@ -74,6 +74,15 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'))
|
||||
return generator_losses.get('loss_generator')
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor, _ = batch
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor, target_attributes = self.generator(target_tensor, source_embedding)
|
||||
output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0))
|
||||
validation = nn.functional.cosine_similarity(source_embedding, output_embedding).mean()
|
||||
self.log('validation', validation)
|
||||
return validation
|
||||
|
||||
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None:
|
||||
preview_limit = 8
|
||||
preview_items = []
|
||||
@@ -95,6 +104,7 @@ def create_trainer() -> Trainer:
|
||||
os.makedirs(output_directory_path, exist_ok = True)
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = trainer_max_epochs,
|
||||
precision = trainer_precision,
|
||||
callbacks =
|
||||
@@ -108,7 +118,7 @@ def create_trainer() -> Trainer:
|
||||
save_last = True
|
||||
)
|
||||
],
|
||||
log_every_n_steps = 10
|
||||
val_check_interval = 1000,
|
||||
)
|
||||
|
||||
|
||||
@@ -122,11 +132,12 @@ def train() -> None:
|
||||
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
|
||||
|
||||
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
|
||||
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
training_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = DataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
face_swapper_trainer = FaceSwapperTrainer()
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.isfile(resume_file_path):
|
||||
trainer.fit(face_swapper_trainer, data_loader, ckpt_path = resume_file_path)
|
||||
trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = resume_file_path)
|
||||
else:
|
||||
trainer.fit(face_swapper_trainer, data_loader)
|
||||
trainer.fit(face_swapper_trainer, training_loader, validation_loader)
|
||||
|
||||
Reference in New Issue
Block a user