Files
claudini/claudini/methods/claude/v100/optimizer.py
T

33 lines
1.1 KiB
Python

"""Claude v100: Nesterov + patience=50 + restore-from-best. Full combination of best techniques."""
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from claudini.methods.claude.v90 import ClaudeV90Optimizer
class ClaudeV100Optimizer(ClaudeV90Optimizer):
method_name = "claude_v100"
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
optim_length: int = 20,
lr: float = 10.0,
momentum: float = 0.99,
ema_alpha: float = 0.01,
num_starts: int = 8,
lsgm_gamma: float = 0.70,
seed: int | None = None,
allow_non_ascii: bool = False,
):
super().__init__(
model, tokenizer, optim_length, lr, momentum, ema_alpha, num_starts, lsgm_gamma, seed, allow_non_ascii
)
self.patience = 50
def setup(self, prompt: str, target: str) -> None:
super().setup(prompt, target)
self.optimizer = torch.optim.SGD([self.soft_opt], lr=self.lr, momentum=self.momentum, nesterov=True)