constructor injection

This commit is contained in:
harisreedhar
2025-02-27 19:17:11 +05:30
committed by henryruhs
parent a5eb7d6aa1
commit ea1b0205f0
3 changed files with 15 additions and 10 deletions
+5 -7
View File
@@ -6,7 +6,7 @@ from pytorch_msssim import ssim
from torch import Tensor, nn
from ..helper import calc_embedding
from ..types import Attributes, EmbedderModule, FaceLandmark203
from ..types import Attributes, EmbedderModule, FaceLandmark203, LandmarkerModule, MotionExtractorModule
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -104,10 +104,9 @@ class IdentityLoss(nn.Module):
class PoseLoss(nn.Module):
def __init__(self) -> None:
def __init__(self, motion_extractor : MotionExtractorModule) -> None:
super().__init__()
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = motion_extractor
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
@@ -132,10 +131,9 @@ class PoseLoss(nn.Module):
class GazeLoss(nn.Module):
def __init__(self) -> None:
def __init__(self, landmarker : LandmarkerModule) -> None:
super().__init__()
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = landmarker
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
+8 -3
View File
@@ -27,8 +27,13 @@ class FaceSwapperTrainer(lightning.LightningModule):
def __init__(self) -> None:
super().__init__()
embedder_path = CONFIG.get('training.model', 'embedder_path')
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.embedder = torch.jit.load(embedder_path, map_location='cpu') # type:ignore[no-untyped-call]
self.generator = Generator()
self.discriminator = Discriminator()
self.discriminator_loss = DiscriminatorLoss()
@@ -36,8 +41,8 @@ class FaceSwapperTrainer(lightning.LightningModule):
self.attribute_loss = AttributeLoss()
self.reconstruction_loss = ReconstructionLoss(self.embedder)
self.identity_loss = IdentityLoss(self.embedder)
self.pose_loss = PoseLoss()
self.gaze_loss = GazeLoss()
self.pose_loss = PoseLoss(self.motion_extractor)
self.gaze_loss = GazeLoss(self.landmarker)
self.automatic_optimization = False
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
+2
View File
@@ -13,5 +13,7 @@ Padding : TypeAlias = Tuple[int, int, int, int]
GeneratorModule : TypeAlias = Module
EmbedderModule : TypeAlias = Module
LandmarkerModule : TypeAlias = Module
MotionExtractorModule : TypeAlias = Module
OptimizerConfig : TypeAlias = Any