""" Text Processing Utilities for UDora Target Positioning This module contains text processing functions used for analyzing generated text and identifying optimal positions for target insertion. """ from typing import List, Dict, Any, Tuple import torch from .utils import combine_with_overlap def build_target_intervals(generated_ids: List[int], targets: List[str], tokenizer, probs_list: List[torch.Tensor], add_space_before_target: bool = False, before_negative: bool = False) -> List[Dict[str, Any]]: """ Build intervals for potential target insertion positions. This function analyzes the generated token sequence to find all possible positions where target strings could be inserted, computing scores based on token probabilities and matching quality. Args: generated_ids: List of generated token IDs targets: List of target strings to consider tokenizer: Model tokenizer for text processing probs_list: List of probability tensors for each generated token add_space_before_target: Whether to add space before target text before_negative: Stop interval collection when encountering negative responses Returns: List of interval dictionaries with keys: - 'start': Start position of interval - 'end': End position of interval - 'score': Quality score for this interval - 'target_ids': Token IDs for the target text """ intervals = [] num_generated_tokens = len(generated_ids) for i in range(num_generated_tokens): # Check for negative response words if before_negative is enabled if before_negative: try: # Look ahead at the next 2 tokens to check for negative words check_negative = tokenizer.decode(generated_ids[i:i+2]).strip().split()[0] if check_negative in ['cannot', "can't"]: break # Stop collecting intervals except: pass # Continue if decoding fails for target_text in targets: # Get preceding context preceding_ids = generated_ids[:i] preceding_text = tokenizer.decode(preceding_ids, skip_special_tokens=False) next_token = tokenizer.convert_ids_to_tokens(generated_ids[i]) # Handle space insertion logic # huggingface uses 'Ġ' to represent ' ' before the token if add_space_before_target and (next_token.startswith('Ġ') or next_token.startswith(' ')): combined_text, overlap = combine_with_overlap(preceding_text, ' ' + target_text) else: combined_text, overlap = combine_with_overlap(preceding_text, target_text) combined_ids = tokenizer.encode(combined_text, add_special_tokens=False) # Calculate position adjustments if overlap: differences = 1 else: differences = sum(1 for x, y in zip(combined_ids[:i], preceding_ids) if x != y) target_ids_in_context = combined_ids[i - differences:] target_length = len(target_ids_in_context) # Compute matching score score_info = _compute_interval_score( target_ids_in_context, generated_ids, probs_list, i, differences, num_generated_tokens ) if score_info is None: continue current_score, current_num_matched = score_info # Create interval start_pos = i - differences end_pos = start_pos + target_length intervals.append({ 'start': start_pos, 'end': end_pos, 'score': current_score, 'target_ids': target_ids_in_context, 'num_matched': current_num_matched, 'target_text': target_text }) return intervals def _compute_interval_score(target_ids_in_context: List[int], generated_ids: List[int], probs_list: List[torch.Tensor], start_idx: int, differences: int, num_generated_tokens: int) -> Tuple[float, int]: """ Compute quality score for a target interval. Args: target_ids_in_context: Target token IDs in the current context generated_ids: Full list of generated token IDs probs_list: List of probability tensors start_idx: Starting index for evaluation differences: Position adjustment offset num_generated_tokens: Total number of generated tokens Returns: Tuple of (score, num_matched_tokens) or None if invalid """ target_length = len(target_ids_in_context) current_num_matched = 0 current_prob = [] # Evaluate each token in the target sequence for j in range(min(target_length, num_generated_tokens + differences - start_idx)): target_id = target_ids_in_context[j] prob_idx = start_idx + j - differences if prob_idx < 0 or prob_idx >= len(probs_list): break current_prob.append(probs_list[prob_idx][target_id].item()) current_num_matched += 1 # Check if prediction matches target if probs_list[prob_idx].argmax().item() != target_id: current_num_matched -= 1 break if len(current_prob) == 0: return None # Compute final score avg_prob = sum(current_prob) / len(current_prob) score = (current_num_matched + avg_prob) / (target_length + 1) return score, current_num_matched def count_matched_locations(intervals: List[Dict[str, Any]], success_threshold: float = None) -> int: """ Count intervals that meet the success threshold. Args: intervals: List of interval dictionaries success_threshold: Minimum score threshold (computed if None) Returns: Number of intervals meeting the threshold """ if not intervals: return 0 matched_count = 0 for interval in intervals: target_length = len(interval['target_ids']) threshold = success_threshold or (target_length / (target_length + 1)) if interval['score'] >= threshold: matched_count += 1 return matched_count def format_interval_debug_info(intervals: List[Dict[str, Any]], tokenizer, max_intervals: int = 10) -> str: """ Format interval information for debugging output. Args: intervals: List of interval dictionaries tokenizer: Model tokenizer for decoding max_intervals: Maximum number of intervals to include Returns: Formatted debug string """ if not intervals: return "No intervals found" lines = [f"Found {len(intervals)} potential intervals:"] # Sort by score (descending) and take top intervals sorted_intervals = sorted(intervals, key=lambda x: x['score'], reverse=True) top_intervals = sorted_intervals[:max_intervals] for i, interval in enumerate(top_intervals): target_text = tokenizer.decode(interval['target_ids']) lines.append( f" {i+1}. [{interval['start']}:{interval['end']}] " f"score={interval['score']:.3f} " f"text='{target_text}'" ) if len(intervals) > max_intervals: lines.append(f" ... and {len(intervals) - max_intervals} more") return "\n".join(lines)