Skip to main content

aios_agent/
router.rs

1//! DecisionRouter — 多层级决策路由。
2//!
3//! 路由优先级:
4//! 1. Circuit breaker — 连续错误超阈值 → FallbackNoOp
5//! 2. Privacy sensitivity — 敏感信号过多 → RuleBased 降级
6//! 3. Semantic complexity — 信号种类数决定后端(当前统一收敛到 RuleBased)
7
8use std::cell::RefCell;
9use std::collections::HashSet;
10use std::time::{Duration, Instant};
11
12use aios_spec::{
13    DecisionBackendResult, DecisionRoute, SanitizedEventType, SemanticHint, StructuredContext,
14};
15
16use crate::backends::fallback::FallbackNoOpBackend;
17use crate::backends::rule_based::RuleBasedBackend;
18use crate::DecisionBackend;
19
20// ============================================================
21// RouterConfig
22// ============================================================
23
24#[derive(Debug, Clone)]
25pub struct RouterConfig {
26    /// Number of privacy-sensitive signals above which cloud routing is blocked.
27    pub privacy_score_threshold: usize,
28    /// Number of consecutive errors before the circuit breaker trips.
29    pub circuit_breaker_threshold: u32,
30    /// Time window (in seconds) over which consecutive errors are counted.
31    pub circuit_breaker_window_secs: u64,
32}
33
34impl Default for RouterConfig {
35    fn default() -> Self {
36        Self {
37            privacy_score_threshold: 3,
38            circuit_breaker_threshold: 5,
39            circuit_breaker_window_secs: 60,
40        }
41    }
42}
43
44// ============================================================
45// Circuit breaker state
46// ============================================================
47
48#[derive(Debug, Clone)]
49struct ErrorRecord {
50    timestamp: Instant,
51}
52
53#[derive(Debug, Clone, Default)]
54struct CircuitState {
55    errors: Vec<ErrorRecord>,
56}
57
58impl CircuitState {
59    fn record_error(&mut self) {
60        self.errors.push(ErrorRecord {
61            timestamp: Instant::now(),
62        });
63    }
64
65    fn record_success(&mut self) {
66        self.errors.clear();
67    }
68
69    fn count_recent_errors(&self, window_secs: u64) -> u32 {
70        let cutoff = Instant::now()
71            .checked_sub(Duration::from_secs(window_secs))
72            .unwrap_or(Instant::now());
73        self.errors.iter().filter(|e| e.timestamp >= cutoff).count() as u32
74    }
75}
76
77// ============================================================
78// Routing reason
79// ============================================================
80
81#[derive(Debug, Clone)]
82enum RoutingReason {
83    CircuitBreakerTripped { failure_count: u32 },
84    PrivacySensitive { score: usize },
85    LowComplexity,
86    MediumComplexity,
87    HighComplexity,
88}
89
90impl RoutingReason {
91    fn tag(&self) -> String {
92        match self {
93            RoutingReason::CircuitBreakerTripped { failure_count } => {
94                format!("routing:circuit_breaker_fallback(errors={})", failure_count)
95            },
96            RoutingReason::PrivacySensitive { score } => {
97                format!("routing:privacy_sensitive(score={})", score)
98            },
99            RoutingReason::LowComplexity => "routing:low_complexity".into(),
100            RoutingReason::MediumComplexity => {
101                "routing:medium_complexity(rule_based_fallback)".into()
102            },
103            RoutingReason::HighComplexity => "routing:high_complexity(rule_based_fallback)".into(),
104        }
105    }
106}
107
108// ============================================================
109// DecisionRouter
110// ============================================================
111
112pub struct DecisionRouter {
113    config: RouterConfig,
114    rule_based: RuleBasedBackend,
115    fallback: FallbackNoOpBackend,
116    circuit_state: RefCell<CircuitState>,
117}
118
119impl DecisionRouter {
120    pub fn new(config: RouterConfig) -> Self {
121        Self {
122            config,
123            rule_based: RuleBasedBackend,
124            fallback: FallbackNoOpBackend,
125            circuit_state: RefCell::new(CircuitState::default()),
126        }
127    }
128
129    /// Evaluate a StructuredContext through the routing pipeline.
130    ///
131    /// Uses interior mutability (`RefCell`) to track circuit breaker state
132    /// across calls without requiring `&mut self`.
133    pub fn evaluate(&self, context: &StructuredContext) -> DecisionBackendResult {
134        let (route, reason) = self.determine_route(context);
135
136        let mut result = match route {
137            DecisionRoute::RuleBased => self.rule_based.evaluate(context),
138            DecisionRoute::FallbackNoOp => self.fallback.evaluate(context),
139            // Future routes (LocalEvaluator, CloudLlm) fall back to RuleBased
140            _ => self.rule_based.evaluate(context),
141        };
142
143        // Inject routing reason tag
144        result.rationale_tags.push(reason.tag());
145
146        // Update circuit breaker state
147        let mut state = self.circuit_state.borrow_mut();
148        if result.error.is_some() {
149            state.record_error();
150        } else {
151            state.record_success();
152        }
153
154        result
155    }
156
157    // --- Private routing logic ---
158
159    fn determine_route(&self, context: &StructuredContext) -> (DecisionRoute, RoutingReason) {
160        // Priority 1: Circuit breaker
161        let error_count = self
162            .circuit_state
163            .borrow()
164            .count_recent_errors(self.config.circuit_breaker_window_secs);
165        if error_count >= self.config.circuit_breaker_threshold {
166            return (
167                DecisionRoute::FallbackNoOp,
168                RoutingReason::CircuitBreakerTripped {
169                    failure_count: error_count,
170                },
171            );
172        }
173
174        // Priority 2: Privacy sensitivity
175        let privacy_score = Self::compute_privacy_score(context);
176        if privacy_score > self.config.privacy_score_threshold {
177            return (
178                DecisionRoute::RuleBased,
179                RoutingReason::PrivacySensitive {
180                    score: privacy_score,
181                },
182            );
183        }
184
185        // Priority 3: Semantic complexity
186        let unique_types = Self::count_unique_semantic_hint_types(context);
187        match unique_types {
188            0 | 1 => (DecisionRoute::RuleBased, RoutingReason::LowComplexity),
189            2 | 3 => (DecisionRoute::RuleBased, RoutingReason::MediumComplexity),
190            _ => (DecisionRoute::RuleBased, RoutingReason::HighComplexity),
191        }
192    }
193
194    /// Count privacy-sensitive signals:
195    /// - Notification events with VerificationCode or FinancialContext hints
196    /// - AppTransition events
197    fn compute_privacy_score(context: &StructuredContext) -> usize {
198        context
199            .events
200            .iter()
201            .map(|event| match &event.event_type {
202                SanitizedEventType::Notification { semantic_hints, .. } => semantic_hints
203                    .iter()
204                    .filter(|h| {
205                        matches!(
206                            h,
207                            SemanticHint::VerificationCode | SemanticHint::FinancialContext
208                        )
209                    })
210                    .count(),
211                SanitizedEventType::AppTransition { .. } => 1,
212                _ => 0,
213            })
214            .sum()
215    }
216
217    /// Count unique SemanticHint variants across all notification events.
218    fn count_unique_semantic_hint_types(context: &StructuredContext) -> usize {
219        let mut seen: HashSet<&SemanticHint> = HashSet::new();
220        for event in &context.events {
221            if let SanitizedEventType::Notification { semantic_hints, .. } = &event.event_type {
222                for hint in semantic_hints {
223                    seen.insert(hint);
224                }
225            }
226        }
227        seen.len()
228    }
229}
230
231impl Default for DecisionRouter {
232    fn default() -> Self {
233        Self::new(RouterConfig::default())
234    }
235}