mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Fix CI
This commit is contained in:
@@ -13,7 +13,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import StaticDataset
|
||||
from .models.embedding_converter import EmbeddingConverter
|
||||
from .types import Batch, Embedding, OptimizerConfig
|
||||
from .types import Batch, Embedding, OptimizerSet
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -54,7 +54,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
self.log('validation_score', validation_score, prog_bar = True)
|
||||
return validation_score
|
||||
|
||||
def configure_optimizers(self) -> OptimizerConfig:
|
||||
def configure_optimizers(self) -> OptimizerSet:
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr = self.config.get('learning_rate'))
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
config =\
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, TypeAlias, Dict
|
||||
from typing import Any, Dict, TypeAlias
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
@@ -6,4 +6,4 @@ Batch : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
|
||||
Config : TypeAlias = Dict[str, Any]
|
||||
OptimizerConfig : TypeAlias = Any
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
Reference in New Issue
Block a user