mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Use lightning over pytorch_lightning import, Configure tensorboard
This commit is contained in:
@@ -4,6 +4,7 @@ __pycache__
|
||||
.idea
|
||||
.inputs
|
||||
.exports
|
||||
.logs
|
||||
.models
|
||||
.outputs
|
||||
.vscode
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import configparser
|
||||
from typing import Any, Tuple
|
||||
|
||||
import lightning
|
||||
import numpy
|
||||
import pytorch_lightning
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
|
||||
|
||||
@@ -17,7 +17,7 @@ CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class EmbeddingConverterTrainer(pytorch_lightning.LightningModule):
|
||||
class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super(EmbeddingConverterTrainer, self).__init__()
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
@@ -88,8 +88,10 @@ def create_trainer() -> Trainer:
|
||||
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
|
||||
output_directory_path = CONFIG.get('training.output', 'directory_path')
|
||||
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
|
||||
logger = TensorBoardLogger('.logs', name = 'embedding_converter')
|
||||
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
max_epochs = trainer_max_epochs,
|
||||
callbacks =
|
||||
[
|
||||
|
||||
@@ -35,8 +35,8 @@ same_person_probability = 0.2
|
||||
|
||||
```
|
||||
[training.loader]
|
||||
batch_size = 24
|
||||
num_workers = 12
|
||||
batch_size = 8
|
||||
num_workers = 8
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -2,13 +2,15 @@ import configparser
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import pytorch_lightning
|
||||
import lightning
|
||||
import torch
|
||||
import torchvision
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities.types import Optimizer
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .data_loader import DataLoaderVGG
|
||||
@@ -22,7 +24,7 @@ CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
|
||||
class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
FaceSwapperLoss.__init__(self)
|
||||
@@ -86,9 +88,11 @@ def create_trainer() -> Trainer:
|
||||
output_directory_path = CONFIG.get('training.output', 'directory_path')
|
||||
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
|
||||
trainer_precision = CONFIG.get('training.trainer', 'precision')
|
||||
logger = TensorBoardLogger('.logs', name = 'face_swapper')
|
||||
|
||||
os.makedirs(output_directory_path, exist_ok = True)
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
max_epochs = trainer_max_epochs,
|
||||
precision = trainer_precision,
|
||||
callbacks =
|
||||
|
||||
@@ -8,3 +8,4 @@ mxnet==1.9.1
|
||||
pytorch-msssim==1.0.0
|
||||
torch==2.6.0
|
||||
torchvision==0.21.0
|
||||
tensorboard==2.19.0
|
||||
|
||||
Reference in New Issue
Block a user