mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
constructor injection
This commit is contained in:
@@ -6,7 +6,7 @@ from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attributes, EmbedderModule, FaceLandmark203
|
||||
from ..types import Attributes, EmbedderModule, FaceLandmark203, LandmarkerModule, MotionExtractorModule
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -104,10 +104,9 @@ class IdentityLoss(nn.Module):
|
||||
|
||||
|
||||
class PoseLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, motion_extractor : MotionExtractorModule) -> None:
|
||||
super().__init__()
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = motion_extractor
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
@@ -132,10 +131,9 @@ class PoseLoss(nn.Module):
|
||||
|
||||
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, landmarker : LandmarkerModule) -> None:
|
||||
super().__init__()
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.landmarker = landmarker
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
@@ -27,8 +27,13 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
|
||||
self.embedder = torch.jit.load(embedder_path, map_location='cpu') # type:ignore[no-untyped-call]
|
||||
self.generator = Generator()
|
||||
self.discriminator = Discriminator()
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
@@ -36,8 +41,8 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
self.attribute_loss = AttributeLoss()
|
||||
self.reconstruction_loss = ReconstructionLoss(self.embedder)
|
||||
self.identity_loss = IdentityLoss(self.embedder)
|
||||
self.pose_loss = PoseLoss()
|
||||
self.gaze_loss = GazeLoss()
|
||||
self.pose_loss = PoseLoss(self.motion_extractor)
|
||||
self.gaze_loss = GazeLoss(self.landmarker)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
|
||||
|
||||
@@ -13,5 +13,7 @@ Padding : TypeAlias = Tuple[int, int, int, int]
|
||||
|
||||
GeneratorModule : TypeAlias = Module
|
||||
EmbedderModule : TypeAlias = Module
|
||||
LandmarkerModule : TypeAlias = Module
|
||||
MotionExtractorModule : TypeAlias = Module
|
||||
|
||||
OptimizerConfig : TypeAlias = Any
|
||||
|
||||
Reference in New Issue
Block a user