mirror of
https://github.com/AI-secure/UDora.git
synced 2026-02-13 05:13:11 +00:00
222 lines
7.9 KiB
Python
222 lines
7.9 KiB
Python
"""
|
|
Text Processing Utilities for UDora Target Positioning
|
|
|
|
This module contains text processing functions used for analyzing generated
|
|
text and identifying optimal positions for target insertion.
|
|
"""
|
|
|
|
from typing import List, Dict, Any, Tuple
|
|
import torch
|
|
from .utils import combine_with_overlap
|
|
|
|
|
|
def build_target_intervals(generated_ids: List[int],
|
|
targets: List[str],
|
|
tokenizer,
|
|
probs_list: List[torch.Tensor],
|
|
add_space_before_target: bool = False,
|
|
before_negative: bool = False) -> List[Dict[str, Any]]:
|
|
"""
|
|
Build intervals for potential target insertion positions.
|
|
|
|
This function analyzes the generated token sequence to find all possible
|
|
positions where target strings could be inserted, computing scores based
|
|
on token probabilities and matching quality.
|
|
|
|
Args:
|
|
generated_ids: List of generated token IDs
|
|
targets: List of target strings to consider
|
|
tokenizer: Model tokenizer for text processing
|
|
probs_list: List of probability tensors for each generated token
|
|
add_space_before_target: Whether to add space before target text
|
|
before_negative: Stop interval collection when encountering negative responses
|
|
|
|
Returns:
|
|
List of interval dictionaries with keys:
|
|
- 'start': Start position of interval
|
|
- 'end': End position of interval
|
|
- 'score': Quality score for this interval
|
|
- 'target_ids': Token IDs for the target text
|
|
"""
|
|
intervals = []
|
|
num_generated_tokens = len(generated_ids)
|
|
|
|
for i in range(num_generated_tokens):
|
|
# Check for negative response words if before_negative is enabled
|
|
if before_negative:
|
|
try:
|
|
# Look ahead at the next 2 tokens to check for negative words
|
|
check_negative = tokenizer.decode(generated_ids[i:i+2]).strip().split()[0]
|
|
if check_negative in ['cannot', "can't"]:
|
|
break # Stop collecting intervals
|
|
except:
|
|
pass # Continue if decoding fails
|
|
|
|
for target_text in targets:
|
|
# Get preceding context
|
|
preceding_ids = generated_ids[:i]
|
|
preceding_text = tokenizer.decode(preceding_ids, skip_special_tokens=False)
|
|
|
|
next_token = tokenizer.convert_ids_to_tokens(generated_ids[i])
|
|
|
|
# Handle space insertion logic
|
|
# huggingface uses 'Ġ' to represent ' ' before the token
|
|
if add_space_before_target and (next_token.startswith('Ġ') or next_token.startswith(' ')):
|
|
combined_text, overlap = combine_with_overlap(preceding_text, ' ' + target_text)
|
|
else:
|
|
combined_text, overlap = combine_with_overlap(preceding_text, target_text)
|
|
|
|
combined_ids = tokenizer.encode(combined_text, add_special_tokens=False)
|
|
|
|
# Calculate position adjustments
|
|
if overlap:
|
|
differences = 1
|
|
else:
|
|
differences = sum(1 for x, y in zip(combined_ids[:i], preceding_ids) if x != y)
|
|
|
|
target_ids_in_context = combined_ids[i - differences:]
|
|
target_length = len(target_ids_in_context)
|
|
|
|
# Compute matching score
|
|
score_info = _compute_interval_score(
|
|
target_ids_in_context,
|
|
generated_ids,
|
|
probs_list,
|
|
i,
|
|
differences,
|
|
num_generated_tokens
|
|
)
|
|
|
|
if score_info is None:
|
|
continue
|
|
|
|
current_score, current_num_matched = score_info
|
|
|
|
# Create interval
|
|
start_pos = i - differences
|
|
end_pos = start_pos + target_length
|
|
|
|
intervals.append({
|
|
'start': start_pos,
|
|
'end': end_pos,
|
|
'score': current_score,
|
|
'target_ids': target_ids_in_context,
|
|
'num_matched': current_num_matched,
|
|
'target_text': target_text
|
|
})
|
|
|
|
return intervals
|
|
|
|
|
|
def _compute_interval_score(target_ids_in_context: List[int],
|
|
generated_ids: List[int],
|
|
probs_list: List[torch.Tensor],
|
|
start_idx: int,
|
|
differences: int,
|
|
num_generated_tokens: int) -> Tuple[float, int]:
|
|
"""
|
|
Compute quality score for a target interval.
|
|
|
|
Args:
|
|
target_ids_in_context: Target token IDs in the current context
|
|
generated_ids: Full list of generated token IDs
|
|
probs_list: List of probability tensors
|
|
start_idx: Starting index for evaluation
|
|
differences: Position adjustment offset
|
|
num_generated_tokens: Total number of generated tokens
|
|
|
|
Returns:
|
|
Tuple of (score, num_matched_tokens) or None if invalid
|
|
"""
|
|
target_length = len(target_ids_in_context)
|
|
current_num_matched = 0
|
|
current_prob = []
|
|
|
|
# Evaluate each token in the target sequence
|
|
for j in range(min(target_length, num_generated_tokens + differences - start_idx)):
|
|
target_id = target_ids_in_context[j]
|
|
prob_idx = start_idx + j - differences
|
|
|
|
if prob_idx < 0 or prob_idx >= len(probs_list):
|
|
break
|
|
|
|
current_prob.append(probs_list[prob_idx][target_id].item())
|
|
current_num_matched += 1
|
|
|
|
# Check if prediction matches target
|
|
if probs_list[prob_idx].argmax().item() != target_id:
|
|
current_num_matched -= 1
|
|
break
|
|
|
|
if len(current_prob) == 0:
|
|
return None
|
|
|
|
# Compute final score
|
|
avg_prob = sum(current_prob) / len(current_prob)
|
|
score = (current_num_matched + avg_prob) / (target_length + 1)
|
|
|
|
return score, current_num_matched
|
|
|
|
|
|
def count_matched_locations(intervals: List[Dict[str, Any]],
|
|
success_threshold: float = None) -> int:
|
|
"""
|
|
Count intervals that meet the success threshold.
|
|
|
|
Args:
|
|
intervals: List of interval dictionaries
|
|
success_threshold: Minimum score threshold (computed if None)
|
|
|
|
Returns:
|
|
Number of intervals meeting the threshold
|
|
"""
|
|
if not intervals:
|
|
return 0
|
|
|
|
matched_count = 0
|
|
|
|
for interval in intervals:
|
|
target_length = len(interval['target_ids'])
|
|
threshold = success_threshold or (target_length / (target_length + 1))
|
|
|
|
if interval['score'] >= threshold:
|
|
matched_count += 1
|
|
|
|
return matched_count
|
|
|
|
|
|
def format_interval_debug_info(intervals: List[Dict[str, Any]],
|
|
tokenizer,
|
|
max_intervals: int = 10) -> str:
|
|
"""
|
|
Format interval information for debugging output.
|
|
|
|
Args:
|
|
intervals: List of interval dictionaries
|
|
tokenizer: Model tokenizer for decoding
|
|
max_intervals: Maximum number of intervals to include
|
|
|
|
Returns:
|
|
Formatted debug string
|
|
"""
|
|
if not intervals:
|
|
return "No intervals found"
|
|
|
|
lines = [f"Found {len(intervals)} potential intervals:"]
|
|
|
|
# Sort by score (descending) and take top intervals
|
|
sorted_intervals = sorted(intervals, key=lambda x: x['score'], reverse=True)
|
|
top_intervals = sorted_intervals[:max_intervals]
|
|
|
|
for i, interval in enumerate(top_intervals):
|
|
target_text = tokenizer.decode(interval['target_ids'])
|
|
lines.append(
|
|
f" {i+1}. [{interval['start']}:{interval['end']}] "
|
|
f"score={interval['score']:.3f} "
|
|
f"text='{target_text}'"
|
|
)
|
|
|
|
if len(intervals) > max_intervals:
|
|
lines.append(f" ... and {len(intervals) - max_intervals} more")
|
|
|
|
return "\n".join(lines) |