diff --git a/scripts/run_benchmark_remote.sh b/scripts/run_benchmark_remote.sh index 62e5c8c..028e90b 100644 --- a/scripts/run_benchmark_remote.sh +++ b/scripts/run_benchmark_remote.sh @@ -144,12 +144,18 @@ def _patched_collect(self, layer_modules, prompts, label): torch.cuda.mem_get_info(i)[0] / (1024 ** 3) for i in range(torch.cuda.device_count()) ) - if free_gb < 2.0: + # Scale thresholds by model size (baseline: 7B with hidden=4096, 32 layers) + _h = self.handle.hidden_size if self.handle else 4096 + _l = n_layers if n_layers else 32 + _ms = (_h / 4096) * (_l / 32) + _tight = max(4.0 * _ms, 0.5) + _low = max(2.0 * _ms, 0.25) + if free_gb < _low: max_length = 64 - self.log(f" Low GPU memory ({free_gb:.1f} GB free), using max_length={max_length}") - elif free_gb < 4.0: + self.log(f" Low GPU memory ({free_gb:.1f} GB free, threshold {_low:.1f} GB), using max_length={max_length}") + elif free_gb < _tight: max_length = 128 - self.log(f" Tight GPU memory ({free_gb:.1f} GB free), using max_length={max_length}") + self.log(f" Tight GPU memory ({free_gb:.1f} GB free, threshold {_tight:.1f} GB), using max_length={max_length}") device = self._get_model_device(model)