mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Improve imports
This commit is contained in:
@@ -9,10 +9,10 @@ from .types import Batch
|
||||
|
||||
|
||||
class StaticDataset(Dataset[Tensor]):
|
||||
def __init__(self, config : ConfigParser) -> None:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config =\
|
||||
{
|
||||
'file_pattern': config.get('training.dataset', 'file_pattern')
|
||||
'file_pattern': config_parser.get('training.dataset', 'file_pattern')
|
||||
}
|
||||
self.file_paths = glob.glob(self.config.get('file_pattern'))
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import configparser
|
||||
from os import makedirs
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
|
||||
import torch
|
||||
|
||||
from .training import EmbeddingConverterTrainer
|
||||
|
||||
CONFIG_PARSER = configparser.ConfigParser()
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ def export() -> None:
|
||||
'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
}
|
||||
|
||||
makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type]
|
||||
os.makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type]
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu')
|
||||
model.eval()
|
||||
model.ir_version = torch.tensor(config.get('ir_version'))
|
||||
|
||||
Reference in New Issue
Block a user