From ea1b0205f00e8f53c8ab5cfc45344b7b4e12da74 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Thu, 27 Feb 2025 19:17:11 +0530 Subject: [PATCH] constructor injection --- face_swapper/src/models/loss.py | 12 +++++------- face_swapper/src/training.py | 11 ++++++++--- face_swapper/src/types.py | 2 ++ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 2614f99..c0699f6 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -6,7 +6,7 @@ from pytorch_msssim import ssim from torch import Tensor, nn from ..helper import calc_embedding -from ..types import Attributes, EmbedderModule, FaceLandmark203 +from ..types import Attributes, EmbedderModule, FaceLandmark203, LandmarkerModule, MotionExtractorModule CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -104,10 +104,9 @@ class IdentityLoss(nn.Module): class PoseLoss(nn.Module): - def __init__(self) -> None: + def __init__(self, motion_extractor : MotionExtractorModule) -> None: super().__init__() - motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') - self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.motion_extractor = motion_extractor self.mse_loss = nn.MSELoss() def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: @@ -132,10 +131,9 @@ class PoseLoss(nn.Module): class GazeLoss(nn.Module): - def __init__(self) -> None: + def __init__(self, landmarker : LandmarkerModule) -> None: super().__init__() - landmarker_path = CONFIG.get('training.model', 'landmarker_path') - self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.landmarker = landmarker self.mse_loss = nn.MSELoss() def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 704be44..2d38f53 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -27,8 +27,13 @@ class FaceSwapperTrainer(lightning.LightningModule): def __init__(self) -> None: super().__init__() embedder_path = CONFIG.get('training.model', 'embedder_path') + landmarker_path = CONFIG.get('training.model', 'landmarker_path') + motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') + + self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.embedder = torch.jit.load(embedder_path, map_location='cpu') # type:ignore[no-untyped-call] self.generator = Generator() self.discriminator = Discriminator() self.discriminator_loss = DiscriminatorLoss() @@ -36,8 +41,8 @@ class FaceSwapperTrainer(lightning.LightningModule): self.attribute_loss = AttributeLoss() self.reconstruction_loss = ReconstructionLoss(self.embedder) self.identity_loss = IdentityLoss(self.embedder) - self.pose_loss = PoseLoss() - self.gaze_loss = GazeLoss() + self.pose_loss = PoseLoss(self.motion_extractor) + self.gaze_loss = GazeLoss(self.landmarker) self.automatic_optimization = False def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index cc2f46e..dd3ed87 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -13,5 +13,7 @@ Padding : TypeAlias = Tuple[int, int, int, int] GeneratorModule : TypeAlias = Module EmbedderModule : TypeAlias = Module +LandmarkerModule : TypeAlias = Module +MotionExtractorModule : TypeAlias = Module OptimizerConfig : TypeAlias = Any