mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
This should be a StaticDataset, Fix learning rate finder
This commit is contained in:
@@ -9,13 +9,13 @@ from torchvision import transforms
|
||||
from .types import Batch
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
class StaticDataset(Dataset[Tensor]):
|
||||
def __init__(self, file_pattern : str) -> None:
|
||||
self.file_paths = glob.glob(file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = random.choice(self.file_paths)
|
||||
file_path = self.file_paths[index]
|
||||
vision_frame = cv2.imread(file_path)
|
||||
return self.transforms(vision_frame)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from lightning.pytorch.tuner import Tuner
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
|
||||
from .dataset import DynamicDataset
|
||||
from .dataset import StaticDataset
|
||||
from .models.embedding_converter import EmbeddingConverter
|
||||
from .types import Batch, Embedding, OptimizerConfig
|
||||
|
||||
@@ -24,11 +24,13 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
super(EmbeddingConverterTrainer, self).__init__()
|
||||
source_path = CONFIG.get('training.model', 'source_path')
|
||||
target_path = CONFIG.get('training.model', 'target_path')
|
||||
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.mse_loss = nn.MSELoss()
|
||||
self.lr = learning_rate
|
||||
|
||||
def forward(self, source_embedding : Embedding) -> Embedding:
|
||||
return self.embedding_converter(source_embedding)
|
||||
@@ -115,7 +117,7 @@ def train() -> None:
|
||||
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
|
||||
output_resume_path = CONFIG.get('training.output', 'resume_path')
|
||||
|
||||
dataset = DynamicDataset(dataset_file_pattern)
|
||||
dataset = StaticDataset(dataset_file_pattern)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer()
|
||||
trainer = create_trainer()
|
||||
|
||||
Reference in New Issue
Block a user