mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Add config to dataloader alignment
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import configparser
|
||||
import os
|
||||
from typing import Tuple
|
||||
from typing import cast, Tuple
|
||||
|
||||
import lightning
|
||||
import torch
|
||||
@@ -17,7 +17,7 @@ from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, OptimizerConfig
|
||||
from .types import Batch, Embedding, OptimizerConfig, WarpMatrix
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -194,7 +194,7 @@ def create_trainer() -> Trainer:
|
||||
|
||||
def train() -> None:
|
||||
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
|
||||
dataset_warp_matrix = CONFIG.get('training.dataset', 'warp_matrix')
|
||||
dataset_warp_matrix = cast(WarpMatrix, CONFIG.get('training.dataset', 'warp_matrix'))
|
||||
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
|
||||
output_resume_path = CONFIG.get('training.output', 'resume_path')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user