diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 60c3f6b..d6fa9b6 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -16,18 +16,18 @@ class Generator(nn.Module): id_channels = CONFIG.getint('training.model.generator', 'id_channels') num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') - self.attribute_encoder = UNet() - self.attribute_generator = AADGenerator(id_channels, num_blocks) - self.attribute_encoder.apply(init_weight) - self.attribute_generator.apply(init_weight) + self.unet = UNet() + self.aad_generator = AADGenerator(id_channels, num_blocks) + self.unet.apply(init_weight) + self.aad_generator.apply(init_weight) def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: target_attributes = self.get_attributes(target_tensor) - output_tensor = self.attribute_generator(target_attributes, source_embedding) + output_tensor = self.aad_generator(target_attributes, source_embedding) return output_tensor def get_attributes(self, input_tensor : Tensor) -> Attributes: - return self.attribute_encoder(input_tensor) + return self.unet(input_tensor) def init_weight(module : nn.Module) -> None: diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index d7c14fd..388a987 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -17,7 +17,7 @@ from .helper import calc_id_embedding from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import FaceSwapperLoss -from .types import Batch, Embedding, TargetAttributes, VisionTensor +from .types import Batch, Embedding, VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -31,9 +31,9 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.discriminator = Discriminator() self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization') - def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]: - output = self.generator(source_embedding, target_tensor) - return output + def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor: + output_tensor = self.generator(source_embedding, target_tensor) + return output_tensor def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]: learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')