mirror of
https://github.com/romovpa/claudini.git
synced 2026-06-07 20:33:54 +02:00
33 lines
1.1 KiB
Python
33 lines
1.1 KiB
Python
"""Claude v100: Nesterov + patience=50 + restore-from-best. Full combination of best techniques."""
|
|
|
|
import torch
|
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
|
|
|
from claudini.methods.claude.v90 import ClaudeV90Optimizer
|
|
|
|
|
|
class ClaudeV100Optimizer(ClaudeV90Optimizer):
|
|
method_name = "claude_v100"
|
|
|
|
def __init__(
|
|
self,
|
|
model: PreTrainedModel,
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
optim_length: int = 20,
|
|
lr: float = 10.0,
|
|
momentum: float = 0.99,
|
|
ema_alpha: float = 0.01,
|
|
num_starts: int = 8,
|
|
lsgm_gamma: float = 0.70,
|
|
seed: int | None = None,
|
|
allow_non_ascii: bool = False,
|
|
):
|
|
super().__init__(
|
|
model, tokenizer, optim_length, lr, momentum, ema_alpha, num_starts, lsgm_gamma, seed, allow_non_ascii
|
|
)
|
|
self.patience = 50
|
|
|
|
def setup(self, prompt: str, target: str) -> None:
|
|
super().setup(prompt, target)
|
|
self.optimizer = torch.optim.SGD([self.soft_opt], lr=self.lr, momentum=self.momentum, nesterov=True)
|