mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -70,32 +70,29 @@ class AttributeLoss(nn.Module):
|
||||
|
||||
|
||||
class ReconstructionLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, embedder : nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.embedder = embedder
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
|
||||
temp_tensors = []
|
||||
|
||||
for __source_tensor__, __target_tensor__ in zip(source_tensor, target_tensor):
|
||||
temp_tensor = self.mse_loss(__source_tensor__, __target_tensor__) * torch.equal(__source_tensor__, __target_tensor__)
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
|
||||
same_person = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
|
||||
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
|
||||
reconstruction_loss = (reconstruction_loss * same_person).mean() * 0.5
|
||||
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
|
||||
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
|
||||
|
||||
reconstruction_loss = (reconstruction_loss + similarity) * 0.5
|
||||
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
|
||||
return reconstruction_loss, weighted_reconstruction_loss
|
||||
|
||||
|
||||
class IdentityLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, embedder : nn.Module) -> None:
|
||||
super().__init__()
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.embedder = embedder
|
||||
|
||||
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
|
||||
|
||||
@@ -28,16 +28,16 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
super().__init__()
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
|
||||
self.embedder = torch.jit.load(embedder_path, map_location='cpu') # type:ignore[no-untyped-call]
|
||||
self.generator = Generator()
|
||||
self.discriminator = Discriminator()
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
self.adversarial_loss = AdversarialLoss()
|
||||
self.attribute_loss = AttributeLoss()
|
||||
self.reconstruction_loss = ReconstructionLoss()
|
||||
self.identity_loss = IdentityLoss()
|
||||
self.reconstruction_loss = ReconstructionLoss(self.embedder)
|
||||
self.identity_loss = IdentityLoss(self.embedder)
|
||||
self.pose_loss = PoseLoss()
|
||||
self.gaze_loss = GazeLoss()
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.automatic_optimization = False
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user