mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
changes
This commit is contained in:
@@ -6,7 +6,7 @@ from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attributes, FaceLandmark203
|
||||
from ..types import Attributes, EmbedderModule, FaceLandmark203
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -70,7 +70,7 @@ class AttributeLoss(nn.Module):
|
||||
|
||||
|
||||
class ReconstructionLoss(nn.Module):
|
||||
def __init__(self, embedder : nn.Module) -> None:
|
||||
def __init__(self, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.embedder = embedder
|
||||
self.mse_loss = nn.MSELoss()
|
||||
@@ -90,7 +90,7 @@ class ReconstructionLoss(nn.Module):
|
||||
|
||||
|
||||
class IdentityLoss(nn.Module):
|
||||
def __init__(self, embedder : nn.Module) -> None:
|
||||
def __init__(self, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.embedder = embedder
|
||||
|
||||
|
||||
Reference in New Issue
Block a user