Follow the lightning naming and call this dataset, Improve config and types

This commit is contained in:
henryruhs
2025-02-21 09:07:58 +01:00
parent da51c5336d
commit 251e610f0e
5 changed files with 17 additions and 14 deletions
+1 -1
View File
@@ -27,7 +27,7 @@ This `config.ini` utilizes the MegaFace dataset to train the Embedding Converter
```
[training.dataset]
dataset_file_pattern = .datasets/images/{}/*.*g
file_pattern = .datasets/images/{}/*.*g
```
```
+1 -1
View File
@@ -1,5 +1,5 @@
[training.dataset]
dataset_file_pattern =
file_pattern =
[training.loader]
split_ratio =
@@ -2,14 +2,15 @@ import glob
import random
import cv2
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
from .types import Batch
class DataLoaderRecognition(Dataset[torch.Tensor]):
class DynamicDataset(Dataset[Tensor]):
def __init__(self, dataset_file_pattern : str) -> None:
self.image_paths = glob.glob(dataset_file_pattern)
self.transforms = self.compose_transforms()
+9 -9
View File
@@ -1,6 +1,6 @@
import configparser
import os
from typing import Any, Tuple
from typing import Tuple
import lightning
import torch
@@ -11,9 +11,9 @@ from lightning.pytorch.tuner import Tuner
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset, random_split
from .data_loader import DataLoaderRecognition
from .dataset import DynamicDataset
from .models.embedding_converter import EmbeddingConverter
from .types import Batch, Embedding
from .types import Batch, Embedding, OptimizerConfig
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -53,7 +53,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
self.log('validation', validation, prog_bar = True)
return validation
def configure_optimizers(self) -> Any:
def configure_optimizers(self) -> OptimizerConfig:
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
optimizer = torch.optim.Adam(self.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
@@ -71,17 +71,17 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
}
def create_loaders(dataset : Dataset[Any]) -> Tuple[DataLoader[Any], DataLoader[Any]]:
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True)
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]:
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
training_size = int(loader_split_ratio * len(dataset)) # type:ignore[operator, arg-type]
validation_size = len(dataset) - training_size # type:ignore[arg-type]
@@ -115,10 +115,10 @@ def create_trainer() -> Trainer:
def train() -> None:
dataset_file_pattern = CONFIG.get('training.dataset', 'image_pattern')
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
dataset = DataLoaderRecognition(dataset_file_pattern)
dataset = DynamicDataset(dataset_file_pattern)
training_loader, validation_loader = create_loaders(dataset)
embedding_converter_trainer = EmbeddingConverterTrainer()
trainer = create_trainer()
+3 -1
View File
@@ -1,4 +1,4 @@
from typing import Any, TypeAlias
from typing import Any, Dict, TypeAlias
from numpy.typing import NDArray
from torch import Tensor
@@ -6,3 +6,5 @@ from torch import Tensor
Batch : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
VisionFrame : TypeAlias = NDArray[Any]
OptimizerConfig : TypeAlias = Dict[str, Any]