diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index e4592f6..68b702c 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -16,21 +16,21 @@ class DiscriminatorLoss(nn.Module): def __init__(self) -> None: super().__init__() - def calc(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor: - temp1_tensors = [] - temp2_tensors = [] + def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor: + positive_tensors = [] + negative_tensors = [] for discriminator_output_tensor in discriminator_output_tensors: - temp1_tensor = torch.relu(discriminator_output_tensor[0] + 1).mean(dim = [ 1, 2, 3 ]) - temp1_tensors.append(temp1_tensor) + positive_tensor = torch.relu(discriminator_output_tensor[0] + 1).mean(dim = [ 1, 2, 3 ]) + positive_tensors.append(positive_tensor) for discriminator_source_tensor in discriminator_source_tensors: - temp2_tensor = torch.relu(1 - discriminator_source_tensor[0]).mean(dim = [ 1, 2, 3 ]) - temp2_tensors.append(temp2_tensor) + negative_tensor = torch.relu(1 - discriminator_source_tensor[0]).mean(dim = [ 1, 2, 3 ]) + negative_tensors.append(negative_tensor) - discriminator1_loss = torch.stack(temp1_tensors).mean() - discriminator2_loss = torch.stack(temp2_tensors).mean() - discriminator_loss = (discriminator1_loss + discriminator2_loss) * 0.5 + discriminator_positive_loss = torch.stack(positive_tensors).mean() + discriminator_negative_loss = torch.stack(negative_tensors).mean() + discriminator_loss = (discriminator_positive_loss + discriminator_negative_loss) * 0.5 return discriminator_loss @@ -38,7 +38,7 @@ class AdversarialLoss(nn.Module): def __init__(self) -> None: super().__init__() - def calc(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]: + def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]: adversarial_weight = CONFIG.getfloat('training.losses', 'adversarial_weight') temp_tensors = [] @@ -55,7 +55,7 @@ class AttributeLoss(nn.Module): def __init__(self) -> None: super().__init__() - def calc(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]: + def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]: batch_size = CONFIG.getint('training.loader', 'batch_size') attribute_weight = CONFIG.getfloat('training.losses', 'attribute_weight') temp_tensors = [] @@ -74,7 +74,7 @@ class ReconstructionLoss(nn.Module): super().__init__() self.mse_loss = nn.MSELoss() - def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight') temp_tensors = [] @@ -85,9 +85,11 @@ class ReconstructionLoss(nn.Module): temp_tensors.append(temp_tensor) else: temp_tensors.append(temp_tensor * 0) + reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5 data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() + reconstruction_loss = (reconstruction_loss + similarity) * 0.5 weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight return reconstruction_loss, weighted_reconstruction_loss @@ -99,7 +101,7 @@ class IdentityLoss(nn.Module): embedder_path = CONFIG.get('training.model', 'embedder_path') self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] - def calc(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: identity_weight = CONFIG.getfloat('training.losses', 'identity_weight') output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10)) source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) @@ -115,7 +117,7 @@ class PoseLoss(nn.Module): self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.mse_loss = nn.MSELoss() - def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: + def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: pose_weight = CONFIG.getfloat('training.losses', 'pose_weight') output_motion_features = self.get_motion_features(output_tensor) target_motion_features = self.get_motion_features(target_tensor) @@ -143,7 +145,7 @@ class GazeLoss(nn.Module): self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.mse_loss = nn.MSELoss() - def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: + def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight') output_face_landmark = self.detect_face_landmark(output_tensor) target_face_landmark = self.detect_face_landmark(target_tensor) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index fb52b12..f3b132a 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -62,12 +62,12 @@ class FaceSwapperTrainer(lightning.LightningModule): generator_output_attributes = self.generator.get_attributes(generator_output_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) - adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors) - attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes) - reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor) - identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor) - pose_loss, weighted_pose_loss = self.pose_loss.calc(target_tensor, generator_output_tensor) - gaze_loss, weighted_gaze_loss = self.gaze_loss.calc(target_tensor, generator_output_tensor) + adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors) + attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes) + reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor) + identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor) + pose_loss, weighted_pose_loss = self.pose_loss(target_tensor, generator_output_tensor) + gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor) generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss generator_optimizer.zero_grad() @@ -76,7 +76,7 @@ class FaceSwapperTrainer(lightning.LightningModule): discriminator_source_tensors = self.discriminator(source_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) - discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors) + discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) discriminator_optimizer.zero_grad() self.manual_backward(discriminator_loss)