mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Replace calc() with forward(), Rename temp1 with positive and temp2 with negative
This commit is contained in:
@@ -16,21 +16,21 @@ class DiscriminatorLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def calc(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor:
|
||||
temp1_tensors = []
|
||||
temp2_tensors = []
|
||||
def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor:
|
||||
positive_tensors = []
|
||||
negative_tensors = []
|
||||
|
||||
for discriminator_output_tensor in discriminator_output_tensors:
|
||||
temp1_tensor = torch.relu(discriminator_output_tensor[0] + 1).mean(dim = [ 1, 2, 3 ])
|
||||
temp1_tensors.append(temp1_tensor)
|
||||
positive_tensor = torch.relu(discriminator_output_tensor[0] + 1).mean(dim = [ 1, 2, 3 ])
|
||||
positive_tensors.append(positive_tensor)
|
||||
|
||||
for discriminator_source_tensor in discriminator_source_tensors:
|
||||
temp2_tensor = torch.relu(1 - discriminator_source_tensor[0]).mean(dim = [ 1, 2, 3 ])
|
||||
temp2_tensors.append(temp2_tensor)
|
||||
negative_tensor = torch.relu(1 - discriminator_source_tensor[0]).mean(dim = [ 1, 2, 3 ])
|
||||
negative_tensors.append(negative_tensor)
|
||||
|
||||
discriminator1_loss = torch.stack(temp1_tensors).mean()
|
||||
discriminator2_loss = torch.stack(temp2_tensors).mean()
|
||||
discriminator_loss = (discriminator1_loss + discriminator2_loss) * 0.5
|
||||
discriminator_positive_loss = torch.stack(positive_tensors).mean()
|
||||
discriminator_negative_loss = torch.stack(negative_tensors).mean()
|
||||
discriminator_loss = (discriminator_positive_loss + discriminator_negative_loss) * 0.5
|
||||
return discriminator_loss
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class AdversarialLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def calc(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
|
||||
adversarial_weight = CONFIG.getfloat('training.losses', 'adversarial_weight')
|
||||
temp_tensors = []
|
||||
|
||||
@@ -55,7 +55,7 @@ class AttributeLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def calc(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
attribute_weight = CONFIG.getfloat('training.losses', 'attribute_weight')
|
||||
temp_tensors = []
|
||||
@@ -74,7 +74,7 @@ class ReconstructionLoss(nn.Module):
|
||||
super().__init__()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
|
||||
temp_tensors = []
|
||||
|
||||
@@ -85,9 +85,11 @@ class ReconstructionLoss(nn.Module):
|
||||
temp_tensors.append(temp_tensor)
|
||||
else:
|
||||
temp_tensors.append(temp_tensor * 0)
|
||||
|
||||
reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
|
||||
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
|
||||
|
||||
reconstruction_loss = (reconstruction_loss + similarity) * 0.5
|
||||
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
|
||||
return reconstruction_loss, weighted_reconstruction_loss
|
||||
@@ -99,7 +101,7 @@ class IdentityLoss(nn.Module):
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
|
||||
def calc(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
@@ -115,7 +117,7 @@ class PoseLoss(nn.Module):
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
|
||||
output_motion_features = self.get_motion_features(output_tensor)
|
||||
target_motion_features = self.get_motion_features(target_tensor)
|
||||
@@ -143,7 +145,7 @@ class GazeLoss(nn.Module):
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
|
||||
output_face_landmark = self.detect_face_landmark(output_tensor)
|
||||
target_face_landmark = self.detect_face_landmark(target_tensor)
|
||||
|
||||
@@ -62,12 +62,12 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors)
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor)
|
||||
pose_loss, weighted_pose_loss = self.pose_loss.calc(target_tensor, generator_output_tensor)
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss.calc(target_tensor, generator_output_tensor)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
|
||||
pose_loss, weighted_pose_loss = self.pose_loss(target_tensor, generator_output_tensor)
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss
|
||||
|
||||
generator_optimizer.zero_grad()
|
||||
@@ -76,7 +76,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors)
|
||||
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_loss)
|
||||
|
||||
Reference in New Issue
Block a user