mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Add cosine_similarity
This commit is contained in:
@@ -49,9 +49,8 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
with torch.no_grad():
|
||||
source_embedding = self.source_embedder(batch)
|
||||
target_embedding = self.target_embedder(batch)
|
||||
output_embedding = self(source_embedding)
|
||||
validation = self.mse_loss(output_embedding, target_embedding)
|
||||
validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
|
||||
self.log('validation', validation, prog_bar = True)
|
||||
return validation
|
||||
|
||||
@@ -113,7 +112,7 @@ def create_trainer() -> Trainer:
|
||||
save_last = True
|
||||
)
|
||||
],
|
||||
val_check_interval = 10
|
||||
val_check_interval = 100
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user