Modernize to use ModuleList, Fix some types

This commit is contained in:
henryruhs
2025-02-12 17:04:20 +01:00
parent 34d0bc10ed
commit 3c6dfa4efe
3 changed files with 37 additions and 27 deletions
@@ -1,21 +1,31 @@
import torch
import torch.nn as nn
from torch import Tensor
from embedding_converter.src.types import VisionTensor
class EmbeddingConverter(nn.Module):
def __init__(self) -> None:
super(EmbeddingConverter, self).__init__()
self.fc1 = nn.Linear(512, 1024)
self.fc2 = nn.Linear(1024, 2048)
self.fc3 = nn.Linear(2048, 1024)
self.fc4 = nn.Linear(1024, 512)
self.layers = self.create_layers()
self.activation = nn.LeakyReLU()
def forward(self, inputs : Tensor) -> Tensor:
norm_inputs = inputs / torch.norm(inputs)
outputs = self.activation(self.fc1(norm_inputs))
outputs = self.activation(self.fc2(outputs))
outputs = self.activation(self.fc3(outputs))
outputs = self.fc4(outputs)
return outputs
@staticmethod
def create_layers() -> nn.ModuleList:
layers = nn.ModuleList(
[
nn.Linear(512, 1024),
nn.Linear(1024, 2048),
nn.Linear(2048, 1024),
nn.Linear(1024, 512)
])
return layers
def forward(self, input_tensor: VisionTensor) -> VisionTensor:
output_tensor = input_tensor / torch.norm(input_tensor)
for layer in self.layers[:-1]:
output_tensor = self.activation(layer(output_tensor))
output_tensor = self.layers[-1](output_tensor)
return output_tensor
+12 -12
View File
@@ -11,7 +11,7 @@ from torch import Tensor
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
from .models.embedding_converter import EmbeddingConverter
from .types import Batch, Loader
from .types import Batch, Embedding
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -23,20 +23,20 @@ class EmbeddingConverterTrainer(pytorch_lightning.LightningModule):
self.embedding_converter = EmbeddingConverter()
self.mse_loss = torch.nn.MSELoss()
def forward(self, source_embedding : Tensor) -> Tensor:
def forward(self, source_embedding : Embedding) -> Embedding:
return self.embedding_converter(source_embedding)
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source, target = batch
output = self(source)
loss_training = self.mse_loss(output, target)
source_tensor, target = batch
output_tensor = self(source_tensor)
loss_training = self.mse_loss(output_tensor, target)
self.log('loss_training', loss_training, prog_bar = True)
return loss_training
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source, target = batch
output = self(source)
loss_validation = self.mse_loss(output, target)
source_tensor, target_tensor = batch
output_tensor = self(source_tensor)
loss_validation = self.mse_loss(output_tensor, target_tensor)
self.log('loss_validation', loss_validation, prog_bar = True)
return loss_validation
@@ -58,7 +58,7 @@ class EmbeddingConverterTrainer(pytorch_lightning.LightningModule):
}
def create_loaders() -> Tuple[Loader, Loader]:
def create_loaders() -> Tuple[DataLoader, DataLoader]:
loader_batch_size = CONFIG.getint('training.loader', 'batch_size')
loader_num_workers = CONFIG.getint('training.loader', 'num_workers')
@@ -73,9 +73,9 @@ def split_dataset() -> Tuple[Dataset[Any], Dataset[Any]]:
input_target_path = CONFIG.get('preparing.input', 'target_path')
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
source_input = torch.from_numpy(numpy.load(input_source_path)).float()
target_input = torch.from_numpy(numpy.load(input_target_path)).float()
dataset = TensorDataset(source_input, target_input)
source_tensor = torch.from_numpy(numpy.load(input_source_path)).float()
target_tensor = torch.from_numpy(numpy.load(input_target_path)).float()
dataset = TensorDataset(source_tensor, target_tensor)
dataset_size = len(dataset)
training_size = int(loader_split_ratio * len(dataset))
+3 -3
View File
@@ -4,11 +4,11 @@ from numpy.typing import NDArray
from torch import Tensor
from torch.utils.data import DataLoader
Batch = Tuple[Tensor, Tensor]
Loader = DataLoader[Tuple[Tensor, ...]]
Embedding = NDArray[Any]
EmbeddingDataset = NDArray[Embedding]
FaceLandmark5 = NDArray[Any]
VisionFrame = NDArray[Any]
VisionTensor = Tensor
Batch = Tuple[VisionTensor, VisionTensor]