This commit is contained in:
harisreedhar
2025-02-21 20:12:48 +05:30
committed by henryruhs
parent db44c91dd8
commit 4d8433f54a
3 changed files with 8 additions and 8 deletions
+2 -3
View File
@@ -1,7 +1,7 @@
import numpy
import torch
from torch import Tensor, nn
from pytorch_msssim import ssim
from torch import Tensor, nn
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
@@ -42,8 +42,7 @@ def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
crop_vision_tensor[:, :, :, :padding[2]] = 0
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
with torch.no_grad():
source_embedding = id_embedder(crop_vision_tensor)
source_embedding = id_embedder(crop_vision_tensor)
source_embedding = nn.functional.normalize(source_embedding, p = 2)
return source_embedding
+4 -3
View File
@@ -95,14 +95,15 @@ class FaceSwapperLoss:
return loss_attribute
def calc_reconstruction_loss(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0))
with torch.no_grad():
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0))
face_similarities = (torch.cosine_similarity(source_embedding, target_embedding) + 1) * 0.5
loss_reconstructions = []
for index, face_similarity in enumerate(face_similarities):
if face_similarity.item() > 0.9:
loss_mse = self.mse_loss(swap_tensor[index], target_tensor[index])
loss_mse = self.mse_loss(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0))
loss_ssim = calc_structural_similarity(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0))
loss_reconstruction = (loss_mse + loss_ssim) * 0.5
loss_reconstructions.append(loss_reconstruction)
+2 -2
View File
@@ -94,14 +94,14 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
def create_loaders(dataset : Dataset[Any]): # type:ignore[type-arg]
def create_loaders(dataset : Dataset[Any]) -> Tuple[TorchDataLoader[Any], TorchDataLoader[Any]]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = TorchDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = TorchDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader # type:ignore[return-value]
return training_loader, validation_loader
def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]: