move generator optimizer toggle

This commit is contained in:
harisreedhar
2025-03-11 19:37:05 +05:30
parent af09ee7ff3
commit 4f67e045a0
+2 -2
View File
@@ -98,13 +98,13 @@ class FaceSwapperTrainer(LightningModule):
source_tensor, target_tensor = batch
do_update = (batch_index + 1) % self.config_accumulate_size == 0
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
self.toggle_optimizer(generator_optimizer)
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_attributes = self.generator.get_attributes(target_tensor)
generator_output_tensor = self.generator(source_embedding, target_tensor)
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
self.toggle_optimizer(generator_optimizer)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)