Replace calc() with forward(), Rename temp1 with positive and temp2 with negative

This commit is contained in:
henryruhs
2025-02-23 09:22:10 +01:00
parent 6fed877d33
commit e75a3c58f9
2 changed files with 25 additions and 23 deletions
+18 -16
View File
@@ -16,21 +16,21 @@ class DiscriminatorLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def calc(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor:
temp1_tensors = []
temp2_tensors = []
def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor:
positive_tensors = []
negative_tensors = []
for discriminator_output_tensor in discriminator_output_tensors:
temp1_tensor = torch.relu(discriminator_output_tensor[0] + 1).mean(dim = [ 1, 2, 3 ])
temp1_tensors.append(temp1_tensor)
positive_tensor = torch.relu(discriminator_output_tensor[0] + 1).mean(dim = [ 1, 2, 3 ])
positive_tensors.append(positive_tensor)
for discriminator_source_tensor in discriminator_source_tensors:
temp2_tensor = torch.relu(1 - discriminator_source_tensor[0]).mean(dim = [ 1, 2, 3 ])
temp2_tensors.append(temp2_tensor)
negative_tensor = torch.relu(1 - discriminator_source_tensor[0]).mean(dim = [ 1, 2, 3 ])
negative_tensors.append(negative_tensor)
discriminator1_loss = torch.stack(temp1_tensors).mean()
discriminator2_loss = torch.stack(temp2_tensors).mean()
discriminator_loss = (discriminator1_loss + discriminator2_loss) * 0.5
discriminator_positive_loss = torch.stack(positive_tensors).mean()
discriminator_negative_loss = torch.stack(negative_tensors).mean()
discriminator_loss = (discriminator_positive_loss + discriminator_negative_loss) * 0.5
return discriminator_loss
@@ -38,7 +38,7 @@ class AdversarialLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def calc(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
adversarial_weight = CONFIG.getfloat('training.losses', 'adversarial_weight')
temp_tensors = []
@@ -55,7 +55,7 @@ class AttributeLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def calc(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
attribute_weight = CONFIG.getfloat('training.losses', 'attribute_weight')
temp_tensors = []
@@ -74,7 +74,7 @@ class ReconstructionLoss(nn.Module):
super().__init__()
self.mse_loss = nn.MSELoss()
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
temp_tensors = []
@@ -85,9 +85,11 @@ class ReconstructionLoss(nn.Module):
temp_tensors.append(temp_tensor)
else:
temp_tensors.append(temp_tensor * 0)
reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
reconstruction_loss = (reconstruction_loss + similarity) * 0.5
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
return reconstruction_loss, weighted_reconstruction_loss
@@ -99,7 +101,7 @@ class IdentityLoss(nn.Module):
embedder_path = CONFIG.get('training.model', 'embedder_path')
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
def calc(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
@@ -115,7 +117,7 @@ class PoseLoss(nn.Module):
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.mse_loss = nn.MSELoss()
def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
output_motion_features = self.get_motion_features(output_tensor)
target_motion_features = self.get_motion_features(target_tensor)
@@ -143,7 +145,7 @@ class GazeLoss(nn.Module):
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.mse_loss = nn.MSELoss()
def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
output_face_landmark = self.detect_face_landmark(output_tensor)
target_face_landmark = self.detect_face_landmark(target_tensor)
+7 -7
View File
@@ -62,12 +62,12 @@ class FaceSwapperTrainer(lightning.LightningModule):
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor)
pose_loss, weighted_pose_loss = self.pose_loss.calc(target_tensor, generator_output_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss.calc(target_tensor, generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
pose_loss, weighted_pose_loss = self.pose_loss(target_tensor, generator_output_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss
generator_optimizer.zero_grad()
@@ -76,7 +76,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors)
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_loss)