mirror of
https://github.com/AI-secure/UDora.git
synced 2026-02-13 13:23:13 +00:00
128 lines
4.6 KiB
Python
128 lines
4.6 KiB
Python
import functools
|
|
import gc
|
|
import inspect
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
INIT_CHARS = [
|
|
".", ",", "!", "?", ";", ":", "(", ")", "[", "]", "{", "}",
|
|
"@", "#", "$", "%", "&", "*",
|
|
"w", "x", "y", "z",
|
|
]
|
|
|
|
def combine_with_overlap(preceding_text, target_text):
|
|
# Find the longest overlap between the end of preceding_text and the start of target_text
|
|
for i in range(len(preceding_text)):
|
|
if target_text.startswith(preceding_text[i:]):
|
|
combined_text = preceding_text + target_text[len(preceding_text[i:]):]
|
|
return combined_text, True # Overlap found
|
|
return preceding_text + target_text, False # No overlap found
|
|
|
|
def check_type(obj, expected_type):
|
|
if not isinstance(obj, list):
|
|
return False
|
|
return all(isinstance(item, expected_type) for item in obj)
|
|
|
|
def get_nonascii_toks(tokenizer, device="cpu"):
|
|
|
|
def is_ascii(s):
|
|
return s.isascii() and s.isprintable()
|
|
|
|
nonascii_toks = []
|
|
for i in range(tokenizer.vocab_size):
|
|
if not is_ascii(tokenizer.decode([i])):
|
|
nonascii_toks.append(i)
|
|
|
|
if tokenizer.bos_token_id is not None:
|
|
nonascii_toks.append(tokenizer.bos_token_id)
|
|
if tokenizer.eos_token_id is not None:
|
|
nonascii_toks.append(tokenizer.eos_token_id)
|
|
if tokenizer.pad_token_id is not None:
|
|
nonascii_toks.append(tokenizer.pad_token_id)
|
|
if tokenizer.unk_token_id is not None:
|
|
nonascii_toks.append(tokenizer.unk_token_id)
|
|
|
|
return torch.tensor(nonascii_toks, device=device)
|
|
|
|
def mellowmax(t: Tensor, alpha=1.0, dim=-1):
|
|
return 1.0 / alpha * (torch.logsumexp(alpha * t, dim=dim) - torch.log(torch.tensor(t.shape[-1], dtype=t.dtype, device=t.device)))
|
|
|
|
# borrowed from https://github.com/huggingface/accelerate/blob/85a75d4c3d0deffde2fc8b917d9b1ae1cb580eb2/src/accelerate/utils/memory.py#L69
|
|
def should_reduce_batch_size(exception: Exception) -> bool:
|
|
"""
|
|
Checks if `exception` relates to CUDA out-of-memory, CUDNN not supported, or CPU out-of-memory
|
|
|
|
Args:
|
|
exception (`Exception`):
|
|
An exception
|
|
"""
|
|
_statements = [
|
|
"CUDA out of memory.", # CUDA OOM
|
|
"cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU
|
|
"DefaultCPUAllocator: can't allocate memory", # CPU OOM
|
|
]
|
|
if isinstance(exception, RuntimeError) and len(exception.args) == 1:
|
|
return any(err in exception.args[0] for err in _statements)
|
|
return False
|
|
|
|
# modified from https://github.com/huggingface/accelerate/blob/85a75d4c3d0deffde2fc8b917d9b1ae1cb580eb2/src/accelerate/utils/memory.py#L87
|
|
def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128):
|
|
"""
|
|
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
|
|
CUDNN, the batch size is cut in half and passed to `function`
|
|
|
|
`function` must take in a `batch_size` parameter as its first argument.
|
|
|
|
Args:
|
|
function (`callable`, *optional*):
|
|
A function to wrap
|
|
starting_batch_size (`int`, *optional*):
|
|
The batch size to try and fit into memory
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from utils import find_executable_batch_size
|
|
|
|
|
|
>>> @find_executable_batch_size(starting_batch_size=128)
|
|
... def train(batch_size, model, optimizer):
|
|
... ...
|
|
|
|
|
|
>>> train(model, optimizer)
|
|
```
|
|
"""
|
|
if function is None:
|
|
return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size)
|
|
|
|
batch_size = starting_batch_size
|
|
|
|
def decorator(*args, **kwargs):
|
|
nonlocal batch_size
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
params = list(inspect.signature(function).parameters.keys())
|
|
# Guard against user error
|
|
if len(params) < (len(args) + 1):
|
|
arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])])
|
|
raise TypeError(
|
|
f"Batch size was passed into `{function.__name__}` as the first argument when called."
|
|
f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`"
|
|
)
|
|
while True:
|
|
if batch_size == 0:
|
|
raise RuntimeError("No executable batch size found, reached zero.")
|
|
try:
|
|
return function(batch_size, *args, **kwargs)
|
|
except Exception as e:
|
|
if should_reduce_batch_size(e):
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
batch_size //= 2
|
|
print(f"Decreasing batch size to: {batch_size}")
|
|
else:
|
|
raise
|
|
|
|
return decorator
|