From 6ca68f14085bac663d158ad0ca035f5b6e7804c2 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Mon, 3 Mar 2025 09:45:02 +0100 Subject: [PATCH] Add config to dataloader alignment --- face_swapper/src/training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 3e88f72..4324b68 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -1,6 +1,6 @@ import configparser import os -from typing import Tuple +from typing import cast, Tuple import lightning import torch @@ -17,7 +17,7 @@ from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss -from .types import Batch, Embedding, OptimizerConfig +from .types import Batch, Embedding, OptimizerConfig, WarpMatrix CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -194,7 +194,7 @@ def create_trainer() -> Trainer: def train() -> None: dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') - dataset_warp_matrix = CONFIG.get('training.dataset', 'warp_matrix') + dataset_warp_matrix = cast(WarpMatrix, CONFIG.get('training.dataset', 'warp_matrix')) dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio') output_resume_path = CONFIG.get('training.output', 'resume_path')