mirror of
https://github.com/AI-secure/UDora.git
synced 2026-05-08 18:55:07 +02:00
init_v1
This commit is contained in:
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Text Processing Utilities for UDora Target Positioning
|
||||
|
||||
This module contains text processing functions used for analyzing generated
|
||||
text and identifying optimal positions for target insertion.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Tuple
|
||||
import torch
|
||||
from .utils import combine_with_overlap
|
||||
|
||||
|
||||
def build_target_intervals(generated_ids: List[int],
|
||||
targets: List[str],
|
||||
tokenizer,
|
||||
probs_list: List[torch.Tensor],
|
||||
add_space_before_target: bool = False,
|
||||
before_negative: bool = False) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Build intervals for potential target insertion positions.
|
||||
|
||||
This function analyzes the generated token sequence to find all possible
|
||||
positions where target strings could be inserted, computing scores based
|
||||
on token probabilities and matching quality.
|
||||
|
||||
Args:
|
||||
generated_ids: List of generated token IDs
|
||||
targets: List of target strings to consider
|
||||
tokenizer: Model tokenizer for text processing
|
||||
probs_list: List of probability tensors for each generated token
|
||||
add_space_before_target: Whether to add space before target text
|
||||
before_negative: Stop interval collection when encountering negative responses
|
||||
|
||||
Returns:
|
||||
List of interval dictionaries with keys:
|
||||
- 'start': Start position of interval
|
||||
- 'end': End position of interval
|
||||
- 'score': Quality score for this interval
|
||||
- 'target_ids': Token IDs for the target text
|
||||
"""
|
||||
intervals = []
|
||||
num_generated_tokens = len(generated_ids)
|
||||
|
||||
for i in range(num_generated_tokens):
|
||||
# Check for negative response words if before_negative is enabled
|
||||
if before_negative:
|
||||
try:
|
||||
# Look ahead at the next 2 tokens to check for negative words
|
||||
check_negative = tokenizer.decode(generated_ids[i:i+2]).strip().split()[0]
|
||||
if check_negative in ['cannot', "can't"]:
|
||||
break # Stop collecting intervals
|
||||
except:
|
||||
pass # Continue if decoding fails
|
||||
|
||||
for target_text in targets:
|
||||
# Get preceding context
|
||||
preceding_ids = generated_ids[:i]
|
||||
preceding_text = tokenizer.decode(preceding_ids, skip_special_tokens=False)
|
||||
|
||||
next_token = tokenizer.convert_ids_to_tokens(generated_ids[i])
|
||||
|
||||
# Handle space insertion logic
|
||||
# huggingface uses 'Ġ' to represent ' ' before the token
|
||||
if add_space_before_target and (next_token.startswith('Ġ') or next_token.startswith(' ')):
|
||||
combined_text, overlap = combine_with_overlap(preceding_text, ' ' + target_text)
|
||||
else:
|
||||
combined_text, overlap = combine_with_overlap(preceding_text, target_text)
|
||||
|
||||
combined_ids = tokenizer.encode(combined_text, add_special_tokens=False)
|
||||
|
||||
# Calculate position adjustments
|
||||
if overlap:
|
||||
differences = 1
|
||||
else:
|
||||
differences = sum(1 for x, y in zip(combined_ids[:i], preceding_ids) if x != y)
|
||||
|
||||
target_ids_in_context = combined_ids[i - differences:]
|
||||
target_length = len(target_ids_in_context)
|
||||
|
||||
# Compute matching score
|
||||
score_info = _compute_interval_score(
|
||||
target_ids_in_context,
|
||||
generated_ids,
|
||||
probs_list,
|
||||
i,
|
||||
differences,
|
||||
num_generated_tokens
|
||||
)
|
||||
|
||||
if score_info is None:
|
||||
continue
|
||||
|
||||
current_score, current_num_matched = score_info
|
||||
|
||||
# Create interval
|
||||
start_pos = i - differences
|
||||
end_pos = start_pos + target_length
|
||||
|
||||
intervals.append({
|
||||
'start': start_pos,
|
||||
'end': end_pos,
|
||||
'score': current_score,
|
||||
'target_ids': target_ids_in_context,
|
||||
'num_matched': current_num_matched,
|
||||
'target_text': target_text
|
||||
})
|
||||
|
||||
return intervals
|
||||
|
||||
|
||||
def _compute_interval_score(target_ids_in_context: List[int],
|
||||
generated_ids: List[int],
|
||||
probs_list: List[torch.Tensor],
|
||||
start_idx: int,
|
||||
differences: int,
|
||||
num_generated_tokens: int) -> Tuple[float, int]:
|
||||
"""
|
||||
Compute quality score for a target interval.
|
||||
|
||||
Args:
|
||||
target_ids_in_context: Target token IDs in the current context
|
||||
generated_ids: Full list of generated token IDs
|
||||
probs_list: List of probability tensors
|
||||
start_idx: Starting index for evaluation
|
||||
differences: Position adjustment offset
|
||||
num_generated_tokens: Total number of generated tokens
|
||||
|
||||
Returns:
|
||||
Tuple of (score, num_matched_tokens) or None if invalid
|
||||
"""
|
||||
target_length = len(target_ids_in_context)
|
||||
current_num_matched = 0
|
||||
current_prob = []
|
||||
|
||||
# Evaluate each token in the target sequence
|
||||
for j in range(min(target_length, num_generated_tokens + differences - start_idx)):
|
||||
target_id = target_ids_in_context[j]
|
||||
prob_idx = start_idx + j - differences
|
||||
|
||||
if prob_idx < 0 or prob_idx >= len(probs_list):
|
||||
break
|
||||
|
||||
current_prob.append(probs_list[prob_idx][target_id].item())
|
||||
current_num_matched += 1
|
||||
|
||||
# Check if prediction matches target
|
||||
if probs_list[prob_idx].argmax().item() != target_id:
|
||||
current_num_matched -= 1
|
||||
break
|
||||
|
||||
if len(current_prob) == 0:
|
||||
return None
|
||||
|
||||
# Compute final score
|
||||
avg_prob = sum(current_prob) / len(current_prob)
|
||||
score = (current_num_matched + avg_prob) / (target_length + 1)
|
||||
|
||||
return score, current_num_matched
|
||||
|
||||
|
||||
def count_matched_locations(intervals: List[Dict[str, Any]],
|
||||
success_threshold: float = None) -> int:
|
||||
"""
|
||||
Count intervals that meet the success threshold.
|
||||
|
||||
Args:
|
||||
intervals: List of interval dictionaries
|
||||
success_threshold: Minimum score threshold (computed if None)
|
||||
|
||||
Returns:
|
||||
Number of intervals meeting the threshold
|
||||
"""
|
||||
if not intervals:
|
||||
return 0
|
||||
|
||||
matched_count = 0
|
||||
|
||||
for interval in intervals:
|
||||
target_length = len(interval['target_ids'])
|
||||
threshold = success_threshold or (target_length / (target_length + 1))
|
||||
|
||||
if interval['score'] >= threshold:
|
||||
matched_count += 1
|
||||
|
||||
return matched_count
|
||||
|
||||
|
||||
def format_interval_debug_info(intervals: List[Dict[str, Any]],
|
||||
tokenizer,
|
||||
max_intervals: int = 10) -> str:
|
||||
"""
|
||||
Format interval information for debugging output.
|
||||
|
||||
Args:
|
||||
intervals: List of interval dictionaries
|
||||
tokenizer: Model tokenizer for decoding
|
||||
max_intervals: Maximum number of intervals to include
|
||||
|
||||
Returns:
|
||||
Formatted debug string
|
||||
"""
|
||||
if not intervals:
|
||||
return "No intervals found"
|
||||
|
||||
lines = [f"Found {len(intervals)} potential intervals:"]
|
||||
|
||||
# Sort by score (descending) and take top intervals
|
||||
sorted_intervals = sorted(intervals, key=lambda x: x['score'], reverse=True)
|
||||
top_intervals = sorted_intervals[:max_intervals]
|
||||
|
||||
for i, interval in enumerate(top_intervals):
|
||||
target_text = tokenizer.decode(interval['target_ids'])
|
||||
lines.append(
|
||||
f" {i+1}. [{interval['start']}:{interval['end']}] "
|
||||
f"score={interval['score']:.3f} "
|
||||
f"text='{target_text}'"
|
||||
)
|
||||
|
||||
if len(intervals) > max_intervals:
|
||||
lines.append(f" ... and {len(intervals) - max_intervals} more")
|
||||
|
||||
return "\n".join(lines)
|
||||
Reference in New Issue
Block a user