From 31303c1c6c5c7cc3356e06af23fb0459299f51be Mon Sep 17 00:00:00 2001 From: henryruhs Date: Tue, 11 Mar 2025 19:12:26 +0100 Subject: [PATCH 1/3] Add Attributes, polish names --- face_swapper/src/networks/masknet.py | 4 +++- face_swapper/src/training.py | 8 ++++---- face_swapper/src/types.py | 3 ++- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index 0a207aa..b80655e 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -3,6 +3,8 @@ from configparser import ConfigParser import torch from torch import Tensor, nn +from face_swapper.src.types import Attribute + class MaskNet(nn.Module): def __init__(self, config_parser : ConfigParser) -> None: @@ -32,7 +34,7 @@ class MaskNet(nn.Module): UpSample(num_filters, num_filters) ]) - def forward(self, target_tensor : Tensor, target_attribute : Tensor) -> Tensor: + def forward(self, target_tensor : Tensor, target_attribute : Attribute) -> Tensor: output_tensor = torch.cat([ target_tensor, target_attribute ], dim = 1) for down_sample in self.down_samples: diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 1c8f1ee..0baea47 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -54,11 +54,10 @@ class FaceSwapperTrainer(LightningModule): self.automatic_optimization = False def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tensor]: - with torch.no_grad(): output_tensor = self.generator(source_embedding, target_tensor) - target_attributes = self.generator.get_attributes(target_tensor) - mask_tensor = self.masker(target_tensor, target_attributes[-1]) + target_attribute = self.generator.get_attributes(target_tensor)[-1] + mask_tensor = self.masker(target_tensor, target_attribute) return output_tensor, mask_tensor @@ -127,7 +126,8 @@ class FaceSwapperTrainer(LightningModule): self.untoggle_optimizer(generator_optimizer) self.toggle_optimizer(masker_optimizer) - mask_tensor = self.masker(target_tensor, target_attributes[-1].detach()) + target_attribute = target_attributes[-1].detach() + mask_tensor = self.masker(target_tensor, target_attribute) mask_loss = self.mask_loss(target_tensor, mask_tensor) self.manual_backward(mask_loss) diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index 27708b7..b3304a3 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -6,7 +6,8 @@ from torch.nn import Module Batch : TypeAlias = Tuple[Tensor, Tensor] BatchMode = Literal['equal', 'same'] -Attributes : TypeAlias = Tuple[Tensor, ...] +Attribute : TypeAlias = Tensor +Attributes : TypeAlias = Tuple[Attribute, ...] Embedding : TypeAlias = Tensor Gaze : TypeAlias = Tuple[Tensor, Tensor] From 70ac772a34fba663fb60ef7313606d8ce5a9785b Mon Sep 17 00:00:00 2001 From: henryruhs Date: Tue, 11 Mar 2025 19:16:17 +0100 Subject: [PATCH 2/3] Remove Attributes --- face_swapper/src/models/generator.py | 5 +++-- face_swapper/src/models/loss.py | 4 ++-- face_swapper/src/networks/aad.py | 5 +++-- face_swapper/src/types.py | 1 - 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index d02be0b..3ff125f 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -1,10 +1,11 @@ from configparser import ConfigParser +from typing import Tuple from torch import Tensor, nn from ..networks.aad import AAD from ..networks.unet import UNet -from ..types import Attributes, Embedding +from ..types import Embedding, Attribute class Generator(nn.Module): @@ -20,7 +21,7 @@ class Generator(nn.Module): output_tensor = self.generator(source_embedding, target_attributes) return output_tensor - def get_attributes(self, input_tensor : Tensor) -> Attributes: + def get_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]: return self.encoder(input_tensor) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 5c6a33a..4584535 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from torchvision import transforms from ..helper import calc_embedding -from ..types import Attributes, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule +from ..types import Attribute, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule class DiscriminatorLoss(nn.Module): @@ -55,7 +55,7 @@ class AttributeLoss(nn.Module): self.config_batch_size = config_parser.getint('training.loader', 'batch_size') self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight') - def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]: + def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Tensor, Tensor]: temp_tensors = [] for target_attribute, output_attribute in zip(target_attributes, output_attributes): diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index aeb1969..671ff14 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -1,9 +1,10 @@ from configparser import ConfigParser +from typing import Tuple import torch from torch import Tensor, nn -from ..types import Attributes, Embedding +from ..types import Attribute, Embedding class AAD(nn.Module): @@ -56,7 +57,7 @@ class AAD(nn.Module): return layers - def forward(self, source_embedding : Embedding, target_attributes : Attributes) -> Tensor: + def forward(self, source_embedding : Embedding, target_attributes : Tuple[Attribute, ...]) -> Tensor: temp_tensors = self.pixel_shuffle_up_sample(source_embedding) for index, layer in enumerate(self.layers[:-1]): diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index b3304a3..2b39c79 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -7,7 +7,6 @@ Batch : TypeAlias = Tuple[Tensor, Tensor] BatchMode = Literal['equal', 'same'] Attribute : TypeAlias = Tensor -Attributes : TypeAlias = Tuple[Attribute, ...] Embedding : TypeAlias = Tensor Gaze : TypeAlias = Tuple[Tensor, Tensor] From 0991745753e85a0852ba982cce33e82d0090aeee Mon Sep 17 00:00:00 2001 From: henryruhs Date: Tue, 11 Mar 2025 19:17:26 +0100 Subject: [PATCH 3/3] Remove Attributes --- face_swapper/src/models/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 3ff125f..8953354 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -5,7 +5,7 @@ from torch import Tensor, nn from ..networks.aad import AAD from ..networks.unet import UNet -from ..types import Embedding, Attribute +from ..types import Attribute, Embedding class Generator(nn.Module):