import functools import gc import inspect import torch from torch import Tensor INIT_CHARS = [ ".", ",", "!", "?", ";", ":", "(", ")", "[", "]", "{", "}", "@", "#", "$", "%", "&", "*", "w", "x", "y", "z", ] def combine_with_overlap(preceding_text, target_text): # Find the longest overlap between the end of preceding_text and the start of target_text for i in range(len(preceding_text)): if target_text.startswith(preceding_text[i:]): combined_text = preceding_text + target_text[len(preceding_text[i:]):] return combined_text, True # Overlap found return preceding_text + target_text, False # No overlap found def check_type(obj, expected_type): if not isinstance(obj, list): return False return all(isinstance(item, expected_type) for item in obj) def get_nonascii_toks(tokenizer, device="cpu"): def is_ascii(s): return s.isascii() and s.isprintable() nonascii_toks = [] for i in range(tokenizer.vocab_size): if not is_ascii(tokenizer.decode([i])): nonascii_toks.append(i) if tokenizer.bos_token_id is not None: nonascii_toks.append(tokenizer.bos_token_id) if tokenizer.eos_token_id is not None: nonascii_toks.append(tokenizer.eos_token_id) if tokenizer.pad_token_id is not None: nonascii_toks.append(tokenizer.pad_token_id) if tokenizer.unk_token_id is not None: nonascii_toks.append(tokenizer.unk_token_id) return torch.tensor(nonascii_toks, device=device) def mellowmax(t: Tensor, alpha=1.0, dim=-1): return 1.0 / alpha * (torch.logsumexp(alpha * t, dim=dim) - torch.log(torch.tensor(t.shape[-1], dtype=t.dtype, device=t.device))) # borrowed from https://github.com/huggingface/accelerate/blob/85a75d4c3d0deffde2fc8b917d9b1ae1cb580eb2/src/accelerate/utils/memory.py#L69 def should_reduce_batch_size(exception: Exception) -> bool: """ Checks if `exception` relates to CUDA out-of-memory, CUDNN not supported, or CPU out-of-memory Args: exception (`Exception`): An exception """ _statements = [ "CUDA out of memory.", # CUDA OOM "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU "DefaultCPUAllocator: can't allocate memory", # CPU OOM ] if isinstance(exception, RuntimeError) and len(exception.args) == 1: return any(err in exception.args[0] for err in _statements) return False # modified from https://github.com/huggingface/accelerate/blob/85a75d4c3d0deffde2fc8b917d9b1ae1cb580eb2/src/accelerate/utils/memory.py#L87 def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128): """ A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or CUDNN, the batch size is cut in half and passed to `function` `function` must take in a `batch_size` parameter as its first argument. Args: function (`callable`, *optional*): A function to wrap starting_batch_size (`int`, *optional*): The batch size to try and fit into memory Example: ```python >>> from utils import find_executable_batch_size >>> @find_executable_batch_size(starting_batch_size=128) ... def train(batch_size, model, optimizer): ... ... >>> train(model, optimizer) ``` """ if function is None: return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size) batch_size = starting_batch_size def decorator(*args, **kwargs): nonlocal batch_size gc.collect() torch.cuda.empty_cache() params = list(inspect.signature(function).parameters.keys()) # Guard against user error if len(params) < (len(args) + 1): arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])]) raise TypeError( f"Batch size was passed into `{function.__name__}` as the first argument when called." f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`" ) while True: if batch_size == 0: raise RuntimeError("No executable batch size found, reached zero.") try: return function(batch_size, *args, **kwargs) except Exception as e: if should_reduce_batch_size(e): gc.collect() torch.cuda.empty_cache() batch_size //= 2 print(f"Decreasing batch size to: {batch_size}") else: raise return decorator