Add files via upload
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
from typing import Generator
|
||||
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
|
||||
|
||||
class CapacitronOptimizer:
|
||||
"""Double optimizer class for the Capacitron model."""
|
||||
|
||||
def __init__(self, config: dict, model_params: Generator) -> None:
|
||||
self.primary_params, self.secondary_params = self.split_model_parameters(model_params)
|
||||
|
||||
optimizer_names = list(config.optimizer_params.keys())
|
||||
optimizer_parameters = list(config.optimizer_params.values())
|
||||
|
||||
self.primary_optimizer = get_optimizer(
|
||||
optimizer_names[0],
|
||||
optimizer_parameters[0],
|
||||
config.lr,
|
||||
parameters=self.primary_params,
|
||||
)
|
||||
|
||||
self.secondary_optimizer = get_optimizer(
|
||||
optimizer_names[1],
|
||||
self.extract_optimizer_parameters(optimizer_parameters[1]),
|
||||
optimizer_parameters[1]["lr"],
|
||||
parameters=self.secondary_params,
|
||||
)
|
||||
|
||||
self.param_groups = self.primary_optimizer.param_groups
|
||||
|
||||
def first_step(self):
|
||||
self.secondary_optimizer.step()
|
||||
self.secondary_optimizer.zero_grad()
|
||||
self.primary_optimizer.zero_grad()
|
||||
|
||||
def step(self):
|
||||
# Update param groups to display the correct learning rate
|
||||
self.param_groups = self.primary_optimizer.param_groups
|
||||
self.primary_optimizer.step()
|
||||
|
||||
def zero_grad(self, set_to_none=False):
|
||||
self.primary_optimizer.zero_grad(set_to_none)
|
||||
self.secondary_optimizer.zero_grad(set_to_none)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.primary_optimizer.load_state_dict(state_dict[0])
|
||||
self.secondary_optimizer.load_state_dict(state_dict[1])
|
||||
|
||||
def state_dict(self):
|
||||
return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()]
|
||||
|
||||
@staticmethod
|
||||
def split_model_parameters(model_params: Generator) -> list:
|
||||
primary_params = []
|
||||
secondary_params = []
|
||||
for name, param in model_params:
|
||||
if param.requires_grad:
|
||||
if name == "capacitron_vae_layer.beta":
|
||||
secondary_params.append(param)
|
||||
else:
|
||||
primary_params.append(param)
|
||||
return [iter(primary_params), iter(secondary_params)]
|
||||
|
||||
@staticmethod
|
||||
def extract_optimizer_parameters(params: dict) -> dict:
|
||||
"""Extract parameters that are not the learning rate"""
|
||||
return {k: v for k, v in params.items() if k != "lr"}
|
||||
Reference in New Issue
Block a user