mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Apply new config approach for embedding converter
This commit is contained in:
@@ -4,12 +4,13 @@ from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .types import Batch
|
||||
from .types import Batch, Config
|
||||
|
||||
|
||||
class StaticDataset(Dataset[Tensor]):
|
||||
def __init__(self, file_pattern : str) -> None:
|
||||
self.file_paths = glob.glob(file_pattern)
|
||||
def __init__(self, config : Config) -> None:
|
||||
self.config = config
|
||||
self.file_paths = glob.glob(self.config.get('file_pattern'))
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
|
||||
@@ -10,15 +10,18 @@ CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def export() -> None:
|
||||
directory_path = CONFIG.get('exporting', 'directory_path')
|
||||
source_path = CONFIG.get('exporting', 'source_path')
|
||||
target_path = CONFIG.get('exporting', 'target_path')
|
||||
ir_version = CONFIG.getint('exporting', 'ir_version')
|
||||
opset_version = CONFIG.getint('exporting', 'opset_version')
|
||||
config =\
|
||||
{
|
||||
'directory_path': CONFIG.get('exporting', 'directory_path'),
|
||||
'source_path': CONFIG.get('exporting', 'source_path'),
|
||||
'target_path': CONFIG.get('exporting', 'target_path'),
|
||||
'ir_version': CONFIG.getint('exporting', 'ir_version'),
|
||||
'opset_version': CONFIG.getint('exporting', 'opset_version')
|
||||
}
|
||||
|
||||
makedirs(directory_path, exist_ok = True)
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
|
||||
makedirs(config.get('directory_path'), exist_ok = True)
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu')
|
||||
model.eval()
|
||||
model.ir_version = torch.tensor(ir_version)
|
||||
model.ir_version = torch.tensor(config.get('ir_version'))
|
||||
input_tensor = torch.randn(1, 512)
|
||||
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
|
||||
torch.onnx.export(model, input_tensor, config.get('target_path'), input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config.get('opset_version'))
|
||||
|
||||
@@ -22,12 +22,16 @@ CONFIG.read('config.ini')
|
||||
class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super(EmbeddingConverterTrainer, self).__init__()
|
||||
source_path = CONFIG.get('training.model', 'source_path')
|
||||
target_path = CONFIG.get('training.model', 'target_path')
|
||||
self.config =\
|
||||
{
|
||||
'source_path': CONFIG.get('training.model', 'source_path'),
|
||||
'target_path': CONFIG.get('training.model', 'target_path'),
|
||||
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
}
|
||||
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.source_embedder = torch.jit.load(self.config.get('source_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.target_embedder = torch.jit.load(self.config.get('target_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_embedding : Embedding) -> Embedding:
|
||||
@@ -51,8 +55,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
return validation_score
|
||||
|
||||
def configure_optimizers(self) -> OptimizerConfig:
|
||||
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr = learning_rate)
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr = self.config.get('learning_rate'))
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
config =\
|
||||
{
|
||||
@@ -70,42 +73,52 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
config =\
|
||||
{
|
||||
'batch_size': CONFIG.getint('training.loader', 'batch_size'),
|
||||
'num_workers': CONFIG.getint('training.loader', 'num_workers')
|
||||
}
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = config.get('batch_size'), shuffle = True, num_workers = config.get('num_workers'), drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config.get('batch_size'), shuffle = False, num_workers = config.get('num_workers'), pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
|
||||
split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
|
||||
config =\
|
||||
{
|
||||
'split_ratio': CONFIG.getfloat('training.loader', 'split_ratio')
|
||||
}
|
||||
|
||||
dataset_size = len(dataset) # type:ignore[arg-type]
|
||||
training_size = int(dataset_size * split_ratio)
|
||||
training_size = int(dataset_size * config.get('split_ratio'))
|
||||
validation_size = int(dataset_size - training_size)
|
||||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
|
||||
output_directory_path = CONFIG.get('training.output', 'directory_path')
|
||||
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
|
||||
trainer_precision = CONFIG.get('training.trainer', 'precision')
|
||||
config =\
|
||||
{
|
||||
'max_epochs': CONFIG.getint('training.trainer', 'max_epochs'),
|
||||
'directory_path': CONFIG.get('training.output', 'directory_path'),
|
||||
'file_pattern': CONFIG.get('training.output', 'file_pattern'),
|
||||
'precision': CONFIG.get('training.trainer', 'precision')
|
||||
}
|
||||
logger = TensorBoardLogger('.logs', name = 'embedding_converter')
|
||||
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = trainer_max_epochs,
|
||||
precision = trainer_precision, # type:ignore[arg-type]
|
||||
max_epochs = config.get('max_epochs'),
|
||||
precision = config.get('precision'), # type:ignore[arg-type]
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
monitor = 'training_loss',
|
||||
dirpath = output_directory_path,
|
||||
filename = output_file_pattern,
|
||||
dirpath = config.get('directory_path'),
|
||||
filename = config.get('file_pattern'),
|
||||
every_n_epochs = 1,
|
||||
save_top_k = 3,
|
||||
save_last = True
|
||||
@@ -115,18 +128,21 @@ def create_trainer() -> Trainer:
|
||||
|
||||
|
||||
def train() -> None:
|
||||
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
|
||||
output_resume_path = CONFIG.get('training.output', 'resume_path')
|
||||
config =\
|
||||
{
|
||||
'file_pattern': CONFIG.get('training.dataset', 'file_pattern'),
|
||||
'resume_path': CONFIG.get('training.output', 'resume_path')
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = StaticDataset(dataset_file_pattern)
|
||||
dataset = StaticDataset(config)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer()
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.exists(output_resume_path):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = output_resume_path)
|
||||
if os.path.exists(config.get('resume_path')):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path'))
|
||||
else:
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any, TypeAlias, Dict
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
Batch : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
|
||||
Config : TypeAlias = Dict[str, Any]
|
||||
OptimizerConfig : TypeAlias = Any
|
||||
|
||||
Reference in New Issue
Block a user