From d9e10a9f7c0a0b7e56814565bf8f5ef6b762c443 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sun, 16 Feb 2025 11:56:01 +0100 Subject: [PATCH] Use lightning over pytorch_lightning import, Configure tensorboard --- .gitignore | 1 + embedding_converter/src/training.py | 12 +++++++----- face_swapper/README.md | 4 ++-- face_swapper/src/training.py | 14 +++++++++----- requirements.txt | 1 + 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 0fb04d8..454001b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__ .idea .inputs .exports +.logs .models .outputs .vscode diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 13b554b..5a198ba 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -1,12 +1,12 @@ import configparser from typing import Any, Tuple +import lightning import numpy -import pytorch_lightning import torch -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.tuner.tuning import Tuner +from lightning import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split @@ -17,7 +17,7 @@ CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -class EmbeddingConverterTrainer(pytorch_lightning.LightningModule): +class EmbeddingConverterTrainer(lightning.LightningModule): def __init__(self) -> None: super(EmbeddingConverterTrainer, self).__init__() self.embedding_converter = EmbeddingConverter() @@ -88,8 +88,10 @@ def create_trainer() -> Trainer: trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs') output_directory_path = CONFIG.get('training.output', 'directory_path') output_file_pattern = CONFIG.get('training.output', 'file_pattern') + logger = TensorBoardLogger('.logs', name = 'embedding_converter') return Trainer( + logger = logger, max_epochs = trainer_max_epochs, callbacks = [ diff --git a/face_swapper/README.md b/face_swapper/README.md index 23f6c30..fc11333 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -35,8 +35,8 @@ same_person_probability = 0.2 ``` [training.loader] -batch_size = 24 -num_workers = 12 +batch_size = 8 +num_workers = 8 ``` ``` diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 629048b..2911f72 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -2,13 +2,15 @@ import configparser import os from typing import Tuple -import pytorch_lightning +import lightning import torch import torchvision -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.utilities.types import Optimizer +from lightning import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger + from torch import Tensor +from torch.optim import Optimizer from torch.utils.data import DataLoader from .data_loader import DataLoaderVGG @@ -22,7 +24,7 @@ CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss): +class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss): def __init__(self) -> None: super().__init__() FaceSwapperLoss.__init__(self) @@ -86,9 +88,11 @@ def create_trainer() -> Trainer: output_directory_path = CONFIG.get('training.output', 'directory_path') output_file_pattern = CONFIG.get('training.output', 'file_pattern') trainer_precision = CONFIG.get('training.trainer', 'precision') + logger = TensorBoardLogger('.logs', name = 'face_swapper') os.makedirs(output_directory_path, exist_ok = True) return Trainer( + logger = logger, max_epochs = trainer_max_epochs, precision = trainer_precision, callbacks = diff --git a/requirements.txt b/requirements.txt index 722121a..3f6b720 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ mxnet==1.9.1 pytorch-msssim==1.0.0 torch==2.6.0 torchvision==0.21.0 +tensorboard==2.19.0