mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import numpy
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
|
||||
|
||||
@@ -42,8 +42,7 @@ def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor
|
||||
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_vision_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
with torch.no_grad():
|
||||
source_embedding = id_embedder(crop_vision_tensor)
|
||||
source_embedding = id_embedder(crop_vision_tensor)
|
||||
source_embedding = nn.functional.normalize(source_embedding, p = 2)
|
||||
return source_embedding
|
||||
|
||||
|
||||
@@ -95,14 +95,15 @@ class FaceSwapperLoss:
|
||||
return loss_attribute
|
||||
|
||||
def calc_reconstruction_loss(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0))
|
||||
with torch.no_grad():
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0))
|
||||
face_similarities = (torch.cosine_similarity(source_embedding, target_embedding) + 1) * 0.5
|
||||
loss_reconstructions = []
|
||||
|
||||
for index, face_similarity in enumerate(face_similarities):
|
||||
if face_similarity.item() > 0.9:
|
||||
loss_mse = self.mse_loss(swap_tensor[index], target_tensor[index])
|
||||
loss_mse = self.mse_loss(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0))
|
||||
loss_ssim = calc_structural_similarity(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0))
|
||||
loss_reconstruction = (loss_mse + loss_ssim) * 0.5
|
||||
loss_reconstructions.append(loss_reconstruction)
|
||||
|
||||
@@ -94,14 +94,14 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Any]): # type:ignore[type-arg]
|
||||
def create_loaders(dataset : Dataset[Any]) -> Tuple[TorchDataLoader[Any], TorchDataLoader[Any]]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = TorchDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = TorchDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader # type:ignore[return-value]
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]:
|
||||
|
||||
Reference in New Issue
Block a user