mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Cache the usage of glob.glob (#80)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from configparser import ConfigParser
|
||||
@@ -9,7 +8,7 @@ from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .helper import convert_tensor
|
||||
from .helper import convert_tensor, resolve_static_file_pattern
|
||||
from .types import Batch, BatchMode, ConvertTemplate, UsageMode
|
||||
|
||||
|
||||
@@ -23,10 +22,9 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
self.config_batch_ratio = config_parser.getfloat('training.dataset.current', 'batch_ratio')
|
||||
self.config_parser = config_parser
|
||||
self.transforms = self.compose_transforms()
|
||||
self.file_paths = glob.glob(self.config_file_pattern)
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = self.file_paths[index]
|
||||
file_path = resolve_static_file_pattern(self.config_file_pattern)[index]
|
||||
|
||||
if random.random() < self.config_batch_ratio:
|
||||
if self.config_batch_mode == 'equal':
|
||||
@@ -43,7 +41,7 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
return self.prepare_different_batch(file_path)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.file_paths)
|
||||
return len(resolve_static_file_pattern(self.config_file_pattern))
|
||||
|
||||
def compose_transforms(self) -> transforms:
|
||||
return transforms.Compose(
|
||||
@@ -79,7 +77,7 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
config_file_pattern = config_parser.get(config_section, 'file_pattern')
|
||||
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
|
||||
|
||||
target_path = random.choice(glob.glob(config_file_pattern))
|
||||
target_path = random.choice(resolve_static_file_pattern(config_file_pattern))
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
|
||||
@@ -94,7 +92,7 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
config_file_pattern = config_parser.get(config_section, 'file_pattern')
|
||||
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
|
||||
|
||||
source_path = random.choice(glob.glob(config_file_pattern))
|
||||
source_path = random.choice(resolve_static_file_pattern(config_file_pattern))
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, config_convert_template)
|
||||
@@ -104,7 +102,7 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(self.file_paths)
|
||||
target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern))
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
import glob
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
@@ -55,3 +59,8 @@ def apply_noise(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
noise_tensor = torch.randn_like(input_tensor) * factor
|
||||
output_tensor = input_tensor + noise_tensor
|
||||
return output_tensor
|
||||
|
||||
|
||||
@lru_cache(maxsize = None)
|
||||
def resolve_static_file_pattern(file_pattern : str) -> List[str]:
|
||||
return sorted(glob.glob(file_pattern))
|
||||
|
||||
Reference in New Issue
Block a user