mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Use sequential whenever possible
This commit is contained in:
@@ -11,8 +11,8 @@ class Discriminator(nn.Module):
|
||||
super().__init__()
|
||||
self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators')
|
||||
self.config_parser = config_parser
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
self.discriminators = self.create_discriminators()
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
|
||||
def create_discriminators(self) -> nn.ModuleList:
|
||||
discriminators = nn.ModuleList()
|
||||
|
||||
@@ -75,10 +75,10 @@ class ReconstructionLoss(nn.Module):
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
with torch.no_grad():
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
|
||||
|
||||
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
|
||||
|
||||
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
|
||||
@@ -141,6 +141,7 @@ class MotionLoss(nn.Module):
|
||||
|
||||
with torch.no_grad():
|
||||
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
|
||||
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
pose = translation, scale, rotation, motion_points
|
||||
return pose, expression
|
||||
@@ -174,6 +175,7 @@ class GazeLoss(nn.Module):
|
||||
|
||||
with torch.no_grad():
|
||||
pitch, yaw = self.gazer(crop_tensor)
|
||||
|
||||
return pitch, yaw
|
||||
|
||||
|
||||
@@ -201,4 +203,5 @@ class MaskLoss(nn.Module):
|
||||
output_tensor = torch.isin(output_tensor, face_mask_regions).to(target_tensor.dtype)
|
||||
output_tensor = output_tensor.view(-1, 1, 512, 512)
|
||||
output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear')
|
||||
|
||||
return output_tensor
|
||||
|
||||
Reference in New Issue
Block a user