Use lightning over pytorch_lightning import, Configure tensorboard

This commit is contained in:
henryruhs
2025-02-16 11:56:01 +01:00
parent cd4b10c832
commit d9e10a9f7c
5 changed files with 20 additions and 12 deletions
+1
View File
@@ -4,6 +4,7 @@ __pycache__
.idea
.inputs
.exports
.logs
.models
.outputs
.vscode
+7 -5
View File
@@ -1,12 +1,12 @@
import configparser
from typing import Any, Tuple
import lightning
import numpy
import pytorch_lightning
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.tuner.tuning import Tuner
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
@@ -17,7 +17,7 @@ CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class EmbeddingConverterTrainer(pytorch_lightning.LightningModule):
class EmbeddingConverterTrainer(lightning.LightningModule):
def __init__(self) -> None:
super(EmbeddingConverterTrainer, self).__init__()
self.embedding_converter = EmbeddingConverter()
@@ -88,8 +88,10 @@ def create_trainer() -> Trainer:
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
output_directory_path = CONFIG.get('training.output', 'directory_path')
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
logger = TensorBoardLogger('.logs', name = 'embedding_converter')
return Trainer(
logger = logger,
max_epochs = trainer_max_epochs,
callbacks =
[
+2 -2
View File
@@ -35,8 +35,8 @@ same_person_probability = 0.2
```
[training.loader]
batch_size = 24
num_workers = 12
batch_size = 8
num_workers = 8
```
```
+9 -5
View File
@@ -2,13 +2,15 @@ import configparser
import os
from typing import Tuple
import pytorch_lightning
import lightning
import torch
import torchvision
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.types import Optimizer
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from .data_loader import DataLoaderVGG
@@ -22,7 +24,7 @@ CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss):
def __init__(self) -> None:
super().__init__()
FaceSwapperLoss.__init__(self)
@@ -86,9 +88,11 @@ def create_trainer() -> Trainer:
output_directory_path = CONFIG.get('training.output', 'directory_path')
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
trainer_precision = CONFIG.get('training.trainer', 'precision')
logger = TensorBoardLogger('.logs', name = 'face_swapper')
os.makedirs(output_directory_path, exist_ok = True)
return Trainer(
logger = logger,
max_epochs = trainer_max_epochs,
precision = trainer_precision,
callbacks =
+1
View File
@@ -8,3 +8,4 @@ mxnet==1.9.1
pytorch-msssim==1.0.0
torch==2.6.0
torchvision==0.21.0
tensorboard==2.19.0