Remove Attributes

This commit is contained in:
henryruhs
2025-03-11 19:16:17 +01:00
parent 31303c1c6c
commit 70ac772a34
4 changed files with 8 additions and 7 deletions
+3 -2
View File
@@ -1,10 +1,11 @@
from configparser import ConfigParser
from typing import Tuple
from torch import Tensor, nn
from ..networks.aad import AAD
from ..networks.unet import UNet
from ..types import Attributes, Embedding
from ..types import Embedding, Attribute
class Generator(nn.Module):
@@ -20,7 +21,7 @@ class Generator(nn.Module):
output_tensor = self.generator(source_embedding, target_attributes)
return output_tensor
def get_attributes(self, input_tensor : Tensor) -> Attributes:
def get_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]:
return self.encoder(input_tensor)
+2 -2
View File
@@ -7,7 +7,7 @@ from torch import Tensor, nn
from torchvision import transforms
from ..helper import calc_embedding
from ..types import Attributes, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule
from ..types import Attribute, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule
class DiscriminatorLoss(nn.Module):
@@ -55,7 +55,7 @@ class AttributeLoss(nn.Module):
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight')
def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Tensor, Tensor]:
temp_tensors = []
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
+3 -2
View File
@@ -1,9 +1,10 @@
from configparser import ConfigParser
from typing import Tuple
import torch
from torch import Tensor, nn
from ..types import Attributes, Embedding
from ..types import Attribute, Embedding
class AAD(nn.Module):
@@ -56,7 +57,7 @@ class AAD(nn.Module):
return layers
def forward(self, source_embedding : Embedding, target_attributes : Attributes) -> Tensor:
def forward(self, source_embedding : Embedding, target_attributes : Tuple[Attribute, ...]) -> Tensor:
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
for index, layer in enumerate(self.layers[:-1]):
-1
View File
@@ -7,7 +7,6 @@ Batch : TypeAlias = Tuple[Tensor, Tensor]
BatchMode = Literal['equal', 'same']
Attribute : TypeAlias = Tensor
Attributes : TypeAlias = Tuple[Attribute, ...]
Embedding : TypeAlias = Tensor
Gaze : TypeAlias = Tuple[Tensor, Tensor]