Files
UDora/udora/text_processing.py
javyduck 0d259fc3aa init_v1
2025-06-24 01:50:12 +00:00

222 lines
7.9 KiB
Python

"""
Text Processing Utilities for UDora Target Positioning
This module contains text processing functions used for analyzing generated
text and identifying optimal positions for target insertion.
"""
from typing import List, Dict, Any, Tuple
import torch
from .utils import combine_with_overlap
def build_target_intervals(generated_ids: List[int],
targets: List[str],
tokenizer,
probs_list: List[torch.Tensor],
add_space_before_target: bool = False,
before_negative: bool = False) -> List[Dict[str, Any]]:
"""
Build intervals for potential target insertion positions.
This function analyzes the generated token sequence to find all possible
positions where target strings could be inserted, computing scores based
on token probabilities and matching quality.
Args:
generated_ids: List of generated token IDs
targets: List of target strings to consider
tokenizer: Model tokenizer for text processing
probs_list: List of probability tensors for each generated token
add_space_before_target: Whether to add space before target text
before_negative: Stop interval collection when encountering negative responses
Returns:
List of interval dictionaries with keys:
- 'start': Start position of interval
- 'end': End position of interval
- 'score': Quality score for this interval
- 'target_ids': Token IDs for the target text
"""
intervals = []
num_generated_tokens = len(generated_ids)
for i in range(num_generated_tokens):
# Check for negative response words if before_negative is enabled
if before_negative:
try:
# Look ahead at the next 2 tokens to check for negative words
check_negative = tokenizer.decode(generated_ids[i:i+2]).strip().split()[0]
if check_negative in ['cannot', "can't"]:
break # Stop collecting intervals
except:
pass # Continue if decoding fails
for target_text in targets:
# Get preceding context
preceding_ids = generated_ids[:i]
preceding_text = tokenizer.decode(preceding_ids, skip_special_tokens=False)
next_token = tokenizer.convert_ids_to_tokens(generated_ids[i])
# Handle space insertion logic
# huggingface uses 'Ġ' to represent ' ' before the token
if add_space_before_target and (next_token.startswith('Ġ') or next_token.startswith(' ')):
combined_text, overlap = combine_with_overlap(preceding_text, ' ' + target_text)
else:
combined_text, overlap = combine_with_overlap(preceding_text, target_text)
combined_ids = tokenizer.encode(combined_text, add_special_tokens=False)
# Calculate position adjustments
if overlap:
differences = 1
else:
differences = sum(1 for x, y in zip(combined_ids[:i], preceding_ids) if x != y)
target_ids_in_context = combined_ids[i - differences:]
target_length = len(target_ids_in_context)
# Compute matching score
score_info = _compute_interval_score(
target_ids_in_context,
generated_ids,
probs_list,
i,
differences,
num_generated_tokens
)
if score_info is None:
continue
current_score, current_num_matched = score_info
# Create interval
start_pos = i - differences
end_pos = start_pos + target_length
intervals.append({
'start': start_pos,
'end': end_pos,
'score': current_score,
'target_ids': target_ids_in_context,
'num_matched': current_num_matched,
'target_text': target_text
})
return intervals
def _compute_interval_score(target_ids_in_context: List[int],
generated_ids: List[int],
probs_list: List[torch.Tensor],
start_idx: int,
differences: int,
num_generated_tokens: int) -> Tuple[float, int]:
"""
Compute quality score for a target interval.
Args:
target_ids_in_context: Target token IDs in the current context
generated_ids: Full list of generated token IDs
probs_list: List of probability tensors
start_idx: Starting index for evaluation
differences: Position adjustment offset
num_generated_tokens: Total number of generated tokens
Returns:
Tuple of (score, num_matched_tokens) or None if invalid
"""
target_length = len(target_ids_in_context)
current_num_matched = 0
current_prob = []
# Evaluate each token in the target sequence
for j in range(min(target_length, num_generated_tokens + differences - start_idx)):
target_id = target_ids_in_context[j]
prob_idx = start_idx + j - differences
if prob_idx < 0 or prob_idx >= len(probs_list):
break
current_prob.append(probs_list[prob_idx][target_id].item())
current_num_matched += 1
# Check if prediction matches target
if probs_list[prob_idx].argmax().item() != target_id:
current_num_matched -= 1
break
if len(current_prob) == 0:
return None
# Compute final score
avg_prob = sum(current_prob) / len(current_prob)
score = (current_num_matched + avg_prob) / (target_length + 1)
return score, current_num_matched
def count_matched_locations(intervals: List[Dict[str, Any]],
success_threshold: float = None) -> int:
"""
Count intervals that meet the success threshold.
Args:
intervals: List of interval dictionaries
success_threshold: Minimum score threshold (computed if None)
Returns:
Number of intervals meeting the threshold
"""
if not intervals:
return 0
matched_count = 0
for interval in intervals:
target_length = len(interval['target_ids'])
threshold = success_threshold or (target_length / (target_length + 1))
if interval['score'] >= threshold:
matched_count += 1
return matched_count
def format_interval_debug_info(intervals: List[Dict[str, Any]],
tokenizer,
max_intervals: int = 10) -> str:
"""
Format interval information for debugging output.
Args:
intervals: List of interval dictionaries
tokenizer: Model tokenizer for decoding
max_intervals: Maximum number of intervals to include
Returns:
Formatted debug string
"""
if not intervals:
return "No intervals found"
lines = [f"Found {len(intervals)} potential intervals:"]
# Sort by score (descending) and take top intervals
sorted_intervals = sorted(intervals, key=lambda x: x['score'], reverse=True)
top_intervals = sorted_intervals[:max_intervals]
for i, interval in enumerate(top_intervals):
target_text = tokenizer.decode(interval['target_ids'])
lines.append(
f" {i+1}. [{interval['start']}:{interval['end']}] "
f"score={interval['score']:.3f} "
f"text='{target_text}'"
)
if len(intervals) > max_intervals:
lines.append(f" ... and {len(intervals) - max_intervals} more")
return "\n".join(lines)