diff --git a/embedding_converter/src/models/embedding_converter.py b/embedding_converter/src/models/embedding_converter.py index 5490aa3..95a3581 100644 --- a/embedding_converter/src/models/embedding_converter.py +++ b/embedding_converter/src/models/embedding_converter.py @@ -1,5 +1,5 @@ import torch -import torch.nn as nn +from torch import nn from ..types import VisionTensor diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 25e2217..13b554b 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -7,7 +7,7 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.tuner.tuning import Tuner -from torch import Tensor +from torch import Tensor, nn from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split from .models.embedding_converter import EmbeddingConverter @@ -21,7 +21,7 @@ class EmbeddingConverterTrainer(pytorch_lightning.LightningModule): def __init__(self) -> None: super(EmbeddingConverterTrainer, self).__init__() self.embedding_converter = EmbeddingConverter() - self.mse_loss = torch.nn.MSELoss() + self.mse_loss = nn.MSELoss() def forward(self, source_embedding : Embedding) -> Embedding: return self.embedding_converter(source_embedding) diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/data_loader.py index 5324509..6d7919b 100644 --- a/face_swapper/src/data_loader.py +++ b/face_swapper/src/data_loader.py @@ -12,9 +12,9 @@ from .types import Batch, ImagePathList, ImagePathSet class DataLoaderVGG(TensorDataset): - def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_folder_pattern : str, same_person_probability : float) -> None: + def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_directory_pattern : str, same_person_probability : float) -> None: self.same_person_probability = same_person_probability - self.directory_paths = glob.glob(dataset_folder_pattern.format(dataset_path)) + self.directory_paths = glob.glob(dataset_directory_pattern.format(dataset_path)) self.image_paths, self.image_path_set = self.prepare_image_paths(dataset_image_pattern) self.dataset_total = len(self.image_paths) self.transforms = self.compose_transforms() diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 05ee770..00faa66 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -3,8 +3,9 @@ import platform import cv2 import numpy import torch +from torch import Tensor, nn -from .types import Embedder, Embedding, Padding, Tensor, VisionFrame, VisionTensor +from .types import Embedder, Embedding, Padding, VisionFrame, VisionTensor def is_windows() -> bool: @@ -35,25 +36,25 @@ def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame: return vision_frame -def hinge_real_loss(tensor : Tensor) -> Tensor: - real_loss = torch.relu(1 - tensor) +def hinge_real_loss(input_tensor : Tensor) -> Tensor: + real_loss = torch.relu(1 - input_tensor) real_loss = real_loss.mean(dim = [ 1, 2, 3 ]) return real_loss -def hinge_fake_loss(tensor : Tensor) -> Tensor: - fake_loss = torch.relu(tensor + 1) +def hinge_fake_loss(input_tensor : Tensor) -> Tensor: + fake_loss = torch.relu(input_tensor + 1) fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ]) return fake_loss def calc_id_embedding(id_embedder : Embedder, vision_tensor : VisionTensor, padding : Padding) -> Embedding: crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241] - crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area') + crop_vision_tensor = nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area') crop_vision_tensor[:, :, :padding[0], :] = 0 crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0 crop_vision_tensor[:, :, :, :padding[2]] = 0 crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0 source_embedding = id_embedder(crop_vision_tensor) - source_embedding = torch.nn.functional.normalize(source_embedding, p = 2) + source_embedding = nn.functional.normalize(source_embedding, p = 2) return source_embedding diff --git a/face_swapper/src/models/discriminator.py b/face_swapper/src/models/discriminator.py index 3755c30..60eb8e1 100644 --- a/face_swapper/src/models/discriminator.py +++ b/face_swapper/src/models/discriminator.py @@ -2,7 +2,7 @@ import configparser from typing import List import numpy -import torch.nn as nn +from torch import nn from face_swapper.src.types import VisionTensor diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 1b38dac..3fc26a4 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -1,7 +1,7 @@ import configparser from typing import Tuple -import torch.nn as nn +from torch import nn from face_swapper.src.networks.attribute_modulator import AADGenerator from face_swapper.src.networks.encoder import UNet diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index a209d49..55d074c 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -18,7 +18,7 @@ class FaceSwapperLoss: landmarker_path = CONFIG.get('training.model', 'landmarker_path') motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') self.batch_size = CONFIG.getint('training.loader', 'batch_size') - self.mse_loss = torch.nn.MSELoss() + self.mse_loss = nn.MSELoss() self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] @@ -127,7 +127,7 @@ class FaceSwapperLoss: def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203: vision_tensor_norm = (vision_tensor + 1) * 0.5 - vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') + vision_tensor_norm = nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2) return landmarks diff --git a/face_swapper/src/networks/attribute_modulator.py b/face_swapper/src/networks/attribute_modulator.py index 69d71e1..1dae96c 100644 --- a/face_swapper/src/networks/attribute_modulator.py +++ b/face_swapper/src/networks/attribute_modulator.py @@ -1,5 +1,5 @@ import torch -from torch import Tensor, nn as nn +from torch import Tensor, nn from face_swapper.src.types import Embedding, TargetAttributes @@ -19,13 +19,13 @@ class AADGenerator(nn.Module): def forward(self, target_attributes : TargetAttributes, source_embedding : Embedding) -> Tensor: feature_map = self.upsample(source_embedding) - feature_map_1 = torch.nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_2 = torch.nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_3 = torch.nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_4 = torch.nn.functional.interpolate(self.res_block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_5 = torch.nn.functional.interpolate(self.res_block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_6 = torch.nn.functional.interpolate(self.res_block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_7 = torch.nn.functional.interpolate(self.res_block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_1 = nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_2 = nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_3 = nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_4 = nn.functional.interpolate(self.res_block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_5 = nn.functional.interpolate(self.res_block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_6 = nn.functional.interpolate(self.res_block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_7 = nn.functional.interpolate(self.res_block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) output = self.res_block_8(feature_map_7, target_attributes[7], source_embedding) return torch.tanh(output) @@ -41,7 +41,7 @@ class AADLayer(nn.Module): self.instance_norm = nn.InstanceNorm2d(input_channels) self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1) - def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : Embedding) -> Tensor: + def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor: feature_map = self.instance_norm(feature_map) gamma_attribute = self.conv_gamma(attribute_embedding) beta_attribute = self.conv_beta(attribute_embedding) @@ -59,7 +59,7 @@ class AADSequential(nn.Module): super(AADSequential, self).__init__() self.layers = nn.ModuleList(args) - def forward(self, feature_map: Tensor, attribute_embedding: Tensor, id_embedding: Embedding) -> Tensor: + def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor: for layer in self.layers: if isinstance(layer, AADLayer): feature_map = layer(feature_map, attribute_embedding, id_embedding) @@ -99,7 +99,7 @@ class AADResBlock(nn.Module): ) self.auxiliary_add_blocks = auxiliary_add_blocks - def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : Embedding) -> Tensor: + def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor: primary_feature = self.primary_add_blocks(feature_map, attribute_embedding, id_embedding) if self.input_channels > self.output_channels: @@ -115,7 +115,7 @@ class PixelShuffleUpsample(nn.Module): self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 3, padding = 1) self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2) - def forward(self, temp : Tensor) -> Tensor: - temp = self.conv(temp.view(temp.shape[0], -1, 1, 1)) - temp = self.pixel_shuffle(temp) - return temp + def forward(self, input_tensor : Tensor) -> Tensor: + temp_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1)) + temp_tensor = self.pixel_shuffle(temp_tensor) + return temp_tensor diff --git a/face_swapper/src/networks/encoder.py b/face_swapper/src/networks/encoder.py index e380ab5..d4e270e 100644 --- a/face_swapper/src/networks/encoder.py +++ b/face_swapper/src/networks/encoder.py @@ -1,5 +1,5 @@ import torch -from torch import Tensor, nn as nn +from torch import Tensor, nn from face_swapper.src.types import TargetAttributes, VisionTensor @@ -11,11 +11,11 @@ class Upsample(nn.Module): self.batch_norm = nn.BatchNorm2d(output_channels) self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) - def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor: - temp = self.deconv(temp) - temp = self.batch_norm(temp) - temp = self.leaky_relu(temp) - return torch.cat((temp, skip_tensor), dim = 1) + def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor: + temp_tensor = self.deconv(input_tensor) + temp_tensor = self.batch_norm(temp_tensor) + temp_tensor = self.leaky_relu(temp_tensor) + return torch.cat((temp_tensor, skip_tensor), dim = 1) class DownSample(nn.Module): @@ -25,11 +25,11 @@ class DownSample(nn.Module): self.batch_norm = nn.BatchNorm2d(output_channels) self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) - def forward(self, temp : Tensor) -> Tensor: - temp = self.conv(temp) - temp = self.batch_norm(temp) - temp = self.leaky_relu(temp) - return temp + def forward(self, input_tensor : Tensor) -> Tensor: + temp_tensor = self.conv(input_tensor) + temp_tensor = self.batch_norm(temp_tensor) + temp_tensor = self.leaky_relu(temp_tensor) + return temp_tensor class UNet(nn.Module): @@ -63,5 +63,5 @@ class UNet(nn.Module): upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3) upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2) upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1) - output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False) + output = nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False) return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index e37001d..b5551d4 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -1,32 +1,31 @@ from collections import OrderedDict -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, TypeAlias -import torch.nn from numpy.typing import NDArray from torch import Tensor -from torch.utils.data import DataLoader +from torch.nn import Module -Batch = Tuple[Any, Any, Any] -Loader = DataLoader[Tuple[Tensor, ...]] -ImagePathList = List[str] -ImagePathSet = Dict[str, ImagePathList] +ImagePathList : TypeAlias = List[str] +ImagePathSet : TypeAlias = Dict[str, ImagePathList] -SwapAttributes = Tuple[Tensor, ...] -TargetAttributes = Tuple[Tensor, ...] -DiscriminatorOutputs = List[List[Tensor]] +SwapAttributes : TypeAlias = Tuple[Tensor, ...] +TargetAttributes : TypeAlias = Tuple[Tensor, ...] +DiscriminatorOutputs : TypeAlias = List[List[Tensor]] -Embedding = Tensor -FaceLandmark203 = Tensor +Embedding : TypeAlias = Tensor +FaceLandmark203 : TypeAlias = Tensor -StateSet = OrderedDict[str, Any] -Padding = Tuple[int, int, int, int] +StateSet : TypeAlias = OrderedDict[str, Any] +Padding : TypeAlias = Tuple[int, int, int, int] -LossTensor = Tensor -VisionTensor = Tensor -VisionFrame = NDArray[Any] +VisionFrame : TypeAlias = NDArray[Any] +LossTensor : TypeAlias = Tensor +VisionTensor : TypeAlias = Tensor -GeneratorLossSet = Dict[str, Tensor] -DiscriminatorLossSet = Dict[str, Tensor] +Batch : TypeAlias = Tuple[VisionTensor, VisionTensor, Tensor] -Generator = torch.nn.Module -Embedder = torch.nn.Module +GeneratorLossSet : TypeAlias = Dict[str, Tensor] +DiscriminatorLossSet : TypeAlias = Dict[str, Tensor] + +Generator : TypeAlias = Module +Embedder : TypeAlias = Module