This commit is contained in:
Adil Hafeez 2025-12-25 20:13:19 -08:00
parent 784d4afa62
commit 157714c42d
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
9 changed files with 105 additions and 115 deletions

View file

@ -13,19 +13,22 @@ repos:
name: cargo-fmt
language: system
types: [file, rust]
entry: bash -c "cd crates/llm_gateway && cargo fmt"
entry: bash -c "cd crates && cargo fmt --all -- --check"
pass_filenames: false
- id: cargo-clippy
name: cargo-clippy
language: system
types: [file, rust]
entry: bash -c "cd crates/llm_gateway && cargo clippy --all"
entry: bash -c "cd crates && cargo clippy --all-targets --all-features -- -D warnings"
pass_filenames: false
- id: cargo-test
name: cargo-test
language: system
types: [file, rust]
entry: bash -c "cd crates && cargo test --lib"
pass_filenames: false
- repo: https://github.com/psf/black
rev: 23.1.0

View file

@ -741,7 +741,7 @@ impl ArchFunctionHandler {
if let Some(instruction) = extra_instruction {
if let Some(last) = processed_messages.last_mut() {
if let MessageContent::Text(content) = &mut last.content {
content.push_str("\n");
content.push('\n');
content.push_str(instruction);
}
}
@ -774,13 +774,11 @@ impl ArchFunctionHandler {
for i in (conversation_idx..messages.len()).rev() {
if let MessageContent::Text(content) = &messages[i].content {
num_tokens += content.len() / 4;
if num_tokens >= max_tokens {
if messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
if num_tokens >= max_tokens && messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
}
// Only update message_idx if we haven't hit the token limit yet
@ -861,7 +859,7 @@ impl ArchFunctionHandler {
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() {
let status = response.status();
@ -916,7 +914,7 @@ impl ArchFunctionHandler {
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() {
let status = response.status();
@ -933,9 +931,9 @@ impl ArchFunctionHandler {
let response_text = response
.text()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
serde_json::from_str(&response_text).map_err(|e| FunctionCallingError::JsonParseError(e))
serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError)
}
pub async fn function_calling_chat(
@ -977,8 +975,7 @@ impl ArchFunctionHandler {
if use_agent_orchestrator {
while let Some(chunk_result) = stream.next().await {
let chunk =
chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
@ -993,90 +990,80 @@ impl ArchFunctionHandler {
}
}
info!("[Agent Orchestrator]: response received");
} else {
if let Some(tools) = request.tools.as_ref() {
let mut hallucination_state = HallucinationState::new(tools);
let mut has_tool_calls = None;
let mut has_hallucination = false;
} else if let Some(tools) = request.tools.as_ref() {
let mut hallucination_state = HallucinationState::new(tools);
let mut has_tool_calls = None;
let mut has_hallucination = false;
while let Some(chunk_result) = stream.next().await {
let chunk =
chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
// Extract logprobs
let logprobs: Vec<f64> = choice
.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
if hallucination_state
.append_and_check_token_hallucination(content.to_string(), logprobs)
{
// Extract logprobs
let logprobs: Vec<f64> = choice
.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| {
v.get("logprob").and_then(|lp| lp.as_f64())
})
.collect()
})
.unwrap_or_default();
has_hallucination = true;
break;
}
if hallucination_state.append_and_check_token_hallucination(
content.to_string(),
logprobs,
) {
has_hallucination = true;
break;
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none()
{
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
}
}
}
}
if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message);
if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message);
let clarify_messages =
self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request =
self.create_request_with_extra_body(clarify_messages, false);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let retry_response = self.make_non_streaming_request(clarify_request).await?;
let retry_response = self.make_non_streaming_request(clarify_request).await?;
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
model_response = content.clone();
}
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
model_response = content.clone();
}
} else {
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk_result) = stream.next().await {
let chunk =
chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
model_response.push_str(content);
}
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
model_response.push_str(content);
}
}
}
@ -2009,12 +1996,12 @@ mod hallucination_tests {
// Test integer types
assert!(handler.check_value_type(&json!(42), "integer"));
assert!(handler.check_value_type(&json!(42), "int"));
assert!(!handler.check_value_type(&json!(3.14), "integer"));
assert!(!handler.check_value_type(&json!(3.15), "integer"));
// Test number types (accepts both int and float)
assert!(handler.check_value_type(&json!(3.14), "number"));
assert!(handler.check_value_type(&json!(3.15), "number"));
assert!(handler.check_value_type(&json!(42), "number"));
assert!(handler.check_value_type(&json!(3.14), "float"));
assert!(handler.check_value_type(&json!(3.15), "float"));
// Test boolean
assert!(handler.check_value_type(&json!(true), "boolean"));
@ -2073,7 +2060,7 @@ mod hallucination_tests {
.validate_or_convert_parameter(&json!(42), "number")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(3.14), "number")
.validate_or_convert_parameter(&json!(3.15), "number")
.unwrap());
}

View file

