Uniform resume checkpoint approach

This commit is contained in:
henryruhs
2025-02-17 23:58:33 +01:00
parent 7b2b8f0f85
commit 11bb9065ba
6 changed files with 27 additions and 19 deletions
+2 -1
View File
@@ -61,7 +61,8 @@ max_epochs = 4096
```
[training.output]
directory_path = .outputs
file_pattern = arcface_converter_simswap_{epoch:02d}_{val_loss:.4f}
file_pattern = arcface_converter_simswap_{epoch}_{step}
resume_file_path = .outputs/last.ckpt
```
```
+1
View File
@@ -24,6 +24,7 @@ max_epochs =
[training.output]
directory_path =
file_pattern =
resume_file_path =
[exporting]
directory_path =
+9 -2
View File
@@ -1,4 +1,5 @@
import configparser
import os
from typing import Any, Tuple
import lightning
@@ -111,9 +112,15 @@ def create_trainer() -> Trainer:
def train() -> None:
trainer = create_trainer()
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
training_loader, validation_loader = create_loaders()
embedding_converter_trainer = EmbeddingConverterTrainer()
trainer = create_trainer()
tuner = Tuner(trainer)
tuner.lr_find(embedding_converter_trainer, training_loader, validation_loader)
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
if os.path.exists(resume_file_path):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = resume_file_path)
else:
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
+3 -4
View File
@@ -77,15 +77,14 @@ learning_rate = 0.0004
max_epochs = 50
precision = 16-mixed
automatic_optimization = false
preview_frequency = 250
```
```
[training.output]
directory_path = .outputs
file_path = .outputs/last.ckpt
file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}'
preview_frequency = 250
validation_frequency = 1000
file_pattern = 'face-swapper_{epoch}_{step}'
resume_file_path = .outputs/last.ckpt
```
```
+2 -3
View File
@@ -37,13 +37,12 @@ learning_rate =
max_epochs =
precision =
automatic_optimization =
preview_frequency =
[training.output]
directory_path =
file_path =
file_pattern =
preview_frequency =
validation_frequency =
resume_file_path =
[exporting]
directory_path =
+10 -9
View File
@@ -23,7 +23,7 @@ CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss):
class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
def __init__(self) -> None:
super().__init__()
FaceSwapperLoss.__init__(self)
@@ -62,7 +62,7 @@ class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss):
self.manual_backward(discriminator_losses.get('loss_discriminator'))
discriminator_optimizer.step()
if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0:
if self.global_step % CONFIG.getint('training.trainer', 'preview_frequency') == 0:
self.generate_preview(source_tensor, target_tensor, swap_tensor)
self.log('loss_generator', generator_losses.get('loss_generator'), prog_bar = True)
@@ -103,7 +103,7 @@ def create_trainer() -> Trainer:
dirpath = output_directory_path,
filename = output_file_pattern,
every_n_train_steps = 1000,
save_top_k = 5,
save_top_k = 3,
save_last = True
)
],
@@ -118,13 +118,14 @@ def train() -> None:
same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability')
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
output_file_path = CONFIG.get('training.output', 'file_path')
if not os.path.isfile(output_file_path):
output_file_path = None
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
face_swap_model = FaceSwapperTrain()
face_swapper_trainer = FaceSwapperTrainer()
trainer = create_trainer()
trainer.fit(face_swap_model, data_loader, ckpt_path = output_file_path)
if os.path.isfile(resume_file_path):
trainer.fit(face_swapper_trainer, data_loader, ckpt_path = resume_file_path)
else:
trainer.fit(face_swapper_trainer, data_loader)