mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
More adjustments
This commit is contained in:
@@ -95,9 +95,9 @@ def create_trainer() -> Trainer:
|
||||
config =\
|
||||
{
|
||||
'max_epochs': CONFIG.getint('training.trainer', 'max_epochs'),
|
||||
'precision': CONFIG.get('training.trainer', 'precision'),
|
||||
'directory_path': CONFIG.get('training.output', 'directory_path'),
|
||||
'file_pattern': CONFIG.get('training.output', 'file_pattern'),
|
||||
'precision': CONFIG.get('training.trainer', 'precision')
|
||||
'file_pattern': CONFIG.get('training.output', 'file_pattern')
|
||||
}
|
||||
logger = TensorBoardLogger('.logs', name = 'embedding_converter')
|
||||
|
||||
@@ -105,7 +105,7 @@ def create_trainer() -> Trainer:
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = config.get('max_epochs'),
|
||||
precision = config.get('precision'), # type:ignore[arg-type]
|
||||
precision = config.get('precision'),
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
@@ -121,30 +121,34 @@ def create_trainer() -> Trainer:
|
||||
|
||||
|
||||
def train() -> None:
|
||||
config_dataset =\
|
||||
config =\
|
||||
{
|
||||
'file_pattern': CONFIG.get('training.dataset', 'file_pattern'),
|
||||
}
|
||||
config_trainer =\
|
||||
{
|
||||
'source_path': CONFIG.get('training.model', 'source_path'),
|
||||
'target_path': CONFIG.get('training.model', 'target_path'),
|
||||
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
}
|
||||
config_common =\
|
||||
{
|
||||
'resume_path': CONFIG.get('training.output', 'resume_path')
|
||||
'dataset':
|
||||
{
|
||||
'file_pattern': CONFIG.get('training.dataset', 'file_pattern')
|
||||
},
|
||||
'trainer':
|
||||
{
|
||||
'source_path': CONFIG.get('training.model', 'source_path'),
|
||||
'target_path': CONFIG.get('training.model', 'target_path'),
|
||||
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
},
|
||||
'common':
|
||||
{
|
||||
'resume_path': CONFIG.get('training.output', 'resume_path')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = StaticDataset(config_dataset)
|
||||
dataset = StaticDataset(config.get('dataset'))
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer(config_trainer)
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer(config.get('trainer'))
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.exists(config_common.get('resume_path')):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_common.get('resume_path'))
|
||||
if os.path.exists(config.get('common').get('resume_path')):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('common').get('resume_path'))
|
||||
else:
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
|
||||
|
||||
Reference in New Issue
Block a user