mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import warnings
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
@@ -179,21 +179,21 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def combine_datasets() -> Dataset[Tensor]:
|
||||
def prepare_datasets(config_parser : ConfigParser) -> List[Dataset[Tensor]]:
|
||||
datasets = []
|
||||
|
||||
for config_section in CONFIG_PARSER.sections():
|
||||
for config_section in config_parser.sections():
|
||||
|
||||
if config_section.startswith('training.dataset'):
|
||||
dataset_config = ConfigParser()
|
||||
dataset_config.add_section('training.dataset')
|
||||
current_config_parser = ConfigParser()
|
||||
current_config_parser.add_section('training.dataset')
|
||||
|
||||
for key, value in CONFIG_PARSER.items(config_section):
|
||||
dataset_config.set('training.dataset', key, value)
|
||||
datasets.append(DynamicDataset(dataset_config))
|
||||
for key, value in config_parser.items(config_section):
|
||||
current_config_parser.set('training.dataset', key, value)
|
||||
|
||||
combine_dataset = ConcatDataset(datasets)
|
||||
return combine_dataset
|
||||
datasets.append(DynamicDataset(current_config_parser))
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
@@ -233,8 +233,8 @@ def train() -> None:
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
combine_dataset = combine_datasets()
|
||||
training_loader, validation_loader = create_loaders(combine_dataset)
|
||||
dataset = ConcatDataset(prepare_datasets(CONFIG_PARSER))
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER)
|
||||
trainer = create_trainer()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user