mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Uniform resume checkpoint approach
This commit is contained in:
@@ -61,7 +61,8 @@ max_epochs = 4096
|
||||
```
|
||||
[training.output]
|
||||
directory_path = .outputs
|
||||
file_pattern = arcface_converter_simswap_{epoch:02d}_{val_loss:.4f}
|
||||
file_pattern = arcface_converter_simswap_{epoch}_{step}
|
||||
resume_file_path = .outputs/last.ckpt
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -24,6 +24,7 @@ max_epochs =
|
||||
[training.output]
|
||||
directory_path =
|
||||
file_pattern =
|
||||
resume_file_path =
|
||||
|
||||
[exporting]
|
||||
directory_path =
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import configparser
|
||||
import os
|
||||
from typing import Any, Tuple
|
||||
|
||||
import lightning
|
||||
@@ -111,9 +112,15 @@ def create_trainer() -> Trainer:
|
||||
|
||||
|
||||
def train() -> None:
|
||||
trainer = create_trainer()
|
||||
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
|
||||
|
||||
training_loader, validation_loader = create_loaders()
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer()
|
||||
trainer = create_trainer()
|
||||
tuner = Tuner(trainer)
|
||||
tuner.lr_find(embedding_converter_trainer, training_loader, validation_loader)
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
|
||||
|
||||
if os.path.exists(resume_file_path):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = resume_file_path)
|
||||
else:
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
|
||||
|
||||
@@ -77,15 +77,14 @@ learning_rate = 0.0004
|
||||
max_epochs = 50
|
||||
precision = 16-mixed
|
||||
automatic_optimization = false
|
||||
preview_frequency = 250
|
||||
```
|
||||
|
||||
```
|
||||
[training.output]
|
||||
directory_path = .outputs
|
||||
file_path = .outputs/last.ckpt
|
||||
file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}'
|
||||
preview_frequency = 250
|
||||
validation_frequency = 1000
|
||||
file_pattern = 'face-swapper_{epoch}_{step}'
|
||||
resume_file_path = .outputs/last.ckpt
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -37,13 +37,12 @@ learning_rate =
|
||||
max_epochs =
|
||||
precision =
|
||||
automatic_optimization =
|
||||
preview_frequency =
|
||||
|
||||
[training.output]
|
||||
directory_path =
|
||||
file_path =
|
||||
file_pattern =
|
||||
preview_frequency =
|
||||
validation_frequency =
|
||||
resume_file_path =
|
||||
|
||||
[exporting]
|
||||
directory_path =
|
||||
|
||||
@@ -23,7 +23,7 @@ CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss):
|
||||
class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
FaceSwapperLoss.__init__(self)
|
||||
@@ -62,7 +62,7 @@ class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.manual_backward(discriminator_losses.get('loss_discriminator'))
|
||||
discriminator_optimizer.step()
|
||||
|
||||
if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0:
|
||||
if self.global_step % CONFIG.getint('training.trainer', 'preview_frequency') == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, swap_tensor)
|
||||
|
||||
self.log('loss_generator', generator_losses.get('loss_generator'), prog_bar = True)
|
||||
@@ -103,7 +103,7 @@ def create_trainer() -> Trainer:
|
||||
dirpath = output_directory_path,
|
||||
filename = output_file_pattern,
|
||||
every_n_train_steps = 1000,
|
||||
save_top_k = 5,
|
||||
save_top_k = 3,
|
||||
save_last = True
|
||||
)
|
||||
],
|
||||
@@ -118,13 +118,14 @@ def train() -> None:
|
||||
same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability')
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
output_file_path = CONFIG.get('training.output', 'file_path')
|
||||
|
||||
if not os.path.isfile(output_file_path):
|
||||
output_file_path = None
|
||||
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
|
||||
|
||||
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
|
||||
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
face_swap_model = FaceSwapperTrain()
|
||||
face_swapper_trainer = FaceSwapperTrainer()
|
||||
trainer = create_trainer()
|
||||
trainer.fit(face_swap_model, data_loader, ckpt_path = output_file_path)
|
||||
|
||||
if os.path.isfile(resume_file_path):
|
||||
trainer.fit(face_swapper_trainer, data_loader, ckpt_path = resume_file_path)
|
||||
else:
|
||||
trainer.fit(face_swapper_trainer, data_loader)
|
||||
|
||||
Reference in New Issue
Block a user