Improve lot of types, imports and names

This commit is contained in:
henryruhs
2025-02-12 21:25:43 +01:00
parent e33bc0d52a
commit b6b4f9f65b
10 changed files with 64 additions and 64 deletions
@@ -1,5 +1,5 @@
import torch
import torch.nn as nn
from torch import nn
from ..types import VisionTensor
+2 -2
View File
@@ -7,7 +7,7 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.tuner.tuning import Tuner
from torch import Tensor
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
from .models.embedding_converter import EmbeddingConverter
@@ -21,7 +21,7 @@ class EmbeddingConverterTrainer(pytorch_lightning.LightningModule):
def __init__(self) -> None:
super(EmbeddingConverterTrainer, self).__init__()
self.embedding_converter = EmbeddingConverter()
self.mse_loss = torch.nn.MSELoss()
self.mse_loss = nn.MSELoss()
def forward(self, source_embedding : Embedding) -> Embedding:
return self.embedding_converter(source_embedding)
+2 -2
View File
@@ -12,9 +12,9 @@ from .types import Batch, ImagePathList, ImagePathSet
class DataLoaderVGG(TensorDataset):
def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_folder_pattern : str, same_person_probability : float) -> None:
def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_directory_pattern : str, same_person_probability : float) -> None:
self.same_person_probability = same_person_probability
self.directory_paths = glob.glob(dataset_folder_pattern.format(dataset_path))
self.directory_paths = glob.glob(dataset_directory_pattern.format(dataset_path))
self.image_paths, self.image_path_set = self.prepare_image_paths(dataset_image_pattern)
self.dataset_total = len(self.image_paths)
self.transforms = self.compose_transforms()
+8 -7
View File
@@ -3,8 +3,9 @@ import platform
import cv2
import numpy
import torch
from torch import Tensor, nn
from .types import Embedder, Embedding, Padding, Tensor, VisionFrame, VisionTensor
from .types import Embedder, Embedding, Padding, VisionFrame, VisionTensor
def is_windows() -> bool:
@@ -35,25 +36,25 @@ def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame:
return vision_frame
def hinge_real_loss(tensor : Tensor) -> Tensor:
real_loss = torch.relu(1 - tensor)
def hinge_real_loss(input_tensor : Tensor) -> Tensor:
real_loss = torch.relu(1 - input_tensor)
real_loss = real_loss.mean(dim = [ 1, 2, 3 ])
return real_loss
def hinge_fake_loss(tensor : Tensor) -> Tensor:
fake_loss = torch.relu(tensor + 1)
def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
fake_loss = torch.relu(input_tensor + 1)
fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ])
return fake_loss
def calc_id_embedding(id_embedder : Embedder, vision_tensor : VisionTensor, padding : Padding) -> Embedding:
crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241]
crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
crop_vision_tensor = nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
crop_vision_tensor[:, :, :padding[0], :] = 0
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
crop_vision_tensor[:, :, :, :padding[2]] = 0
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
source_embedding = id_embedder(crop_vision_tensor)
source_embedding = torch.nn.functional.normalize(source_embedding, p = 2)
source_embedding = nn.functional.normalize(source_embedding, p = 2)
return source_embedding
+1 -1
View File
@@ -2,7 +2,7 @@ import configparser
from typing import List
import numpy
import torch.nn as nn
from torch import nn
from face_swapper.src.types import VisionTensor
+1 -1
View File
@@ -1,7 +1,7 @@
import configparser
from typing import Tuple
import torch.nn as nn
from torch import nn
from face_swapper.src.networks.attribute_modulator import AADGenerator
from face_swapper.src.networks.encoder import UNet
+2 -2
View File
@@ -18,7 +18,7 @@ class FaceSwapperLoss:
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
self.mse_loss = torch.nn.MSELoss()
self.mse_loss = nn.MSELoss()
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
@@ -127,7 +127,7 @@ class FaceSwapperLoss:
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
vision_tensor_norm = (vision_tensor + 1) * 0.5
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
vision_tensor_norm = nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
return landmarks
@@ -1,5 +1,5 @@
import torch
from torch import Tensor, nn as nn
from torch import Tensor, nn
from face_swapper.src.types import Embedding, TargetAttributes
@@ -19,13 +19,13 @@ class AADGenerator(nn.Module):
def forward(self, target_attributes : TargetAttributes, source_embedding : Embedding) -> Tensor:
feature_map = self.upsample(source_embedding)
feature_map_1 = torch.nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_2 = torch.nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_3 = torch.nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_4 = torch.nn.functional.interpolate(self.res_block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_5 = torch.nn.functional.interpolate(self.res_block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_6 = torch.nn.functional.interpolate(self.res_block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_7 = torch.nn.functional.interpolate(self.res_block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_1 = nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_2 = nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_3 = nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_4 = nn.functional.interpolate(self.res_block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_5 = nn.functional.interpolate(self.res_block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_6 = nn.functional.interpolate(self.res_block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_7 = nn.functional.interpolate(self.res_block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
output = self.res_block_8(feature_map_7, target_attributes[7], source_embedding)
return torch.tanh(output)
@@ -41,7 +41,7 @@ class AADLayer(nn.Module):
self.instance_norm = nn.InstanceNorm2d(input_channels)
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1)
def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : Embedding) -> Tensor:
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor:
feature_map = self.instance_norm(feature_map)
gamma_attribute = self.conv_gamma(attribute_embedding)
beta_attribute = self.conv_beta(attribute_embedding)
@@ -59,7 +59,7 @@ class AADSequential(nn.Module):
super(AADSequential, self).__init__()
self.layers = nn.ModuleList(args)
def forward(self, feature_map: Tensor, attribute_embedding: Tensor, id_embedding: Embedding) -> Tensor:
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor:
for layer in self.layers:
if isinstance(layer, AADLayer):
feature_map = layer(feature_map, attribute_embedding, id_embedding)
@@ -99,7 +99,7 @@ class AADResBlock(nn.Module):
)
self.auxiliary_add_blocks = auxiliary_add_blocks
def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : Embedding) -> Tensor:
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor:
primary_feature = self.primary_add_blocks(feature_map, attribute_embedding, id_embedding)
if self.input_channels > self.output_channels:
@@ -115,7 +115,7 @@ class PixelShuffleUpsample(nn.Module):
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 3, padding = 1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
def forward(self, temp : Tensor) -> Tensor:
temp = self.conv(temp.view(temp.shape[0], -1, 1, 1))
temp = self.pixel_shuffle(temp)
return temp
def forward(self, input_tensor : Tensor) -> Tensor:
temp_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1))
temp_tensor = self.pixel_shuffle(temp_tensor)
return temp_tensor
+12 -12
View File
@@ -1,5 +1,5 @@
import torch
from torch import Tensor, nn as nn
from torch import Tensor, nn
from face_swapper.src.types import TargetAttributes, VisionTensor
@@ -11,11 +11,11 @@ class Upsample(nn.Module):
self.batch_norm = nn.BatchNorm2d(output_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor:
temp = self.deconv(temp)
temp = self.batch_norm(temp)
temp = self.leaky_relu(temp)
return torch.cat((temp, skip_tensor), dim = 1)
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
temp_tensor = self.deconv(input_tensor)
temp_tensor = self.batch_norm(temp_tensor)
temp_tensor = self.leaky_relu(temp_tensor)
return torch.cat((temp_tensor, skip_tensor), dim = 1)
class DownSample(nn.Module):
@@ -25,11 +25,11 @@ class DownSample(nn.Module):
self.batch_norm = nn.BatchNorm2d(output_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
def forward(self, temp : Tensor) -> Tensor:
temp = self.conv(temp)
temp = self.batch_norm(temp)
temp = self.leaky_relu(temp)
return temp
def forward(self, input_tensor : Tensor) -> Tensor:
temp_tensor = self.conv(input_tensor)
temp_tensor = self.batch_norm(temp_tensor)
temp_tensor = self.leaky_relu(temp_tensor)
return temp_tensor
class UNet(nn.Module):
@@ -63,5 +63,5 @@ class UNet(nn.Module):
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
output = nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
+20 -21
View File
@@ -1,32 +1,31 @@
from collections import OrderedDict
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, TypeAlias
import torch.nn
from numpy.typing import NDArray
from torch import Tensor
from torch.utils.data import DataLoader
from torch.nn import Module
Batch = Tuple[Any, Any, Any]
Loader = DataLoader[Tuple[Tensor, ...]]
ImagePathList = List[str]
ImagePathSet = Dict[str, ImagePathList]
ImagePathList : TypeAlias = List[str]
ImagePathSet : TypeAlias = Dict[str, ImagePathList]
SwapAttributes = Tuple[Tensor, ...]
TargetAttributes = Tuple[Tensor, ...]
DiscriminatorOutputs = List[List[Tensor]]
SwapAttributes : TypeAlias = Tuple[Tensor, ...]
TargetAttributes : TypeAlias = Tuple[Tensor, ...]
DiscriminatorOutputs : TypeAlias = List[List[Tensor]]
Embedding = Tensor
FaceLandmark203 = Tensor
Embedding : TypeAlias = Tensor
FaceLandmark203 : TypeAlias = Tensor
StateSet = OrderedDict[str, Any]
Padding = Tuple[int, int, int, int]
StateSet : TypeAlias = OrderedDict[str, Any]
Padding : TypeAlias = Tuple[int, int, int, int]
LossTensor = Tensor
VisionTensor = Tensor
VisionFrame = NDArray[Any]
VisionFrame : TypeAlias = NDArray[Any]
LossTensor : TypeAlias = Tensor
VisionTensor : TypeAlias = Tensor
GeneratorLossSet = Dict[str, Tensor]
DiscriminatorLossSet = Dict[str, Tensor]
Batch : TypeAlias = Tuple[VisionTensor, VisionTensor, Tensor]
Generator = torch.nn.Module
Embedder = torch.nn.Module
GeneratorLossSet : TypeAlias = Dict[str, Tensor]
DiscriminatorLossSet : TypeAlias = Dict[str, Tensor]
Generator : TypeAlias = Module
Embedder : TypeAlias = Module