From d87f6c0b156715a00183ca233aff13095ffe3e16 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Thu, 27 Feb 2025 15:46:50 +0530 Subject: [PATCH] changes --- face_swapper/src/models/loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index f2936df..2614f99 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -6,7 +6,7 @@ from pytorch_msssim import ssim from torch import Tensor, nn from ..helper import calc_embedding -from ..types import Attributes, FaceLandmark203 +from ..types import Attributes, EmbedderModule, FaceLandmark203 CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -70,7 +70,7 @@ class AttributeLoss(nn.Module): class ReconstructionLoss(nn.Module): - def __init__(self, embedder : nn.Module) -> None: + def __init__(self, embedder : EmbedderModule) -> None: super().__init__() self.embedder = embedder self.mse_loss = nn.MSELoss() @@ -90,7 +90,7 @@ class ReconstructionLoss(nn.Module): class IdentityLoss(nn.Module): - def __init__(self, embedder : nn.Module) -> None: + def __init__(self, embedder : EmbedderModule) -> None: super().__init__() self.embedder = embedder