From 11bb9065ba6201bc198a9074f8fed6a10dde4675 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Mon, 17 Feb 2025 23:58:33 +0100 Subject: [PATCH] Uniform resume checkpoint approach --- embedding_converter/README.md | 3 ++- embedding_converter/config.ini | 1 + embedding_converter/src/training.py | 11 +++++++++-- face_swapper/README.md | 7 +++---- face_swapper/config.ini | 5 ++--- face_swapper/src/training.py | 19 ++++++++++--------- 6 files changed, 27 insertions(+), 19 deletions(-) diff --git a/embedding_converter/README.md b/embedding_converter/README.md index 1666bc1..e323b4a 100644 --- a/embedding_converter/README.md +++ b/embedding_converter/README.md @@ -61,7 +61,8 @@ max_epochs = 4096 ``` [training.output] directory_path = .outputs -file_pattern = arcface_converter_simswap_{epoch:02d}_{val_loss:.4f} +file_pattern = arcface_converter_simswap_{epoch}_{step} +resume_file_path = .outputs/last.ckpt ``` ``` diff --git a/embedding_converter/config.ini b/embedding_converter/config.ini index 71166fd..bbcf4ca 100644 --- a/embedding_converter/config.ini +++ b/embedding_converter/config.ini @@ -24,6 +24,7 @@ max_epochs = [training.output] directory_path = file_pattern = +resume_file_path = [exporting] directory_path = diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index d95035f..a87848c 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -1,4 +1,5 @@ import configparser +import os from typing import Any, Tuple import lightning @@ -111,9 +112,15 @@ def create_trainer() -> Trainer: def train() -> None: - trainer = create_trainer() + resume_file_path = CONFIG.get('training.output', 'resume_file_path') + training_loader, validation_loader = create_loaders() embedding_converter_trainer = EmbeddingConverterTrainer() + trainer = create_trainer() tuner = Tuner(trainer) tuner.lr_find(embedding_converter_trainer, training_loader, validation_loader) - trainer.fit(embedding_converter_trainer, training_loader, validation_loader) + + if os.path.exists(resume_file_path): + trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = resume_file_path) + else: + trainer.fit(embedding_converter_trainer, training_loader, validation_loader) diff --git a/face_swapper/README.md b/face_swapper/README.md index 5a64cce..b331e2a 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -77,15 +77,14 @@ learning_rate = 0.0004 max_epochs = 50 precision = 16-mixed automatic_optimization = false +preview_frequency = 250 ``` ``` [training.output] directory_path = .outputs -file_path = .outputs/last.ckpt -file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}' -preview_frequency = 250 -validation_frequency = 1000 +file_pattern = 'face-swapper_{epoch}_{step}' +resume_file_path = .outputs/last.ckpt ``` ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 92ee043..10f875f 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -37,13 +37,12 @@ learning_rate = max_epochs = precision = automatic_optimization = +preview_frequency = [training.output] directory_path = -file_path = file_pattern = -preview_frequency = -validation_frequency = +resume_file_path = [exporting] directory_path = diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 6a62ef1..1b2852f 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -23,7 +23,7 @@ CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss): +class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): def __init__(self) -> None: super().__init__() FaceSwapperLoss.__init__(self) @@ -62,7 +62,7 @@ class FaceSwapperTrain(lightning.LightningModule, FaceSwapperLoss): self.manual_backward(discriminator_losses.get('loss_discriminator')) discriminator_optimizer.step() - if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0: + if self.global_step % CONFIG.getint('training.trainer', 'preview_frequency') == 0: self.generate_preview(source_tensor, target_tensor, swap_tensor) self.log('loss_generator', generator_losses.get('loss_generator'), prog_bar = True) @@ -103,7 +103,7 @@ def create_trainer() -> Trainer: dirpath = output_directory_path, filename = output_file_pattern, every_n_train_steps = 1000, - save_top_k = 5, + save_top_k = 3, save_last = True ) ], @@ -118,13 +118,14 @@ def train() -> None: same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability') batch_size = CONFIG.getint('training.loader', 'batch_size') num_workers = CONFIG.getint('training.loader', 'num_workers') - output_file_path = CONFIG.get('training.output', 'file_path') - - if not os.path.isfile(output_file_path): - output_file_path = None + resume_file_path = CONFIG.get('training.output', 'resume_file_path') dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability) data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) - face_swap_model = FaceSwapperTrain() + face_swapper_trainer = FaceSwapperTrainer() trainer = create_trainer() - trainer.fit(face_swap_model, data_loader, ckpt_path = output_file_path) + + if os.path.isfile(resume_file_path): + trainer.fit(face_swapper_trainer, data_loader, ckpt_path = resume_file_path) + else: + trainer.fit(face_swapper_trainer, data_loader)