mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
apply crossface
This commit is contained in:
+9
-1
@@ -45,10 +45,18 @@ target_path = .models/arcface_simswap.pt
|
||||
|
||||
```
|
||||
[training.trainer]
|
||||
learning_rate = 0.001
|
||||
max_epochs = 4096
|
||||
strategy = auto
|
||||
precision = 16-mixed
|
||||
```
|
||||
|
||||
```
|
||||
[training.optimizer]
|
||||
learning_rate = 0.001
|
||||
```
|
||||
|
||||
```
|
||||
[training.logger]
|
||||
logger_path = .logs
|
||||
logger_name = crossface_simswap
|
||||
```
|
||||
|
||||
@@ -11,10 +11,14 @@ source_path =
|
||||
target_path =
|
||||
|
||||
[training.trainer]
|
||||
learning_rate =
|
||||
max_epochs =
|
||||
strategy =
|
||||
precision =
|
||||
|
||||
[training.optimizer]
|
||||
learning_rate =
|
||||
|
||||
[training.logger]
|
||||
logger_path =
|
||||
logger_name =
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class CrossFaceTrainer(LightningModule):
|
||||
super().__init__()
|
||||
self.config_source_path = config_parser.get('training.model', 'source_path')
|
||||
self.config_target_path = config_parser.get('training.model', 'target_path')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.config_learning_rate = config_parser.getfloat('training.optimizer', 'learning_rate')
|
||||
self.crossface = CrossFace()
|
||||
self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu').eval()
|
||||
self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu').eval()
|
||||
@@ -50,7 +50,7 @@ class CrossFaceTrainer(LightningModule):
|
||||
return validation_score
|
||||
|
||||
def configure_optimizers(self) -> OptimizerSet:
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr = self.config_learning_rate)
|
||||
optimizer = torch.optim.AdamW(self.parameters(), lr = self.config_learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
optimizer_set =\
|
||||
{
|
||||
@@ -91,8 +91,8 @@ def create_trainer() -> Trainer:
|
||||
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
|
||||
config_strategy = CONFIG_PARSER.get('training.trainer', 'strategy')
|
||||
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
|
||||
config_logger_path = CONFIG_PARSER.get('training.trainer', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.trainer', 'logger_name')
|
||||
config_logger_path = CONFIG_PARSER.get('training.logger', 'logger_path')
|
||||
config_logger_name = CONFIG_PARSER.get('training.logger', 'logger_name')
|
||||
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
|
||||
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
logger = TensorBoardLogger(config_logger_path, config_logger_name)
|
||||
|
||||
Reference in New Issue
Block a user