@ -14,7 +14,7 @@ use crate::router::plano_orchestrator::OrchestratorService;
/// 2. PipelineProcessor - executes the agent pipeline
/// 3. ResponseHandler - handles response streaming
#[cfg(test)]
mod integration_tests {
mod tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};

View file

@ -348,6 +348,7 @@ fn resolve_model_alias(
}
/// Builds the LLM span with all required and optional attributes.
#[allow(clippy::too_many_arguments)]
async fn build_llm_span(
traceparent: &str,
request_path: &str,
@ -378,7 +379,7 @@ async fn build_llm_span(
let operation_name = if request_path != upstream_path {
OperationNameBuilder::new()
.with_method("POST")
.with_path(&format!("{} >> {}", request_path, upstream_path))
.with_path(format!("{} >> {}", request_path, upstream_path))
.with_target(resolved_model)
.build()
} else {

View file

@ -82,6 +82,7 @@ impl PipelineProcessor {
}
/// Record a span for filter execution
#[allow(clippy::too_many_arguments)]
fn record_filter_span(
&self,
collector: &std::sync::Arc<common::traces::TraceCollector>,
@ -132,6 +133,7 @@ impl PipelineProcessor {
}
/// Record a span for MCP protocol interactions
#[allow(clippy::too_many_arguments)]
fn record_agent_filter_span(
&self,
collector: &std::sync::Arc<common::traces::TraceCollector>,
@ -156,12 +158,12 @@ impl PipelineProcessor {
.build();
let mut span_builder = SpanBuilder::new(&operation_name)
.with_span_id(span_id.unwrap_or_else(|| generate_random_span_id()))
.with_span_id(span_id.unwrap_or_else(generate_random_span_id))
.with_kind(SpanKind::Client)
.with_start_time(start_time)
.with_end_time(end_time)
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string()))
.with_attribute(http::TARGET, format!("/mcp ({})", operation))
.with_attribute("mcp.operation", operation.to_string())
.with_attribute("mcp.agent_id", agent_id.to_string())
.with_attribute(
@ -188,6 +190,7 @@ impl PipelineProcessor {
}
/// Process the filter chain of agents (all except the terminal agent)
#[allow(clippy::too_many_arguments)]
pub async fn process_filter_chain(
&mut self,
chat_history: &[Message],
@ -1023,7 +1026,7 @@ mod tests {
}
});
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string());
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body);
let mut server = Server::new_async().await;
let _m = server

View file

@ -164,7 +164,7 @@ impl ResponseHandler {
match transformed_event.provider_response() {
Ok(provider_response) => {
if let Some(content) = provider_response.content_delta() {
accumulated_text.push_str(&content);
accumulated_text.push_str(content);
} else {
info!("No content delta in provider response");
}
@ -174,7 +174,7 @@ impl ResponseHandler {
}
}
}
return Ok(accumulated_text);
Ok(accumulated_text)
} else {
// If not SSE, treat as regular text response
let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| {

View file

@ -144,7 +144,7 @@ impl OrchestratorModelV1 {
// Format routes: each route as JSON on its own line with standard spacing
let agent_orchestration_json_str = agent_orchestration_values
.iter()
.map(|pref| to_spaced_json(pref))
.map(to_spaced_json)
.collect::<Vec<String>>()
.join("\n");
let agent_orchestration_to_model_map: HashMap<String, String> = agent_orchestrations
@ -382,7 +382,7 @@ fn convert_to_orchestrator_preferences(
// Format routes: each route as JSON on its own line with standard spacing
let routes_str = orchestration_preferences
.iter()
.map(|pref| to_spaced_json(pref))
.map(to_spaced_json)
.collect::<Vec<String>>()
.join("\n");

View file

@ -108,7 +108,7 @@ pub enum StorageBackend {
}
impl StorageBackend {
pub fn from_str(s: &str) -> Option<Self> {
pub fn parse_backend(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"memory" => Some(StorageBackend::Memory),
"supabase" => Some(StorageBackend::Supabase),

View file

@ -51,6 +51,7 @@ pub struct ResponsesStateProcessor<P: StreamProcessor> {
}
impl<P: StreamProcessor> ResponsesStateProcessor<P> {
#[allow(clippy::too_many_arguments)]
pub fn new(
inner: P,
storage: Arc<dyn StateStorage>,
@ -137,24 +138,19 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
for event in sse_iter {
// Only process data lines (skip event-only lines)
if let Some(data_str) = &event.data {
// Try to parse as ResponsesAPIStreamEvent
if let Ok(stream_event) =
// Try to parse as ResponsesAPIStreamEvent and check if it's a ResponseCompleted event
if let Ok(ResponsesAPIStreamEvent::ResponseCompleted { response, .. }) =
serde_json::from_str::<ResponsesAPIStreamEvent>(data_str)
{
// Check if this is a ResponseCompleted event
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } =
stream_event
{
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
self.request_id,
response.id,
response.output.len()
);
self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone());
return; // Found what we need, exit early
}
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
self.request_id,
response.id,
response.output.len()
);
self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone());
return; // Found what we need, exit early
}
}
}