diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index d02be0b..3ff125f 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -1,10 +1,11 @@ from configparser import ConfigParser +from typing import Tuple from torch import Tensor, nn from ..networks.aad import AAD from ..networks.unet import UNet -from ..types import Attributes, Embedding +from ..types import Embedding, Attribute class Generator(nn.Module): @@ -20,7 +21,7 @@ class Generator(nn.Module): output_tensor = self.generator(source_embedding, target_attributes) return output_tensor - def get_attributes(self, input_tensor : Tensor) -> Attributes: + def get_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]: return self.encoder(input_tensor) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 5c6a33a..4584535 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from torchvision import transforms from ..helper import calc_embedding -from ..types import Attributes, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule +from ..types import Attribute, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule class DiscriminatorLoss(nn.Module): @@ -55,7 +55,7 @@ class AttributeLoss(nn.Module): self.config_batch_size = config_parser.getint('training.loader', 'batch_size') self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight') - def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]: + def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Tensor, Tensor]: temp_tensors = [] for target_attribute, output_attribute in zip(target_attributes, output_attributes): diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index aeb1969..671ff14 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -1,9 +1,10 @@ from configparser import ConfigParser +from typing import Tuple import torch from torch import Tensor, nn -from ..types import Attributes, Embedding +from ..types import Attribute, Embedding class AAD(nn.Module): @@ -56,7 +57,7 @@ class AAD(nn.Module): return layers - def forward(self, source_embedding : Embedding, target_attributes : Attributes) -> Tensor: + def forward(self, source_embedding : Embedding, target_attributes : Tuple[Attribute, ...]) -> Tensor: temp_tensors = self.pixel_shuffle_up_sample(source_embedding) for index, layer in enumerate(self.layers[:-1]): diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index b3304a3..2b39c79 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -7,7 +7,6 @@ Batch : TypeAlias = Tuple[Tensor, Tensor] BatchMode = Literal['equal', 'same'] Attribute : TypeAlias = Tensor -Attributes : TypeAlias = Tuple[Attribute, ...] Embedding : TypeAlias = Tensor Gaze : TypeAlias = Tuple[Tensor, Tensor]