mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Always config injection
This commit is contained in:
@@ -2,9 +2,8 @@ import configparser
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import lightning
|
||||
import torch
|
||||
from lightning import Trainer
|
||||
from lightning import Trainer, LightningModule
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
@@ -13,22 +12,16 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import StaticDataset
|
||||
from .models.embedding_converter import EmbeddingConverter
|
||||
from .types import Batch, Embedding, OptimizerSet
|
||||
from .types import Batch, Embedding, OptimizerSet, Config
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super(EmbeddingConverterTrainer, self).__init__()
|
||||
self.config =\
|
||||
{
|
||||
'source_path': CONFIG.get('training.model', 'source_path'),
|
||||
'target_path': CONFIG.get('training.model', 'target_path'),
|
||||
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
}
|
||||
|
||||
class EmbeddingConverterTrainer(LightningModule):
|
||||
def __init__(self, config : Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
self.source_embedder = torch.jit.load(self.config.get('source_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.target_embedder = torch.jit.load(self.config.get('target_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
@@ -128,21 +121,27 @@ def create_trainer() -> Trainer:
|
||||
|
||||
|
||||
def train() -> None:
|
||||
config =\
|
||||
config_dataset =\
|
||||
{
|
||||
'file_pattern': CONFIG.get('training.dataset', 'file_pattern'),
|
||||
'resume_path': CONFIG.get('training.output', 'resume_path')
|
||||
}
|
||||
config_trainer =\
|
||||
{
|
||||
'source_path': CONFIG.get('training.model', 'source_path'),
|
||||
'target_path': CONFIG.get('training.model', 'target_path'),
|
||||
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = StaticDataset(config)
|
||||
dataset = StaticDataset(config_dataset)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer()
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer(config_trainer)
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.exists(config.get('resume_path')):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path'))
|
||||
if os.path.exists(config_dataset.get('resume_path')):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_dataset.get('resume_path'))
|
||||
else:
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
|
||||
|
||||
Reference in New Issue
Block a user