This should be a StaticDataset, Fix learning rate finder

This commit is contained in:
henryruhs
2025-02-24 10:25:19 +01:00
parent 257e5e56a4
commit f5cd6b6336
2 changed files with 6 additions and 4 deletions
+2 -2
View File
@@ -9,13 +9,13 @@ from torchvision import transforms
from .types import Batch
class DynamicDataset(Dataset[Tensor]):
class StaticDataset(Dataset[Tensor]):
def __init__(self, file_pattern : str) -> None:
self.file_paths = glob.glob(file_pattern)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
file_path = random.choice(self.file_paths)
file_path = self.file_paths[index]
vision_frame = cv2.imread(file_path)
return self.transforms(vision_frame)
+4 -2
View File
@@ -11,7 +11,7 @@ from lightning.pytorch.tuner import Tuner
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset, random_split
from .dataset import DynamicDataset
from .dataset import StaticDataset
from .models.embedding_converter import EmbeddingConverter
from .types import Batch, Embedding, OptimizerConfig
@@ -24,11 +24,13 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
super(EmbeddingConverterTrainer, self).__init__()
source_path = CONFIG.get('training.model', 'source_path')
target_path = CONFIG.get('training.model', 'target_path')
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
self.embedding_converter = EmbeddingConverter()
self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.mse_loss = nn.MSELoss()
self.lr = learning_rate
def forward(self, source_embedding : Embedding) -> Embedding:
return self.embedding_converter(source_embedding)
@@ -115,7 +117,7 @@ def train() -> None:
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
output_resume_path = CONFIG.get('training.output', 'resume_path')
dataset = DynamicDataset(dataset_file_pattern)
dataset = StaticDataset(dataset_file_pattern)
training_loader, validation_loader = create_loaders(dataset)
embedding_converter_trainer = EmbeddingConverterTrainer()
trainer = create_trainer()