|
|
|
@@ -1,6 +1,6 @@
|
|
|
|
|
import configparser
|
|
|
|
|
import os
|
|
|
|
|
from typing import Any, Tuple
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
import lightning
|
|
|
|
|
import torch
|
|
|
|
@@ -11,9 +11,9 @@ from lightning.pytorch.tuner import Tuner
|
|
|
|
|
from torch import Tensor, nn
|
|
|
|
|
from torch.utils.data import DataLoader, Dataset, random_split
|
|
|
|
|
|
|
|
|
|
from .data_loader import DataLoaderRecognition
|
|
|
|
|
from .dataset import DynamicDataset
|
|
|
|
|
from .models.embedding_converter import EmbeddingConverter
|
|
|
|
|
from .types import Batch, Embedding
|
|
|
|
|
from .types import Batch, Embedding, OptimizerConfig
|
|
|
|
|
|
|
|
|
|
CONFIG = configparser.ConfigParser()
|
|
|
|
|
CONFIG.read('config.ini')
|
|
|
|
@@ -53,7 +53,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
|
|
|
|
self.log('validation', validation, prog_bar = True)
|
|
|
|
|
return validation
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self) -> Any:
|
|
|
|
|
def configure_optimizers(self) -> OptimizerConfig:
|
|
|
|
|
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
|
|
|
|
|
optimizer = torch.optim.Adam(self.parameters(), lr = learning_rate)
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
|
|
|
@@ -71,17 +71,17 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_loaders(dataset : Dataset[Any]) -> Tuple[DataLoader[Any], DataLoader[Any]]:
|
|
|
|
|
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[DataLoader[Tensor], DataLoader[Tensor]]:
|
|
|
|
|
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
|
|
|
|
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
|
|
|
|
|
|
|
|
|
training_dataset, validate_dataset = split_dataset(dataset)
|
|
|
|
|
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
|
|
|
|
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True)
|
|
|
|
|
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
|
|
|
|
|
return training_loader, validation_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]:
|
|
|
|
|
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
|
|
|
|
|
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
|
|
|
|
|
training_size = int(loader_split_ratio * len(dataset)) # type:ignore[operator, arg-type]
|
|
|
|
|
validation_size = len(dataset) - training_size # type:ignore[arg-type]
|
|
|
|
@@ -115,10 +115,10 @@ def create_trainer() -> Trainer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train() -> None:
|
|
|
|
|
dataset_file_pattern = CONFIG.get('training.dataset', 'image_pattern')
|
|
|
|
|
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
|
|
|
|
|
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
|
|
|
|
|
|
|
|
|
|
dataset = DataLoaderRecognition(dataset_file_pattern)
|
|
|
|
|
dataset = DynamicDataset(dataset_file_pattern)
|
|
|
|
|
training_loader, validation_loader = create_loaders(dataset)
|
|
|
|
|
embedding_converter_trainer = EmbeddingConverterTrainer()
|
|
|
|
|
trainer = create_trainer()
|
|
|
|
|