This commit is contained in:
harisreedhar
2025-02-27 15:46:50 +05:30
committed by henryruhs
parent 5d1b90ff19
commit d87f6c0b15
+3 -3
View File
@@ -6,7 +6,7 @@ from pytorch_msssim import ssim
from torch import Tensor, nn
from ..helper import calc_embedding
from ..types import Attributes, FaceLandmark203
from ..types import Attributes, EmbedderModule, FaceLandmark203
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -70,7 +70,7 @@ class AttributeLoss(nn.Module):
class ReconstructionLoss(nn.Module):
def __init__(self, embedder : nn.Module) -> None:
def __init__(self, embedder : EmbedderModule) -> None:
super().__init__()
self.embedder = embedder
self.mse_loss = nn.MSELoss()
@@ -90,7 +90,7 @@ class ReconstructionLoss(nn.Module):
class IdentityLoss(nn.Module):
def __init__(self, embedder : nn.Module) -> None:
def __init__(self, embedder : EmbedderModule) -> None:
super().__init__()
self.embedder = embedder