diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index b766979..c018e96 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -51,7 +51,7 @@ class AADResBlock(nn.Module): intermediate_channels = input_channels if index < (num_blocks - 1) else output_channels primary_add_blocks.extend( [ - AADLayer(input_channels, attribute_channels, identity_channels), + FeatureModulation(input_channels, attribute_channels, identity_channels), nn.ReLU(inplace = True), nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, padding = 1, bias = False) ] @@ -61,7 +61,7 @@ class AADResBlock(nn.Module): def prepare_auxiliary_add_blocks(self, input_channels : int, attribute_channels : int, identity_channels : int, output_channels : int) -> None: if input_channels > output_channels: auxiliary_add_blocks = AADSequential( - AADLayer(input_channels, attribute_channels, identity_channels), + FeatureModulation(input_channels, attribute_channels, identity_channels), nn.ReLU(inplace = True), nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False) ) @@ -84,14 +84,14 @@ class AADSequential(nn.Module): def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: for layer in self.layers: - if isinstance(layer, AADLayer): + if isinstance(layer, FeatureModulation): feature_map = layer(feature_map, attribute_embedding, identity_embedding) else: feature_map = layer(feature_map) return feature_map -class AADLayer(nn.Module): +class FeatureModulation(nn.Module): def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None: super().__init__() self.input_channels = input_channels