mirror of
https://github.com/katanemo/plano.git
synced 2026-05-08 23:32:43 +02:00
Improve end to end tracing (#628)
* adding canonical tracing support via bright-staff * improved formatting for tools in the traces * removing anthropic from the currency exchange demo * using Envoy to transport traces, not calling OTEL directly * moving otel collcetor cluster outside tracing if/else * minor fixes to not write to the OTEL collector if tracing is disabled * fixed PR comments and added more trace attributes * more fixes based on PR comments * more clean up based on PR comments --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
8adb9795d8
commit
a79f55f313
34 changed files with 2556 additions and 403 deletions
345
crates/brightstaff/src/handlers/llm.rs
Normal file
345
crates/brightstaff/src/handlers/llm.rs
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{LlmProvider, ModelAlias};
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
|
||||
use common::traces::TraceCollector;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::tracing::operation_component;
|
||||
|
||||
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub async fn llm_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<Vec<LlmProvider>>>,
|
||||
trace_collector: Arc<TraceCollector>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
|
||||
// Extract or generate traceparent - this establishes the trace context for all spans
|
||||
let traceparent: String = request_headers
|
||||
.get("traceparent")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| {
|
||||
use uuid::Uuid;
|
||||
let trace_id = Uuid::new_v4().to_string().replace("-", "");
|
||||
format!("00-{}-0000000000000000-01", trace_id)
|
||||
});
|
||||
|
||||
let mut request_headers = request_headers;
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
"Received request body (raw utf8): {}",
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let temperature = client_request.get_temperature();
|
||||
let is_streaming_request = client_request.is_streaming();
|
||||
let resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
||||
|
||||
// Extract tool names and user message preview for span attributes
|
||||
let tool_names = client_request.get_tool_names();
|
||||
let user_message_preview = client_request.get_recent_user_message()
|
||||
.map(|msg| truncate_message(&msg, 50));
|
||||
|
||||
client_request.set_model(resolved_model.clone());
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!("Removed archgw_preference_config from metadata");
|
||||
}
|
||||
|
||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||
|
||||
// Determine routing using the dedicated router_chat module
|
||||
let routing_result = match router_chat_get_upstream_model(
|
||||
router_service,
|
||||
client_request, // Pass the original request - router_chat will convert it
|
||||
&request_headers,
|
||||
trace_collector.clone(),
|
||||
&traceparent,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
let mut internal_error = Response::new(full(err.message));
|
||||
*internal_error.status_mut() = err.status_code;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let model_name = routing_result.model_name;
|
||||
|
||||
debug!(
|
||||
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
|
||||
full_qualified_llm_provider_url, model_name
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(&model_name).unwrap(),
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
// remove content-length header if it exists
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
// Capture start time right before sending request to upstream
|
||||
let request_start_time = std::time::Instant::now();
|
||||
let request_start_system_time = std::time::SystemTime::now();
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(client_request_bytes_for_upstream)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
// copy over the headers and status code from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
let mut response = Response::builder().status(upstream_status);
|
||||
let headers = response.headers_mut().unwrap();
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
// Build LLM span with actual status code using constants
|
||||
let byte_stream = llm_response.bytes_stream();
|
||||
|
||||
// Build the LLM span (will be finalized after streaming completes)
|
||||
let llm_span = build_llm_span(
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&resolved_model,
|
||||
&model_name,
|
||||
upstream_status.as_u16(),
|
||||
is_streaming_request,
|
||||
request_start_system_time,
|
||||
tool_names,
|
||||
user_message_preview,
|
||||
temperature,
|
||||
&llm_providers,
|
||||
).await;
|
||||
|
||||
// Use PassthroughProcessor to track streaming metrics and finalize the span
|
||||
let processor = ObservableStreamProcessor::new(
|
||||
trace_collector,
|
||||
operation_component::LLM,
|
||||
llm_span,
|
||||
request_start_time,
|
||||
);
|
||||
|
||||
let streaming_response = create_streaming_response(byte_stream, processor, 16);
|
||||
|
||||
match response.body(streaming_response.body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to create response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
Ok(internal_error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolves model aliases by looking up the requested model in the model_aliases map.
|
||||
/// Returns the target model if an alias is found, otherwise returns the original model.
|
||||
fn resolve_model_alias(
|
||||
model_from_request: &str,
|
||||
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
) -> String {
|
||||
if let Some(aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = aliases.get(model_from_request) {
|
||||
debug!(
|
||||
"Model Alias: 'From {}' -> 'To {}'",
|
||||
model_from_request, model_alias.target
|
||||
);
|
||||
return model_alias.target.clone();
|
||||
}
|
||||
}
|
||||
model_from_request.to_string()
|
||||
}
|
||||
|
||||
/// Builds the LLM span with all required and optional attributes.
|
||||
async fn build_llm_span(
|
||||
traceparent: &str,
|
||||
request_path: &str,
|
||||
resolved_model: &str,
|
||||
model_name: &str,
|
||||
status_code: u16,
|
||||
is_streaming: bool,
|
||||
start_time: std::time::SystemTime,
|
||||
tool_names: Option<Vec<String>>,
|
||||
user_message_preview: Option<String>,
|
||||
temperature: Option<f32>,
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
) -> common::traces::Span {
|
||||
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
|
||||
use crate::tracing::{http, llm, OperationNameBuilder};
|
||||
|
||||
// Calculate the upstream path based on provider configuration
|
||||
let upstream_path = get_upstream_path(
|
||||
llm_providers,
|
||||
model_name,
|
||||
request_path,
|
||||
resolved_model,
|
||||
is_streaming,
|
||||
).await;
|
||||
|
||||
// Build operation name showing path transformation if different
|
||||
let operation_name = if request_path != upstream_path {
|
||||
OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(&format!("{} >> {}", request_path, upstream_path))
|
||||
.with_target(resolved_model)
|
||||
.build()
|
||||
} else {
|
||||
OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(request_path)
|
||||
.with_target(resolved_model)
|
||||
.build()
|
||||
};
|
||||
|
||||
let (trace_id, parent_span_id) = parse_traceparent(traceparent);
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_trace_id(&trace_id)
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_start_time(start_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::STATUS_CODE, status_code.to_string())
|
||||
.with_attribute(http::TARGET, request_path.to_string())
|
||||
.with_attribute(http::UPSTREAM_TARGET, upstream_path)
|
||||
.with_attribute(llm::MODEL_NAME, resolved_model.to_string())
|
||||
.with_attribute(llm::IS_STREAMING, is_streaming.to_string());
|
||||
|
||||
// Only set parent span ID if it exists (not a root span)
|
||||
if let Some(parent) = parent_span_id {
|
||||
span_builder = span_builder.with_parent_span_id(&parent);
|
||||
}
|
||||
|
||||
// Add optional attributes
|
||||
if let Some(temp) = temperature {
|
||||
span_builder = span_builder.with_attribute(llm::TEMPERATURE, temp.to_string());
|
||||
}
|
||||
|
||||
if let Some(tools) = tool_names {
|
||||
let formatted_tools = tools.iter()
|
||||
.map(|name| format!("{}(...)", name))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
span_builder = span_builder.with_attribute(llm::TOOLS, formatted_tools);
|
||||
}
|
||||
|
||||
if let Some(preview) = user_message_preview {
|
||||
span_builder = span_builder.with_attribute(llm::USER_MESSAGE_PREVIEW, preview);
|
||||
}
|
||||
|
||||
span_builder.build()
|
||||
}
|
||||
|
||||
/// Calculates the upstream path for the provider based on the model name.
|
||||
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
|
||||
/// then uses target_endpoint_for_provider to calculate the correct upstream path.
|
||||
async fn get_upstream_path(
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
model_name: &str,
|
||||
request_path: &str,
|
||||
resolved_model: &str,
|
||||
is_streaming: bool,
|
||||
) -> String {
|
||||
let providers_lock = llm_providers.read().await;
|
||||
|
||||
// First, try to find by model name or provider name
|
||||
let provider = providers_lock.iter().find(|p| {
|
||||
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|
||||
|| p.name == model_name
|
||||
});
|
||||
|
||||
let (provider_id, base_url_path_prefix) = if let Some(provider) = provider {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
(provider_id, prefix)
|
||||
} else {
|
||||
let default_provider = providers_lock.iter().find(|p| {
|
||||
p.default.unwrap_or(false)
|
||||
});
|
||||
|
||||
if let Some(provider) = default_provider {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
(provider_id, prefix)
|
||||
} else {
|
||||
// Last resort: use OpenAI as hardcoded fallback
|
||||
warn!("No default provider found, falling back to OpenAI");
|
||||
(hermesllm::ProviderId::OpenAI, None)
|
||||
}
|
||||
};
|
||||
|
||||
drop(providers_lock);
|
||||
|
||||
// Calculate the upstream path using the proper API
|
||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
|
||||
.expect("Should have valid API endpoint");
|
||||
|
||||
client_api.target_endpoint_for_provider(
|
||||
&provider_id,
|
||||
request_path,
|
||||
resolved_model,
|
||||
is_streaming,
|
||||
base_url_path_prefix.as_deref(),
|
||||
)
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
pub mod agent_chat_completions;
|
||||
pub mod agent_selector;
|
||||
pub mod router;
|
||||
pub mod llm;
|
||||
pub mod router_chat;
|
||||
pub mod models;
|
||||
pub mod function_calling;
|
||||
pub mod pipeline_processor;
|
||||
|
|
|
|||
|
|
@ -1,252 +0,0 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{ModelAlias, ModelUsagePreference};
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::handlers::utils::{create_streaming_response, PassthroughProcessor};
|
||||
|
||||
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub async fn router_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let mut request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
"Received request body (raw utf8): {}",
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let is_streaming_request = client_request.is_streaming();
|
||||
let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = model_aliases.get(&model_from_request) {
|
||||
debug!(
|
||||
"Model Alias: 'From {}' -> 'To {}'",
|
||||
model_from_request, model_alias.target
|
||||
);
|
||||
model_alias.target.clone()
|
||||
} else {
|
||||
model_from_request.clone()
|
||||
}
|
||||
} else {
|
||||
model_from_request.clone()
|
||||
};
|
||||
client_request.set_model(resolved_model.clone());
|
||||
|
||||
// Clone metadata for routing and remove archgw_preference_config from original
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!("Removed archgw_preference_config from metadata");
|
||||
}
|
||||
|
||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||
|
||||
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
|
||||
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
|
||||
match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(
|
||||
hermesllm::apis::OpenAIApi::ChatCompletions,
|
||||
),
|
||||
)) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(
|
||||
ProviderRequestType::MessagesRequest(_)
|
||||
| ProviderRequestType::BedrockConverse(_)
|
||||
| ProviderRequestType::BedrockConverseStream(_)
|
||||
| ProviderRequestType::ResponsesAPIRequest(_),
|
||||
) => {
|
||||
// This should not happen after conversion to OpenAI format
|
||||
warn!("Unexpected: got non-ChatCompletions request after converting to OpenAI format");
|
||||
let err_msg = "Request conversion failed".to_string();
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Failed to convert request to ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
let err_msg = format!("Failed to convert request: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"[ARCH_ROUTER REQ]: {}",
|
||||
&serde_json::to_string(&chat_completions_request_for_arch_router).unwrap()
|
||||
);
|
||||
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
.find(|(ty, _)| ty.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
|
||||
let latest_message_for_log = chat_completions_request_for_arch_router
|
||||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
|
||||
let truncated: String = latest_message_for_log
|
||||
.chars()
|
||||
.take(MAX_MESSAGE_LENGTH)
|
||||
.collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
latest_message_for_log
|
||||
};
|
||||
|
||||
info!(
|
||||
"request received, request type: chat_completion, usage preferences from request: {}, request path: {}, latest message: {}",
|
||||
usage_preferences.is_some(),
|
||||
request_path,
|
||||
latest_message_for_log
|
||||
);
|
||||
|
||||
debug!("usage preferences from request: {:?}", usage_preferences);
|
||||
|
||||
let model_name = match router_service
|
||||
.determine_route(
|
||||
&chat_completions_request_for_arch_router.messages,
|
||||
trace_parent.clone(),
|
||||
usage_preferences,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(route) => match route {
|
||||
Some((_, model_name)) => model_name,
|
||||
None => {
|
||||
info!(
|
||||
"No route determined, using default model from request: {}",
|
||||
chat_completions_request_for_arch_router.model
|
||||
);
|
||||
chat_completions_request_for_arch_router.model.clone()
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to determine route: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
|
||||
full_qualified_llm_provider_url, model_name
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(&model_name).unwrap(),
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
|
||||
if let Some(trace_parent) = trace_parent {
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static("traceparent"),
|
||||
header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
}
|
||||
// remove content-length header if it exists
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(client_request_bytes_for_upstream)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
// copy over the headers and status code from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
let mut response = Response::builder().status(upstream_status);
|
||||
let headers = response.headers_mut().unwrap();
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
// Use the streaming utility with a passthrough processor (no modification of chunks)
|
||||
let byte_stream = llm_response.bytes_stream();
|
||||
let processor = PassthroughProcessor;
|
||||
let streaming_response = create_streaming_response(byte_stream, processor, 16);
|
||||
|
||||
match response.body(streaming_response.body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to create response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
Ok(internal_error)
|
||||
}
|
||||
}
|
||||
}
|
||||
243
crates/brightstaff/src/handlers/router_chat.rs
Normal file
243
crates/brightstaff/src/handlers/router_chat.rs
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hyper::StatusCode;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::tracing::{OperationNameBuilder, operation_component, http, routing};
|
||||
|
||||
pub struct RoutingResult {
|
||||
pub model_name: String
|
||||
}
|
||||
|
||||
pub struct RoutingError {
|
||||
pub message: String,
|
||||
pub status_code: StatusCode,
|
||||
}
|
||||
|
||||
impl RoutingError {
|
||||
pub fn internal_error(message: String) -> Self {
|
||||
Self {
|
||||
message,
|
||||
status_code: StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines the routing decision if
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
|
||||
/// * `Err(RoutingError)` - Contains error details and optional span ID
|
||||
pub async fn router_chat_get_upstream_model(
|
||||
router_service: Arc<RouterService>,
|
||||
client_request: ProviderRequestType,
|
||||
request_headers: &hyper::HeaderMap,
|
||||
trace_collector: Arc<TraceCollector>,
|
||||
traceparent: &str,
|
||||
request_path: &str,
|
||||
) -> Result<RoutingResult, RoutingError> {
|
||||
// Clone metadata for routing before converting (which consumes client_request)
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
// Convert to ChatCompletionsRequest for routing (regardless of input type)
|
||||
let chat_request = match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(
|
||||
hermesllm::apis::OpenAIApi::ChatCompletions,
|
||||
),
|
||||
)) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(
|
||||
ProviderRequestType::MessagesRequest(_)
|
||||
| ProviderRequestType::BedrockConverse(_)
|
||||
| ProviderRequestType::BedrockConverseStream(_)
|
||||
| ProviderRequestType::ResponsesAPIRequest(_),
|
||||
) => {
|
||||
warn!("Unexpected: got non-ChatCompletions request after converting to OpenAI format");
|
||||
return Err(RoutingError::internal_error(
|
||||
"Request conversion failed".to_string(),
|
||||
));
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
|
||||
return Err(RoutingError::internal_error(format!(
|
||||
"Failed to convert request: {}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"[ARCH_ROUTER REQ]: {}",
|
||||
&serde_json::to_string(&chat_request).unwrap()
|
||||
);
|
||||
|
||||
// Extract trace_parent from headers
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
.find(|(ty, _)| ty.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
// Extract usage preferences from metadata
|
||||
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
|
||||
// Prepare log message with latest message from chat request
|
||||
let latest_message_for_log = chat_request
|
||||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
|
||||
let truncated: String = latest_message_for_log
|
||||
.chars()
|
||||
.take(MAX_MESSAGE_LENGTH)
|
||||
.collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
latest_message_for_log
|
||||
};
|
||||
|
||||
info!(
|
||||
"request received, request type: chat_completion, usage preferences from request: {}, request path: {}, latest message: {}",
|
||||
usage_preferences.is_some(),
|
||||
request_path,
|
||||
latest_message_for_log
|
||||
);
|
||||
|
||||
debug!("usage preferences from request: {:?}", usage_preferences);
|
||||
|
||||
// Capture start time for routing span
|
||||
let routing_start_time = std::time::Instant::now();
|
||||
let routing_start_system_time = std::time::SystemTime::now();
|
||||
|
||||
// Attempt to determine route using the router service
|
||||
let routing_result = router_service
|
||||
.determine_route(&chat_request.messages, trace_parent, usage_preferences)
|
||||
.await;
|
||||
|
||||
match routing_result {
|
||||
Ok(route) => match route {
|
||||
Some((_, model_name)) => {
|
||||
// Record successful routing span
|
||||
let mut attrs: HashMap<String, String> = HashMap::new();
|
||||
attrs.insert("route.selected_model".to_string(), model_name.clone());
|
||||
record_routing_span(
|
||||
trace_collector,
|
||||
traceparent,
|
||||
routing_start_time,
|
||||
routing_start_system_time,
|
||||
attrs,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name
|
||||
})
|
||||
}
|
||||
None => {
|
||||
// No route determined, use default model from request
|
||||
info!(
|
||||
"No route determined, using default model from request: {}",
|
||||
chat_request.model
|
||||
);
|
||||
|
||||
let default_model = chat_request.model.clone();
|
||||
let mut attrs = HashMap::new();
|
||||
attrs.insert("route.selected_model".to_string(), default_model.clone());
|
||||
record_routing_span(
|
||||
trace_collector,
|
||||
traceparent,
|
||||
routing_start_time,
|
||||
routing_start_system_time,
|
||||
attrs,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name: default_model
|
||||
})
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
// Record failed routing span
|
||||
let mut attrs = HashMap::new();
|
||||
attrs.insert("route.selected_model".to_string(), "unknown".to_string());
|
||||
attrs.insert("error.message".to_string(), err.to_string());
|
||||
record_routing_span(
|
||||
trace_collector,
|
||||
traceparent,
|
||||
routing_start_time,
|
||||
routing_start_system_time,
|
||||
attrs,
|
||||
)
|
||||
.await;
|
||||
|
||||
Err(RoutingError::internal_error(
|
||||
format!("Failed to determine route: {}", err)
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to record a routing span with the given attributes.
|
||||
/// Reduces code duplication across different routing outcomes.
|
||||
async fn record_routing_span(
|
||||
trace_collector: Arc<TraceCollector>,
|
||||
traceparent: &str,
|
||||
start_time: std::time::Instant,
|
||||
start_system_time: std::time::SystemTime,
|
||||
attrs: HashMap<String, String>,
|
||||
) {
|
||||
// The routing always uses OpenAI Chat Completions format internally,
|
||||
// so we log that as the actual API being used for routing
|
||||
let routing_api_path = "/v1/chat/completions";
|
||||
|
||||
let routing_operation_name = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(routing_api_path)
|
||||
.with_target("Arch-Router-1.5B")
|
||||
.build();
|
||||
|
||||
let (trace_id, parent_span_id) = parse_traceparent(traceparent);
|
||||
|
||||
// Build the routing span directly using constants
|
||||
let mut span_builder = SpanBuilder::new(&routing_operation_name)
|
||||
.with_trace_id(&trace_id)
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_start_time(start_system_time)
|
||||
.with_end_time(std::time::SystemTime::now())
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, routing_api_path.to_string())
|
||||
.with_attribute(routing::ROUTE_DETERMINATION_MS, start_time.elapsed().as_millis().to_string());
|
||||
|
||||
// Only set parent span ID if it exists (not a root span)
|
||||
if let Some(parent) = parent_span_id {
|
||||
span_builder = span_builder.with_parent_span_id(&parent);
|
||||
}
|
||||
|
||||
// Add all custom attributes
|
||||
for (key, value) in attrs {
|
||||
span_builder = span_builder.with_attribute(key, value);
|
||||
}
|
||||
|
||||
let span = span_builder.build();
|
||||
|
||||
// Record the span directly to the collector
|
||||
trace_collector.record_span(operation_component::ROUTING, span);
|
||||
}
|
||||
|
|
@ -1,18 +1,27 @@
|
|||
use bytes::Bytes;
|
||||
use common::traces::{Span, Attribute, AttributeValue, TraceCollector, Event};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::StreamBody;
|
||||
use hyper::body::Frame;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Instant, SystemTime};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::warn;
|
||||
|
||||
// Import tracing constants
|
||||
use crate::tracing::{llm, error};
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
pub trait StreamProcessor: Send + 'static {
|
||||
/// Process an incoming chunk of bytes
|
||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String>;
|
||||
|
||||
/// Called when the first bytes are received (for time-to-first-token tracking)
|
||||
fn on_first_bytes(&mut self) {}
|
||||
|
||||
/// Called when streaming completes successfully
|
||||
fn on_complete(&mut self) {}
|
||||
|
||||
|
|
@ -20,13 +29,152 @@ pub trait StreamProcessor: Send + 'static {
|
|||
fn on_error(&mut self, _error: &str) {}
|
||||
}
|
||||
|
||||
/// A no-op processor that just forwards chunks as-is
|
||||
pub struct PassthroughProcessor;
|
||||
/// A processor that tracks streaming metrics and finalizes the span
|
||||
pub struct ObservableStreamProcessor {
|
||||
collector: Arc<TraceCollector>,
|
||||
service_name: String,
|
||||
span: Span,
|
||||
total_bytes: usize,
|
||||
chunk_count: usize,
|
||||
start_time: Instant,
|
||||
time_to_first_token: Option<u128>,
|
||||
}
|
||||
|
||||
impl StreamProcessor for PassthroughProcessor {
|
||||
impl ObservableStreamProcessor {
|
||||
/// Create a new passthrough processor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `collector` - The trace collector to record the span to
|
||||
/// * `service_name` - The service name for this span (e.g., "archgw(llm)")
|
||||
/// * `span` - The span to finalize after streaming completes
|
||||
/// * `start_time` - When the request started (for duration calculation)
|
||||
pub fn new(
|
||||
collector: Arc<TraceCollector>,
|
||||
service_name: impl Into<String>,
|
||||
span: Span,
|
||||
start_time: Instant,
|
||||
) -> Self {
|
||||
Self {
|
||||
collector,
|
||||
service_name: service_name.into(),
|
||||
span,
|
||||
total_bytes: 0,
|
||||
chunk_count: 0,
|
||||
start_time,
|
||||
time_to_first_token: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamProcessor for ObservableStreamProcessor {
|
||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||
self.total_bytes += chunk.len();
|
||||
self.chunk_count += 1;
|
||||
Ok(Some(chunk))
|
||||
}
|
||||
|
||||
fn on_first_bytes(&mut self) {
|
||||
// Record time to first token (only for streaming)
|
||||
if self.time_to_first_token.is_none() {
|
||||
self.time_to_first_token = Some(self.start_time.elapsed().as_millis());
|
||||
}
|
||||
}
|
||||
|
||||
fn on_complete(&mut self) {
|
||||
// Update span with streaming metrics and end time
|
||||
let end_time_nanos = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos();
|
||||
|
||||
self.span.end_time_unix_nano = format!("{}", end_time_nanos);
|
||||
|
||||
// Add streaming metrics as attributes using constants
|
||||
self.span.attributes.push(Attribute {
|
||||
key: llm::RESPONSE_BYTES.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(self.total_bytes.to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
self.span.attributes.push(Attribute {
|
||||
key: llm::DURATION_MS.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(self.start_time.elapsed().as_millis().to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
// Add time to first token if available (streaming only)
|
||||
if let Some(ttft) = self.time_to_first_token {
|
||||
self.span.attributes.push(Attribute {
|
||||
key: llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(ttft.to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
// Add time to first token as a span event
|
||||
// Calculate the timestamp by adding ttft duration to span start time
|
||||
if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::<u128>() {
|
||||
// Convert ttft from milliseconds to nanoseconds and add to start time
|
||||
let event_timestamp = start_time_nanos + (ttft * 1_000_000);
|
||||
let mut event = Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
|
||||
event.add_attribute(
|
||||
llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
|
||||
ttft.to_string(),
|
||||
);
|
||||
|
||||
// Initialize events vector if needed
|
||||
if self.span.events.is_none() {
|
||||
self.span.events = Some(Vec::new());
|
||||
}
|
||||
|
||||
if let Some(ref mut events) = self.span.events {
|
||||
events.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Record the finalized span
|
||||
self.collector.record_span(&self.service_name, self.span.clone());
|
||||
}
|
||||
|
||||
fn on_error(&mut self, error_msg: &str) {
|
||||
warn!("Stream error in PassthroughProcessor: {}", error_msg);
|
||||
|
||||
// Update span with error info and end time
|
||||
let end_time_nanos = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos();
|
||||
|
||||
self.span.end_time_unix_nano = format!("{}", end_time_nanos);
|
||||
|
||||
self.span.attributes.push(Attribute {
|
||||
key: error::ERROR.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some("true".to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
self.span.attributes.push(Attribute {
|
||||
key: error::MESSAGE.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(error_msg.to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
self.span.attributes.push(Attribute {
|
||||
key: llm::DURATION_MS.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(self.start_time.elapsed().as_millis().to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
// Record the error span
|
||||
self.collector.record_span(&self.service_name, self.span.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of creating a streaming response
|
||||
|
|
@ -48,6 +196,8 @@ where
|
|||
|
||||
// Spawn a task to process and forward chunks
|
||||
let processor_handle = tokio::spawn(async move {
|
||||
let mut is_first_chunk = true;
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
|
|
@ -59,6 +209,12 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
// Call on_first_bytes for the first chunk
|
||||
if is_first_chunk {
|
||||
processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
// Process the chunk
|
||||
match processor.process_chunk(chunk) {
|
||||
Ok(Some(processed_chunk)) => {
|
||||
|
|
@ -91,3 +247,13 @@ where
|
|||
processor_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates a message to the specified maximum length, adding "..." if truncated.
|
||||
pub fn truncate_message(message: &str, max_length: usize) -> String {
|
||||
if message.chars().count() > max_length {
|
||||
let truncated: String = message.chars().take(max_length).collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
message.to_string()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue