""" Weighted Interval Scheduling for UDora Target Positioning This module implements the core weighted interval scheduling algorithm used by UDora to optimally position target actions within the agent's reasoning trace. """ from typing import List, Dict, Any, Tuple import torch def weighted_interval_scheduling(intervals: List[Dict[str, Any]], num_location: int) -> List[int]: """ Implement Weighted Interval Scheduling Algorithm for optimal target placement. This algorithm finds the optimal subset of non-overlapping intervals that maximizes the total score, subject to a maximum number of intervals constraint. Args: intervals: List of interval dictionaries with keys: - 'start': Start position of interval - 'end': End position of interval - 'score': Score/weight of interval - 'target_ids': Token IDs for this interval num_location: Maximum number of intervals to select Returns: List of selected interval indices in sorted order Algorithm: 1. Sort intervals by end position 2. Compute predecessor array p[j] for each interval 3. Use dynamic programming: M[j][l] = max score using first j intervals with ≤l selections 4. Reconstruct optimal solution via backtracking """ if not intervals: return [] # Step 1: Sort intervals by end position intervals.sort(key=lambda x: x['end']) n = len(intervals) # Step 2: Compute p[j] for each interval (latest non-overlapping predecessor) p = [] for j in range(n): p_j = None for i in range(j - 1, -1, -1): if intervals[i]['end'] <= intervals[j]['start']: p_j = i break p.append(p_j) # Step 3: Initialize DP table M[j][l] # M[j][l] = maximum score using first j intervals with at most l selections M = [[0] * (num_location + 1) for _ in range(n + 1)] # Step 4: Fill DP table for j in range(1, n + 1): for l in range(1, num_location + 1): interval = intervals[j - 1] # Option 1: Don't include current interval exclude_score = M[j - 1][l] # Option 2: Include current interval if p[j - 1] is not None: include_score = interval['score'] + M[p[j - 1] + 1][l - 1] else: include_score = interval['score'] M[j][l] = max(exclude_score, include_score) # Step 5: Reconstruct solution via backtracking selected_indices = _reconstruct_solution(M, intervals, p, n, num_location) return sorted(selected_indices) def _reconstruct_solution(M: List[List[float]], intervals: List[Dict[str, Any]], p: List[int], j: int, l: int) -> List[int]: """ Reconstruct the optimal solution from the DP table. Args: M: DP table from weighted interval scheduling intervals: List of interval dictionaries p: Predecessor array j: Current interval index l: Current location budget Returns: List of selected interval indices """ selected = [] while j > 0 and l > 0: interval = intervals[j - 1] # Calculate include score if p[j - 1] is not None: include_score = interval['score'] + M[p[j - 1] + 1][l - 1] else: include_score = interval['score'] # Check if current interval was included in optimal solution if M[j][l] == include_score: selected.append(j - 1) # Add interval index j = p[j - 1] + 1 if p[j - 1] is not None else 0 l -= 1 else: j -= 1 return selected[::-1] # Reverse to maintain correct order def filter_intervals_by_sequential_mode(intervals: List[Dict[str, Any]], selected_indices: List[int], sequential: bool = True) -> List[int]: """ Filter selected intervals based on sequential optimization mode. In sequential mode, only keep intervals that meet the success threshold, plus the best interval if none meet the threshold. Args: intervals: List of all intervals selected_indices: Indices of intervals selected by scheduling algorithm sequential: Whether to use sequential filtering Returns: Filtered list of interval indices """ if not sequential: return selected_indices filtered_indices = [] max_score = float('-inf') max_score_index = None for idx in selected_indices: interval = intervals[idx] target_length = len(interval['target_ids']) success_threshold = target_length / (target_length + 1) if interval['score'] >= success_threshold: filtered_indices.append(idx) elif interval['score'] > max_score: max_score = interval['score'] max_score_index = idx # If no intervals meet threshold, add the best one if not filtered_indices and max_score_index is not None: filtered_indices.append(max_score_index) return sorted(filtered_indices) def build_final_token_sequence(intervals: List[Dict[str, Any]], selected_indices: List[int], generated_ids: List[int], device: torch.device) -> List[torch.Tensor]: """ Build the final token sequence with selected intervals inserted. Args: intervals: List of all intervals selected_indices: Indices of selected intervals generated_ids: Original generated token sequence device: PyTorch device for tensor creation Returns: List of tensors alternating between context and target sequences """ final_generated_ids = [] prev_end = 0 for idx in selected_indices: interval = intervals[idx] # Add tokens before the interval (context) context_tokens = generated_ids[prev_end:interval['start']] target_tokens = interval['target_ids'] final_generated_ids.extend([ torch.tensor([context_tokens], device=device, dtype=torch.int64), torch.tensor([target_tokens], device=device, dtype=torch.int64) ]) prev_end = interval['end'] return final_generated_ids