This commit is contained in:
henryruhs
2025-03-06 08:47:59 +01:00
parent e61e470432
commit 368da824aa
3 changed files with 5 additions and 4 deletions
+1 -1
View File
@@ -19,7 +19,7 @@ def export() -> None:
'opset_version': CONFIG.getint('exporting', 'opset_version')
}
makedirs(config.get('directory_path'), exist_ok = True)
makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type]
model = EmbeddingConverterTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu')
model.eval()
model.ir_version = torch.tensor(config.get('ir_version'))
+2 -2
View File
@@ -18,7 +18,7 @@ from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MotionLoss, ReconstructionLoss
from .types import Batch, BatchMode, Embedding, OptimizerConfig, WarpTemplate
from .types import Batch, BatchMode, Embedding, OptimizerSet, WarpTemplate
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -52,7 +52,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
output_tensor = self.generator(source_embedding, target_tensor)
return output_tensor
def configure_optimizers(self) -> Tuple[OptimizerConfig, OptimizerConfig]:
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
+2 -1
View File
@@ -17,7 +17,8 @@ EmbedderModule : TypeAlias = Module
GazerModule : TypeAlias = Module
MotionExtractorModule : TypeAlias = Module
OptimizerConfig : TypeAlias = Any
Config : TypeAlias = Dict[str, Any]
OptimizerSet : TypeAlias = Any
WarpTemplate = Literal['vgg_face_hq_to_arcface_128_v2', 'arcface_128_v2_to_arcface_112_v2']
WarpTemplateSet : TypeAlias = Dict[WarpTemplate, Tensor]