mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Introduce batch mode via config for equal and same batches
This commit is contained in:
@@ -29,6 +29,7 @@ This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model.
|
||||
[training.dataset]
|
||||
file_pattern = .datasets/vggface2/**/*.jpg
|
||||
warp_template = vgg_face_hq_to_arcface_128_v2
|
||||
batch_mode = same
|
||||
batch_ratio = 0.2
|
||||
```
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
[training.dataset]
|
||||
file_pattern =
|
||||
warp_template =
|
||||
batch_mode =
|
||||
batch_ratio =
|
||||
|
||||
[training.loader]
|
||||
|
||||
@@ -7,21 +7,25 @@ from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .helper import warp_tensor
|
||||
from .types import Batch, WarpTemplate
|
||||
from .types import Batch, BatchMode, WarpTemplate
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_ratio : float) -> None:
|
||||
def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_mode : BatchMode, batch_ratio : float) -> None:
|
||||
self.file_paths = glob.glob(file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
self.warp_template = warp_template
|
||||
self.batch_mode = batch_mode
|
||||
self.batch_ratio = batch_ratio
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = self.file_paths[index]
|
||||
|
||||
if random.random() < self.batch_ratio:
|
||||
return self.prepare_equal_batch(file_path)
|
||||
if self.batch_mode == 'equal':
|
||||
return self.prepare_equal_batch(file_path)
|
||||
if self.batch_mode == 'same':
|
||||
return self.prepare_same_batch(file_path)
|
||||
|
||||
return self.prepare_different_batch(file_path)
|
||||
|
||||
@@ -51,6 +55,11 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
target_tensor = self.transforms(target_tensor)
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_same_batch(self, source_path : str) -> Batch:
|
||||
source_tensor = io.read_image(source_path)
|
||||
source_tensor = self.transforms(source_tensor)
|
||||
return source_tensor, source_tensor
|
||||
|
||||
def prepare_equal_batch(self, source_path : str) -> Batch:
|
||||
target_directory_path = os.path.dirname(source_path)
|
||||
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
|
||||
|
||||
@@ -18,7 +18,7 @@ from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, OptimizerConfig, WarpTemplate
|
||||
from .types import Batch, BatchMode, Embedding, OptimizerConfig, WarpTemplate
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
|
||||
@@ -198,13 +198,14 @@ def create_trainer() -> Trainer:
|
||||
def train() -> None:
|
||||
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
|
||||
dataset_warp_template = cast(WarpTemplate, CONFIG.get('training.dataset', 'warp_template'))
|
||||
dataset_batch_mode = cast(BatchMode, CONFIG.get('training.dataset', 'batch_mode'))
|
||||
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
|
||||
output_resume_path = CONFIG.get('training.output', 'resume_path')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_ratio)
|
||||
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_mode, dataset_batch_ratio)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
face_swapper_trainer = FaceSwapperTrainer()
|
||||
trainer = create_trainer()
|
||||
|
||||
@@ -4,6 +4,7 @@ from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
Batch : TypeAlias = Tuple[Tensor, Tensor]
|
||||
BatchMode = Literal['equal', 'same']
|
||||
|
||||
Attributes : TypeAlias = Tuple[Tensor, ...]
|
||||
Embedding : TypeAlias = Tensor
|
||||
|
||||
Reference in New Issue
Block a user