diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 6f30ba3..ef3a889 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -49,7 +49,7 @@ class AdversarialLoss(nn.Module): return adversarial_loss, weighted_adversarial_loss -class FeautureLoss(nn.Module): +class FeatureLoss(nn.Module): def __init__(self, config_parser : ConfigParser) -> None: super().__init__() self.config_batch_size = config_parser.getint('training.loader', 'batch_size') diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 84a16d4..344fd4d 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -16,7 +16,7 @@ from .dataset import DynamicDataset from .helper import calc_embedding, overlay_mask from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import AdversarialLoss, DiscriminatorLoss, FeautureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss +from .models.loss import AdversarialLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss from .types import Batch, Embedding, Mask, OptimizerSet warnings.filterwarnings('ignore', category = UserWarning, module = 'torch') @@ -45,7 +45,7 @@ class FaceSwapperTrainer(LightningModule): self.discriminator = Discriminator(config_parser) self.discriminator_loss = DiscriminatorLoss() self.adversarial_loss = AdversarialLoss(config_parser) - self.feature_loss = FeautureLoss(config_parser) + self.feature_loss = FeatureLoss(config_parser) self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder) self.identity_loss = IdentityLoss(config_parser, self.loss_embedder) self.motion_loss = MotionLoss(config_parser, self.motion_extractor)