This commit is contained in:
harisreedhar
2025-02-27 12:18:09 +05:30
committed by henryruhs
parent c8801ececd
commit 2ddcf52b66
2 changed files with 12 additions and 15 deletions
+9 -12
View File
@@ -70,32 +70,29 @@ class AttributeLoss(nn.Module):
class ReconstructionLoss(nn.Module):
def __init__(self) -> None:
def __init__(self, embedder : nn.Module) -> None:
super().__init__()
self.embedder = embedder
self.mse_loss = nn.MSELoss()
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
temp_tensors = []
for __source_tensor__, __target_tensor__ in zip(source_tensor, target_tensor):
temp_tensor = self.mse_loss(__source_tensor__, __target_tensor__) * torch.equal(__source_tensor__, __target_tensor__)
temp_tensors.append(temp_tensor)
reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
same_person = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
reconstruction_loss = (reconstruction_loss * same_person).mean() * 0.5
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
reconstruction_loss = (reconstruction_loss + similarity) * 0.5
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
return reconstruction_loss, weighted_reconstruction_loss
class IdentityLoss(nn.Module):
def __init__(self) -> None:
def __init__(self, embedder : nn.Module) -> None:
super().__init__()
embedder_path = CONFIG.get('training.model', 'embedder_path')
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.embedder = embedder
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
+3 -3
View File
@@ -28,16 +28,16 @@ class FaceSwapperTrainer(lightning.LightningModule):
super().__init__()
embedder_path = CONFIG.get('training.model', 'embedder_path')
self.embedder = torch.jit.load(embedder_path, map_location='cpu') # type:ignore[no-untyped-call]
self.generator = Generator()
self.discriminator = Discriminator()
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss()
self.attribute_loss = AttributeLoss()
self.reconstruction_loss = ReconstructionLoss()
self.identity_loss = IdentityLoss()
self.reconstruction_loss = ReconstructionLoss(self.embedder)
self.identity_loss = IdentityLoss(self.embedder)
self.pose_loss = PoseLoss()
self.gaze_loss = GazeLoss()
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.automatic_optimization = False
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: