mirror of
https://github.com/AI-secure/UDora.git
synced 2026-02-12 21:03:21 +00:00
191 lines
6.4 KiB
Python
191 lines
6.4 KiB
Python
"""
|
|
Weighted Interval Scheduling for UDora Target Positioning
|
|
|
|
This module implements the core weighted interval scheduling algorithm used by UDora
|
|
to optimally position target actions within the agent's reasoning trace.
|
|
"""
|
|
|
|
from typing import List, Dict, Any, Tuple
|
|
import torch
|
|
|
|
|
|
def weighted_interval_scheduling(intervals: List[Dict[str, Any]], num_location: int) -> List[int]:
|
|
"""
|
|
Implement Weighted Interval Scheduling Algorithm for optimal target placement.
|
|
|
|
This algorithm finds the optimal subset of non-overlapping intervals that maximizes
|
|
the total score, subject to a maximum number of intervals constraint.
|
|
|
|
Args:
|
|
intervals: List of interval dictionaries with keys:
|
|
- 'start': Start position of interval
|
|
- 'end': End position of interval
|
|
- 'score': Score/weight of interval
|
|
- 'target_ids': Token IDs for this interval
|
|
num_location: Maximum number of intervals to select
|
|
|
|
Returns:
|
|
List of selected interval indices in sorted order
|
|
|
|
Algorithm:
|
|
1. Sort intervals by end position
|
|
2. Compute predecessor array p[j] for each interval
|
|
3. Use dynamic programming: M[j][l] = max score using first j intervals with ≤l selections
|
|
4. Reconstruct optimal solution via backtracking
|
|
"""
|
|
if not intervals:
|
|
return []
|
|
|
|
# Step 1: Sort intervals by end position
|
|
intervals.sort(key=lambda x: x['end'])
|
|
n = len(intervals)
|
|
|
|
# Step 2: Compute p[j] for each interval (latest non-overlapping predecessor)
|
|
p = []
|
|
for j in range(n):
|
|
p_j = None
|
|
for i in range(j - 1, -1, -1):
|
|
if intervals[i]['end'] <= intervals[j]['start']:
|
|
p_j = i
|
|
break
|
|
p.append(p_j)
|
|
|
|
# Step 3: Initialize DP table M[j][l]
|
|
# M[j][l] = maximum score using first j intervals with at most l selections
|
|
M = [[0] * (num_location + 1) for _ in range(n + 1)]
|
|
|
|
# Step 4: Fill DP table
|
|
for j in range(1, n + 1):
|
|
for l in range(1, num_location + 1):
|
|
interval = intervals[j - 1]
|
|
|
|
# Option 1: Don't include current interval
|
|
exclude_score = M[j - 1][l]
|
|
|
|
# Option 2: Include current interval
|
|
if p[j - 1] is not None:
|
|
include_score = interval['score'] + M[p[j - 1] + 1][l - 1]
|
|
else:
|
|
include_score = interval['score']
|
|
|
|
M[j][l] = max(exclude_score, include_score)
|
|
|
|
# Step 5: Reconstruct solution via backtracking
|
|
selected_indices = _reconstruct_solution(M, intervals, p, n, num_location)
|
|
|
|
return sorted(selected_indices)
|
|
|
|
|
|
def _reconstruct_solution(M: List[List[float]], intervals: List[Dict[str, Any]],
|
|
p: List[int], j: int, l: int) -> List[int]:
|
|
"""
|
|
Reconstruct the optimal solution from the DP table.
|
|
|
|
Args:
|
|
M: DP table from weighted interval scheduling
|
|
intervals: List of interval dictionaries
|
|
p: Predecessor array
|
|
j: Current interval index
|
|
l: Current location budget
|
|
|
|
Returns:
|
|
List of selected interval indices
|
|
"""
|
|
selected = []
|
|
|
|
while j > 0 and l > 0:
|
|
interval = intervals[j - 1]
|
|
|
|
# Calculate include score
|
|
if p[j - 1] is not None:
|
|
include_score = interval['score'] + M[p[j - 1] + 1][l - 1]
|
|
else:
|
|
include_score = interval['score']
|
|
|
|
# Check if current interval was included in optimal solution
|
|
if M[j][l] == include_score:
|
|
selected.append(j - 1) # Add interval index
|
|
j = p[j - 1] + 1 if p[j - 1] is not None else 0
|
|
l -= 1
|
|
else:
|
|
j -= 1
|
|
|
|
return selected[::-1] # Reverse to maintain correct order
|
|
|
|
|
|
def filter_intervals_by_sequential_mode(intervals: List[Dict[str, Any]],
|
|
selected_indices: List[int],
|
|
sequential: bool = True) -> List[int]:
|
|
"""
|
|
Filter selected intervals based on sequential optimization mode.
|
|
|
|
In sequential mode, only keep intervals that meet the success threshold,
|
|
plus the best interval if none meet the threshold.
|
|
|
|
Args:
|
|
intervals: List of all intervals
|
|
selected_indices: Indices of intervals selected by scheduling algorithm
|
|
sequential: Whether to use sequential filtering
|
|
|
|
Returns:
|
|
Filtered list of interval indices
|
|
"""
|
|
if not sequential:
|
|
return selected_indices
|
|
|
|
filtered_indices = []
|
|
max_score = float('-inf')
|
|
max_score_index = None
|
|
|
|
for idx in selected_indices:
|
|
interval = intervals[idx]
|
|
target_length = len(interval['target_ids'])
|
|
success_threshold = target_length / (target_length + 1)
|
|
|
|
if interval['score'] >= success_threshold:
|
|
filtered_indices.append(idx)
|
|
elif interval['score'] > max_score:
|
|
max_score = interval['score']
|
|
max_score_index = idx
|
|
|
|
# If no intervals meet threshold, add the best one
|
|
if not filtered_indices and max_score_index is not None:
|
|
filtered_indices.append(max_score_index)
|
|
|
|
return sorted(filtered_indices)
|
|
|
|
|
|
def build_final_token_sequence(intervals: List[Dict[str, Any]],
|
|
selected_indices: List[int],
|
|
generated_ids: List[int],
|
|
device: torch.device) -> List[torch.Tensor]:
|
|
"""
|
|
Build the final token sequence with selected intervals inserted.
|
|
|
|
Args:
|
|
intervals: List of all intervals
|
|
selected_indices: Indices of selected intervals
|
|
generated_ids: Original generated token sequence
|
|
device: PyTorch device for tensor creation
|
|
|
|
Returns:
|
|
List of tensors alternating between context and target sequences
|
|
"""
|
|
final_generated_ids = []
|
|
prev_end = 0
|
|
|
|
for idx in selected_indices:
|
|
interval = intervals[idx]
|
|
|
|
# Add tokens before the interval (context)
|
|
context_tokens = generated_ids[prev_end:interval['start']]
|
|
target_tokens = interval['target_ids']
|
|
|
|
final_generated_ids.extend([
|
|
torch.tensor([context_tokens], device=device, dtype=torch.int64),
|
|
torch.tensor([target_tokens], device=device, dtype=torch.int64)
|
|
])
|
|
|
|
prev_end = interval['end']
|
|
|
|
return final_generated_ids |