mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Fix CI
This commit is contained in:
@@ -19,7 +19,7 @@ def export() -> None:
|
||||
'opset_version': CONFIG.getint('exporting', 'opset_version')
|
||||
}
|
||||
|
||||
makedirs(config.get('directory_path'), exist_ok = True)
|
||||
makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type]
|
||||
model = EmbeddingConverterTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu')
|
||||
model.eval()
|
||||
model.ir_version = torch.tensor(config.get('ir_version'))
|
||||
|
||||
@@ -18,7 +18,7 @@ from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MotionLoss, ReconstructionLoss
|
||||
from .types import Batch, BatchMode, Embedding, OptimizerConfig, WarpTemplate
|
||||
from .types import Batch, BatchMode, Embedding, OptimizerSet, WarpTemplate
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
|
||||
@@ -52,7 +52,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
output_tensor = self.generator(source_embedding, target_tensor)
|
||||
return output_tensor
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerConfig, OptimizerConfig]:
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
||||
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
|
||||
@@ -17,7 +17,8 @@ EmbedderModule : TypeAlias = Module
|
||||
GazerModule : TypeAlias = Module
|
||||
MotionExtractorModule : TypeAlias = Module
|
||||
|
||||
OptimizerConfig : TypeAlias = Any
|
||||
Config : TypeAlias = Dict[str, Any]
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
WarpTemplate = Literal['vgg_face_hq_to_arcface_128_v2', 'arcface_128_v2_to_arcface_112_v2']
|
||||
WarpTemplateSet : TypeAlias = Dict[WarpTemplate, Tensor]
|
||||
|
||||
Reference in New Issue
Block a user