diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index f2936df..2614f99 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -6,7 +6,7 @@ from pytorch_msssim import ssim from torch import Tensor, nn from ..helper import calc_embedding -from ..types import Attributes, FaceLandmark203 +from ..types import Attributes, EmbedderModule, FaceLandmark203 CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -70,7 +70,7 @@ class AttributeLoss(nn.Module): class ReconstructionLoss(nn.Module): - def __init__(self, embedder : nn.Module) -> None: + def __init__(self, embedder : EmbedderModule) -> None: super().__init__() self.embedder = embedder self.mse_loss = nn.MSELoss() @@ -90,7 +90,7 @@ class ReconstructionLoss(nn.Module): class IdentityLoss(nn.Module): - def __init__(self, embedder : nn.Module) -> None: + def __init__(self, embedder : EmbedderModule) -> None: super().__init__() self.embedder = embedder