Files
UDora/udora/utils.py
javyduck 06f6ae523d init
2025-06-23 00:35:15 +00:00

128 lines
4.6 KiB
Python

import functools
import gc
import inspect
import torch
from torch import Tensor
INIT_CHARS = [
".", ",", "!", "?", ";", ":", "(", ")", "[", "]", "{", "}",
"@", "#", "$", "%", "&", "*",
"w", "x", "y", "z",
]
def combine_with_overlap(preceding_text, target_text):
# Find the longest overlap between the end of preceding_text and the start of target_text
for i in range(len(preceding_text)):
if target_text.startswith(preceding_text[i:]):
combined_text = preceding_text + target_text[len(preceding_text[i:]):]
return combined_text, True # Overlap found
return preceding_text + target_text, False # No overlap found
def check_type(obj, expected_type):
if not isinstance(obj, list):
return False
return all(isinstance(item, expected_type) for item in obj)
def get_nonascii_toks(tokenizer, device="cpu"):
def is_ascii(s):
return s.isascii() and s.isprintable()
nonascii_toks = []
for i in range(tokenizer.vocab_size):
if not is_ascii(tokenizer.decode([i])):
nonascii_toks.append(i)
if tokenizer.bos_token_id is not None:
nonascii_toks.append(tokenizer.bos_token_id)
if tokenizer.eos_token_id is not None:
nonascii_toks.append(tokenizer.eos_token_id)
if tokenizer.pad_token_id is not None:
nonascii_toks.append(tokenizer.pad_token_id)
if tokenizer.unk_token_id is not None:
nonascii_toks.append(tokenizer.unk_token_id)
return torch.tensor(nonascii_toks, device=device)
def mellowmax(t: Tensor, alpha=1.0, dim=-1):
return 1.0 / alpha * (torch.logsumexp(alpha * t, dim=dim) - torch.log(torch.tensor(t.shape[-1], dtype=t.dtype, device=t.device)))
# borrowed from https://github.com/huggingface/accelerate/blob/85a75d4c3d0deffde2fc8b917d9b1ae1cb580eb2/src/accelerate/utils/memory.py#L69
def should_reduce_batch_size(exception: Exception) -> bool:
"""
Checks if `exception` relates to CUDA out-of-memory, CUDNN not supported, or CPU out-of-memory
Args:
exception (`Exception`):
An exception
"""
_statements = [
"CUDA out of memory.", # CUDA OOM
"cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU
"DefaultCPUAllocator: can't allocate memory", # CPU OOM
]
if isinstance(exception, RuntimeError) and len(exception.args) == 1:
return any(err in exception.args[0] for err in _statements)
return False
# modified from https://github.com/huggingface/accelerate/blob/85a75d4c3d0deffde2fc8b917d9b1ae1cb580eb2/src/accelerate/utils/memory.py#L87
def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128):
"""
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
CUDNN, the batch size is cut in half and passed to `function`
`function` must take in a `batch_size` parameter as its first argument.
Args:
function (`callable`, *optional*):
A function to wrap
starting_batch_size (`int`, *optional*):
The batch size to try and fit into memory
Example:
```python
>>> from utils import find_executable_batch_size
>>> @find_executable_batch_size(starting_batch_size=128)
... def train(batch_size, model, optimizer):
... ...
>>> train(model, optimizer)
```
"""
if function is None:
return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size)
batch_size = starting_batch_size
def decorator(*args, **kwargs):
nonlocal batch_size
gc.collect()
torch.cuda.empty_cache()
params = list(inspect.signature(function).parameters.keys())
# Guard against user error
if len(params) < (len(args) + 1):
arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])])
raise TypeError(
f"Batch size was passed into `{function.__name__}` as the first argument when called."
f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`"
)
while True:
if batch_size == 0:
raise RuntimeError("No executable batch size found, reached zero.")
try:
return function(batch_size, *args, **kwargs)
except Exception as e:
if should_reduce_batch_size(e):
gc.collect()
torch.cuda.empty_cache()
batch_size //= 2
print(f"Decreasing batch size to: {batch_size}")
else:
raise
return decorator