1use std::cell::RefCell;
9use std::collections::HashSet;
10use std::time::{Duration, Instant};
11
12use aios_spec::{
13 DecisionBackendResult, DecisionRoute, SanitizedEventType, SemanticHint, StructuredContext,
14};
15
16use crate::backends::fallback::FallbackNoOpBackend;
17use crate::backends::rule_based::RuleBasedBackend;
18use crate::DecisionBackend;
19
20#[derive(Debug, Clone)]
25pub struct RouterConfig {
26 pub privacy_score_threshold: usize,
28 pub circuit_breaker_threshold: u32,
30 pub circuit_breaker_window_secs: u64,
32}
33
34impl Default for RouterConfig {
35 fn default() -> Self {
36 Self {
37 privacy_score_threshold: 3,
38 circuit_breaker_threshold: 5,
39 circuit_breaker_window_secs: 60,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
49struct ErrorRecord {
50 timestamp: Instant,
51}
52
53#[derive(Debug, Clone, Default)]
54struct CircuitState {
55 errors: Vec<ErrorRecord>,
56}
57
58impl CircuitState {
59 fn record_error(&mut self) {
60 self.errors.push(ErrorRecord {
61 timestamp: Instant::now(),
62 });
63 }
64
65 fn record_success(&mut self) {
66 self.errors.clear();
67 }
68
69 fn count_recent_errors(&self, window_secs: u64) -> u32 {
70 let cutoff = Instant::now()
71 .checked_sub(Duration::from_secs(window_secs))
72 .unwrap_or(Instant::now());
73 self.errors.iter().filter(|e| e.timestamp >= cutoff).count() as u32
74 }
75}
76
77#[derive(Debug, Clone)]
82enum RoutingReason {
83 CircuitBreakerTripped { failure_count: u32 },
84 PrivacySensitive { score: usize },
85 LowComplexity,
86 MediumComplexity,
87 HighComplexity,
88}
89
90impl RoutingReason {
91 fn tag(&self) -> String {
92 match self {
93 RoutingReason::CircuitBreakerTripped { failure_count } => {
94 format!("routing:circuit_breaker_fallback(errors={})", failure_count)
95 },
96 RoutingReason::PrivacySensitive { score } => {
97 format!("routing:privacy_sensitive(score={})", score)
98 },
99 RoutingReason::LowComplexity => "routing:low_complexity".into(),
100 RoutingReason::MediumComplexity => {
101 "routing:medium_complexity(rule_based_fallback)".into()
102 },
103 RoutingReason::HighComplexity => "routing:high_complexity(rule_based_fallback)".into(),
104 }
105 }
106}
107
108pub struct DecisionRouter {
113 config: RouterConfig,
114 rule_based: RuleBasedBackend,
115 fallback: FallbackNoOpBackend,
116 circuit_state: RefCell<CircuitState>,
117}
118
119impl DecisionRouter {
120 pub fn new(config: RouterConfig) -> Self {
121 Self {
122 config,
123 rule_based: RuleBasedBackend,
124 fallback: FallbackNoOpBackend,
125 circuit_state: RefCell::new(CircuitState::default()),
126 }
127 }
128
129 pub fn evaluate(&self, context: &StructuredContext) -> DecisionBackendResult {
134 let (route, reason) = self.determine_route(context);
135
136 let mut result = match route {
137 DecisionRoute::RuleBased => self.rule_based.evaluate(context),
138 DecisionRoute::FallbackNoOp => self.fallback.evaluate(context),
139 _ => self.rule_based.evaluate(context),
141 };
142
143 result.rationale_tags.push(reason.tag());
145
146 let mut state = self.circuit_state.borrow_mut();
148 if result.error.is_some() {
149 state.record_error();
150 } else {
151 state.record_success();
152 }
153
154 result
155 }
156
157 fn determine_route(&self, context: &StructuredContext) -> (DecisionRoute, RoutingReason) {
160 let error_count = self
162 .circuit_state
163 .borrow()
164 .count_recent_errors(self.config.circuit_breaker_window_secs);
165 if error_count >= self.config.circuit_breaker_threshold {
166 return (
167 DecisionRoute::FallbackNoOp,
168 RoutingReason::CircuitBreakerTripped {
169 failure_count: error_count,
170 },
171 );
172 }
173
174 let privacy_score = Self::compute_privacy_score(context);
176 if privacy_score > self.config.privacy_score_threshold {
177 return (
178 DecisionRoute::RuleBased,
179 RoutingReason::PrivacySensitive {
180 score: privacy_score,
181 },
182 );
183 }
184
185 let unique_types = Self::count_unique_semantic_hint_types(context);
187 match unique_types {
188 0 | 1 => (DecisionRoute::RuleBased, RoutingReason::LowComplexity),
189 2 | 3 => (DecisionRoute::RuleBased, RoutingReason::MediumComplexity),
190 _ => (DecisionRoute::RuleBased, RoutingReason::HighComplexity),
191 }
192 }
193
194 fn compute_privacy_score(context: &StructuredContext) -> usize {
198 context
199 .events
200 .iter()
201 .map(|event| match &event.event_type {
202 SanitizedEventType::Notification { semantic_hints, .. } => semantic_hints
203 .iter()
204 .filter(|h| {
205 matches!(
206 h,
207 SemanticHint::VerificationCode | SemanticHint::FinancialContext
208 )
209 })
210 .count(),
211 SanitizedEventType::AppTransition { .. } => 1,
212 _ => 0,
213 })
214 .sum()
215 }
216
217 fn count_unique_semantic_hint_types(context: &StructuredContext) -> usize {
219 let mut seen: HashSet<&SemanticHint> = HashSet::new();
220 for event in &context.events {
221 if let SanitizedEventType::Notification { semantic_hints, .. } = &event.event_type {
222 for hint in semantic_hints {
223 seen.insert(hint);
224 }
225 }
226 }
227 seen.len()
228 }
229}
230
231impl Default for DecisionRouter {
232 fn default() -> Self {
233 Self::new(RouterConfig::default())
234 }
235}