Files
UDora/udora/scheduling.py
javyduck 0d259fc3aa init_v1
2025-06-24 01:50:12 +00:00

191 lines
6.4 KiB
Python

"""
Weighted Interval Scheduling for UDora Target Positioning
This module implements the core weighted interval scheduling algorithm used by UDora
to optimally position target actions within the agent's reasoning trace.
"""
from typing import List, Dict, Any, Tuple
import torch
def weighted_interval_scheduling(intervals: List[Dict[str, Any]], num_location: int) -> List[int]:
"""
Implement Weighted Interval Scheduling Algorithm for optimal target placement.
This algorithm finds the optimal subset of non-overlapping intervals that maximizes
the total score, subject to a maximum number of intervals constraint.
Args:
intervals: List of interval dictionaries with keys:
- 'start': Start position of interval
- 'end': End position of interval
- 'score': Score/weight of interval
- 'target_ids': Token IDs for this interval
num_location: Maximum number of intervals to select
Returns:
List of selected interval indices in sorted order
Algorithm:
1. Sort intervals by end position
2. Compute predecessor array p[j] for each interval
3. Use dynamic programming: M[j][l] = max score using first j intervals with ≤l selections
4. Reconstruct optimal solution via backtracking
"""
if not intervals:
return []
# Step 1: Sort intervals by end position
intervals.sort(key=lambda x: x['end'])
n = len(intervals)
# Step 2: Compute p[j] for each interval (latest non-overlapping predecessor)
p = []
for j in range(n):
p_j = None
for i in range(j - 1, -1, -1):
if intervals[i]['end'] <= intervals[j]['start']:
p_j = i
break
p.append(p_j)
# Step 3: Initialize DP table M[j][l]
# M[j][l] = maximum score using first j intervals with at most l selections
M = [[0] * (num_location + 1) for _ in range(n + 1)]
# Step 4: Fill DP table
for j in range(1, n + 1):
for l in range(1, num_location + 1):
interval = intervals[j - 1]
# Option 1: Don't include current interval
exclude_score = M[j - 1][l]
# Option 2: Include current interval
if p[j - 1] is not None:
include_score = interval['score'] + M[p[j - 1] + 1][l - 1]
else:
include_score = interval['score']
M[j][l] = max(exclude_score, include_score)
# Step 5: Reconstruct solution via backtracking
selected_indices = _reconstruct_solution(M, intervals, p, n, num_location)
return sorted(selected_indices)
def _reconstruct_solution(M: List[List[float]], intervals: List[Dict[str, Any]],
p: List[int], j: int, l: int) -> List[int]:
"""
Reconstruct the optimal solution from the DP table.
Args:
M: DP table from weighted interval scheduling
intervals: List of interval dictionaries
p: Predecessor array
j: Current interval index
l: Current location budget
Returns:
List of selected interval indices
"""
selected = []
while j > 0 and l > 0:
interval = intervals[j - 1]
# Calculate include score
if p[j - 1] is not None:
include_score = interval['score'] + M[p[j - 1] + 1][l - 1]
else:
include_score = interval['score']
# Check if current interval was included in optimal solution
if M[j][l] == include_score:
selected.append(j - 1) # Add interval index
j = p[j - 1] + 1 if p[j - 1] is not None else 0
l -= 1
else:
j -= 1
return selected[::-1] # Reverse to maintain correct order
def filter_intervals_by_sequential_mode(intervals: List[Dict[str, Any]],
selected_indices: List[int],
sequential: bool = True) -> List[int]:
"""
Filter selected intervals based on sequential optimization mode.
In sequential mode, only keep intervals that meet the success threshold,
plus the best interval if none meet the threshold.
Args:
intervals: List of all intervals
selected_indices: Indices of intervals selected by scheduling algorithm
sequential: Whether to use sequential filtering
Returns:
Filtered list of interval indices
"""
if not sequential:
return selected_indices
filtered_indices = []
max_score = float('-inf')
max_score_index = None
for idx in selected_indices:
interval = intervals[idx]
target_length = len(interval['target_ids'])
success_threshold = target_length / (target_length + 1)
if interval['score'] >= success_threshold:
filtered_indices.append(idx)
elif interval['score'] > max_score:
max_score = interval['score']
max_score_index = idx
# If no intervals meet threshold, add the best one
if not filtered_indices and max_score_index is not None:
filtered_indices.append(max_score_index)
return sorted(filtered_indices)
def build_final_token_sequence(intervals: List[Dict[str, Any]],
selected_indices: List[int],
generated_ids: List[int],
device: torch.device) -> List[torch.Tensor]:
"""
Build the final token sequence with selected intervals inserted.
Args:
intervals: List of all intervals
selected_indices: Indices of selected intervals
generated_ids: Original generated token sequence
device: PyTorch device for tensor creation
Returns:
List of tensors alternating between context and target sequences
"""
final_generated_ids = []
prev_end = 0
for idx in selected_indices:
interval = intervals[idx]
# Add tokens before the interval (context)
context_tokens = generated_ids[prev_end:interval['start']]
target_tokens = interval['target_ids']
final_generated_ids.extend([
torch.tensor([context_tokens], device=device, dtype=torch.int64),
torch.tensor([target_tokens], device=device, dtype=torch.int64)
])
prev_end = interval['end']
return final_generated_ids