From 28b674454b3a2d1b643fa490a32ff9942f7dc623 Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Thu, 11 Dec 2025 13:53:44 -0800 Subject: [PATCH] fixed PR comments and added more trace attributes --- .github/workflows/rust_tests.yml | 3 + arch/envoy.template.yaml | 6 +- archgw.code-workspace | 1 + crates/Cargo.lock | 75 +++++ crates/Cargo.toml | 4 + .../src/handlers/{router.rs => llm.rs} | 28 +- crates/brightstaff/src/handlers/mod.rs | 2 +- crates/brightstaff/src/handlers/utils.rs | 6 +- crates/brightstaff/src/main.rs | 4 +- crates/brightstaff/src/tracing/constants.rs | 8 +- crates/common/Cargo.toml | 2 + crates/common/src/traces/collector.rs | 12 +- crates/common/src/traces/mod.rs | 3 + .../src/traces/tests/mock_otel_collector.rs | 101 ++++++ crates/common/src/traces/tests/mod.rs | 4 + .../traces/tests/trace_integration_test.rs | 304 ++++++++++++++++++ crates/hermesllm/src/apis/amazon_bedrock.rs | 4 + crates/hermesllm/src/apis/anthropic.rs | 4 + crates/hermesllm/src/apis/openai.rs | 4 + crates/hermesllm/src/apis/openai_responses.rs | 4 + crates/hermesllm/src/providers/request.rs | 12 + 21 files changed, 565 insertions(+), 26 deletions(-) rename crates/brightstaff/src/handlers/{router.rs => llm.rs} (92%) create mode 100644 crates/common/src/traces/tests/mock_otel_collector.rs create mode 100644 crates/common/src/traces/tests/mod.rs create mode 100644 crates/common/src/traces/tests/trace_integration_test.rs diff --git a/.github/workflows/rust_tests.yml b/.github/workflows/rust_tests.yml index 9837531d..e75a6b60 100644 --- a/.github/workflows/rust_tests.yml +++ b/.github/workflows/rust_tests.yml @@ -29,3 +29,6 @@ jobs: - name: Run unit tests run: cargo test --lib + + - name: Run trace integration tests + run: cargo test -p common --features trace-collection traces::tests::trace_integration_test diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index cf30a07d..3d618faf 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -206,7 +206,7 @@ static_resources: - name: outbound_api_traffic address: socket_address: - address: 0.0.0.0 + address: 127.0.0.1 port_value: 11000 traffic_direction: OUTBOUND filter_chains: @@ -225,7 +225,7 @@ static_resources: envoy_grpc: cluster_name: opentelemetry_collector timeout: 0.250s - service_name: tool + service_name: tools random_sampling: value: {{ arch_tracing.random_sampling }} {% endif %} @@ -473,7 +473,7 @@ static_resources: - name: otel_collector_proxy address: socket_address: - address: 0.0.0.0 + address: 127.0.0.1 port_value: 9903 traffic_direction: OUTBOUND filter_chains: diff --git a/archgw.code-workspace b/archgw.code-workspace index f94e67b0..bd24f82a 100644 --- a/archgw.code-workspace +++ b/archgw.code-workspace @@ -34,6 +34,7 @@ "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true }, + "rust-analyzer.cargo.features": ["trace-collection"] }, "extensions": { "recommendations": [ diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 7aa85f59..09c86861 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -167,6 +167,61 @@ dependencies = [ "time", ] +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower 0.5.2", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backoff" version = "0.4.0" @@ -370,6 +425,7 @@ dependencies = [ name = "common" version = "0.1.0" dependencies = [ + "axum", "derivative", "duration-string", "governor", @@ -384,6 +440,7 @@ dependencies = [ "serde_json", "serde_with", "serde_yaml", + "serial_test", "thiserror 1.0.69", "tiktoken-rs", "tokio", @@ -1429,6 +1486,12 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "md5" version = "0.7.0" @@ -2461,6 +2524,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2984,6 +3057,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -3022,6 +3096,7 @@ version = "0.1.41" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", diff --git a/crates/Cargo.toml b/crates/Cargo.toml index 5cd6b29c..c22e252e 100644 --- a/crates/Cargo.toml +++ b/crates/Cargo.toml @@ -1,3 +1,7 @@ [workspace] resolver = "2" members = ["llm_gateway", "prompt_gateway", "common", "brightstaff", "hermesllm"] + +[workspace.metadata.rust-analyzer] +# Enable features for better IDE support +cargo.features = ["trace-collection"] diff --git a/crates/brightstaff/src/handlers/router.rs b/crates/brightstaff/src/handlers/llm.rs similarity index 92% rename from crates/brightstaff/src/handlers/router.rs rename to crates/brightstaff/src/handlers/llm.rs index cb4b2aa7..ba674c3e 100644 --- a/crates/brightstaff/src/handlers/router.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -14,7 +14,7 @@ use tokio::sync::RwLock; use tracing::{debug, warn}; use crate::router::llm_router::RouterService; -use crate::handlers::utils::{create_streaming_response, PassthroughProcessor, truncate_message}; +use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message}; use crate::handlers::router_chat::router_chat_get_upstream_model; use crate::tracing::operation_component; @@ -24,7 +24,7 @@ fn full>(chunk: T) -> BoxBody { .boxed() } -pub async fn chat( +pub async fn llm_chat( request: Request, router_service: Arc, full_qualified_llm_provider_url: String, @@ -36,12 +36,19 @@ pub async fn chat( let request_path = request.uri().path().to_string(); let request_headers = request.headers().clone(); - // Extract traceparent header early (Envoy should have added this) - let traceparent = request_headers + // Extract or generate traceparent - this establishes the trace context for all spans + let traceparent: String = request_headers .get("traceparent") .and_then(|h| h.to_str().ok()) - .unwrap_or("00-00000000000000000000000000000000-0000000000000000-01") - .to_string(); + .map(|s| s.to_string()) + .unwrap_or_else(|| { + // No traceparent - this is a root span, generate a new trace ID + use uuid::Uuid; + let trace_id = Uuid::new_v4().to_string().replace("-", ""); + let span_id = Uuid::new_v4().to_string().replace("-", "")[..16].to_string(); + // Format: version-trace_id-parent_span_id-trace_flags + format!("00-{}-{}-01", trace_id, span_id) + }); let mut request_headers = request_headers; let chat_request_bytes = request.collect().await?.to_bytes(); @@ -68,6 +75,7 @@ pub async fn chat( // Model alias resolution: update model field in client_request immediately // This ensures all downstream objects use the resolved model let model_from_request = client_request.model().to_string(); + let temperature = client_request.get_temperature(); let is_streaming_request = client_request.is_streaming(); let resolved_model = resolve_model_alias(&model_from_request, &model_aliases); @@ -177,11 +185,12 @@ pub async fn chat( request_start_system_time, tool_names, user_message_preview, + temperature, &llm_providers, ).await; // Use PassthroughProcessor to track streaming metrics and finalize the span - let processor = PassthroughProcessor::new( + let processor = ObservableStreamProcessor::new( trace_collector, operation_component::LLM, llm_span, @@ -230,6 +239,7 @@ async fn build_llm_span( start_time: std::time::SystemTime, tool_names: Option>, user_message_preview: Option, + temperature: Option, llm_providers: &Arc>>, ) -> common::traces::Span { use common::traces::{SpanBuilder, SpanKind, parse_traceparent}; @@ -274,6 +284,10 @@ async fn build_llm_span( .with_attribute(llm::IS_STREAMING, is_streaming.to_string()); // Add optional attributes + if let Some(temp) = temperature { + span_builder = span_builder.with_attribute(llm::TEMPERATURE, temp.to_string()); + } + if let Some(tools) = tool_names { let formatted_tools = tools.iter() .map(|name| format!("{}(...)", name)) diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index 177ec8a1..b7762916 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,6 +1,6 @@ pub mod agent_chat_completions; pub mod agent_selector; -pub mod router; +pub mod llm; pub mod router_chat; pub mod models; pub mod function_calling; diff --git a/crates/brightstaff/src/handlers/utils.rs b/crates/brightstaff/src/handlers/utils.rs index 15c02ce6..6f84c1f3 100644 --- a/crates/brightstaff/src/handlers/utils.rs +++ b/crates/brightstaff/src/handlers/utils.rs @@ -30,7 +30,7 @@ pub trait StreamProcessor: Send + 'static { } /// A processor that tracks streaming metrics and finalizes the span -pub struct PassthroughProcessor { +pub struct ObservableStreamProcessor { collector: Arc, service_name: String, span: Span, @@ -40,7 +40,7 @@ pub struct PassthroughProcessor { time_to_first_token: Option, } -impl PassthroughProcessor { +impl ObservableStreamProcessor { /// Create a new passthrough processor /// /// # Arguments @@ -66,7 +66,7 @@ impl PassthroughProcessor { } } -impl StreamProcessor for PassthroughProcessor { +impl StreamProcessor for ObservableStreamProcessor { fn process_chunk(&mut self, chunk: Bytes) -> Result, String> { self.total_bytes += chunk.len(); self.chunk_count += 1; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 4f76f5df..d0241fa3 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,5 +1,5 @@ use brightstaff::handlers::agent_chat_completions::agent_chat; -use brightstaff::handlers::router::chat; +use brightstaff::handlers::llm::llm_chat; use brightstaff::handlers::models::list_models; use brightstaff::handlers::function_calling::{function_calling_chat_handler}; use brightstaff::router::llm_router::RouterService; @@ -130,7 +130,7 @@ async fn main() -> Result<(), Box> { (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => { let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path()); - chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector) + llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector) .with_context(parent_cx) .await } diff --git a/crates/brightstaff/src/tracing/constants.rs b/crates/brightstaff/src/tracing/constants.rs index 8557e5de..bd946aac 100644 --- a/crates/brightstaff/src/tracing/constants.rs +++ b/crates/brightstaff/src/tracing/constants.rs @@ -83,19 +83,19 @@ pub mod llm { pub const TOTAL_TOKENS: &str = "llm.usage.total_tokens"; /// Temperature parameter used - pub const TEMPERATURE: &str = "llm.request.temperature"; + pub const TEMPERATURE: &str = "llm.temperature"; /// Max tokens parameter used - pub const MAX_TOKENS: &str = "llm.request.max_tokens"; + pub const MAX_TOKENS: &str = "llm.max_tokens"; /// Top-p parameter used - pub const TOP_P: &str = "llm.request.top_p"; + pub const TOP_P: &str = "llm.top_p"; /// List of tool names provided in the request pub const TOOLS: &str = "llm.tools"; /// Preview of the user message (truncated) - pub const USER_MESSAGE_PREVIEW: &str = "llm.request.user_message_preview"; + pub const USER_MESSAGE_PREVIEW: &str = "llm.user_message_preview"; } // ============================================================================= diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index f0def5d9..4c659bfe 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -33,4 +33,6 @@ trace-collection = ["tokio", "reqwest", "tracing"] [dev-dependencies] pretty_assertions = "1.4.1" serde_json = "1.0.64" +serial_test = "3.2" +axum = "0.7" tokio = { version = "1.44", features = ["sync", "time", "macros", "rt"] } diff --git a/crates/common/src/traces/collector.rs b/crates/common/src/traces/collector.rs index b3a58ce0..c2188339 100644 --- a/crates/common/src/traces/collector.rs +++ b/crates/common/src/traces/collector.rs @@ -52,13 +52,13 @@ impl TraceCollector { /// - `None` - Check `OTEL_TRACING_ENABLED` env var (defaults to true if not set) /// /// Other parameters are read from environment variables: - /// - `TRACE_FLUSH_INTERVAL_SECS` - Flush interval in seconds (default: 1) + /// - `TRACE_FLUSH_INTERVAL_MS` - Flush interval in milliseconds (default: 1000) /// - `OTEL_COLLECTOR_URL` - OTEL collector endpoint (default: http://localhost:9903/v1/traces) pub fn new(enabled: Option) -> Self { - let flush_interval_secs = std::env::var("TRACE_FLUSH_INTERVAL_SECS") + let flush_interval_ms = std::env::var("TRACE_FLUSH_INTERVAL_MS") .ok() .and_then(|s| s.parse().ok()) - .unwrap_or(1); + .unwrap_or(1000); let otel_url = std::env::var("OTEL_COLLECTOR_URL") .unwrap_or_else(|_| "http://localhost:9903/v1/traces".to_string()); @@ -75,13 +75,13 @@ impl TraceCollector { }); debug!( - "TraceCollector initialized: flush_interval={}s, url={}, enabled={}", - flush_interval_secs, otel_url, enabled + "TraceCollector initialized: flush_interval={}ms, url={}, enabled={}", + flush_interval_ms, otel_url, enabled ); Self { spans_by_service: Arc::new(Mutex::new(HashMap::new())), - flush_interval: Duration::from_secs(flush_interval_secs), + flush_interval: Duration::from_millis(flush_interval_ms), otel_url, enabled, } diff --git a/crates/common/src/traces/mod.rs b/crates/common/src/traces/mod.rs index a8bc6ca5..c0d042fa 100644 --- a/crates/common/src/traces/mod.rs +++ b/crates/common/src/traces/mod.rs @@ -8,6 +8,9 @@ mod constants; #[cfg(feature = "trace-collection")] mod collector; +#[cfg(all(test, feature = "trace-collection"))] +mod tests; + // Re-export original types pub use shapes::{ Span, Event, Traceparent, TraceparentNewError, diff --git a/crates/common/src/traces/tests/mock_otel_collector.rs b/crates/common/src/traces/tests/mock_otel_collector.rs new file mode 100644 index 00000000..8a154145 --- /dev/null +++ b/crates/common/src/traces/tests/mock_otel_collector.rs @@ -0,0 +1,101 @@ +//! Mock OTEL Collector for testing trace output +//! +//! This module provides a simple HTTP server that mimics an OTEL collector. +//! It exposes three endpoints: +//! - POST /v1/traces: Capture incoming OTLP JSON payloads +//! - GET /v1/traces: Return all captured payloads as JSON array +//! - DELETE /v1/traces: Clear all captured payloads +//! +//! Each test creates its own MockOtelCollector instance. + +use axum::{ + extract::State, + http::StatusCode, + routing::{delete, get, post}, + Json, Router, +}; +use serde_json::Value; +use std::sync::Arc; +use tokio::sync::RwLock; + +type SharedTraces = Arc>>; + +/// POST /v1/traces - capture incoming OTLP payload +async fn post_traces( + State(traces): State, + Json(payload): Json, +) -> StatusCode { + traces.write().await.push(payload); + StatusCode::OK +} + +/// GET /v1/traces - return all captured payloads +async fn get_traces(State(traces): State) -> Json> { + Json(traces.read().await.clone()) +} + +/// DELETE /v1/traces - clear all captured payloads +async fn delete_traces(State(traces): State) -> StatusCode { + traces.write().await.clear(); + StatusCode::NO_CONTENT +} + +/// Mock OTEL collector server +pub struct MockOtelCollector { + address: String, + client: reqwest::Client, + #[allow(dead_code)] + server_handle: tokio::task::JoinHandle<()>, +} + +impl MockOtelCollector { + /// Create and start a new mock collector on a random port + pub async fn start() -> Self { + let traces = Arc::new(RwLock::new(Vec::new())); + + let app = Router::new() + .route("/v1/traces", post(post_traces)) + .route("/v1/traces", get(get_traces)) + .route("/v1/traces", delete(delete_traces)) + .with_state(traces.clone()); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("Failed to bind to random port"); + + let addr = listener.local_addr().expect("Failed to get local address"); + let address = format!("http://127.0.0.1:{}", addr.port()); + + let server_handle = tokio::spawn(async move { + axum::serve(listener, app) + .await + .expect("Server failed"); + }); + + // Give server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + Self { + address, + client: reqwest::Client::new(), + server_handle, + } + } + + /// Get the address of the collector + pub fn address(&self) -> &str { + &self.address + } + + /// GET /v1/traces - fetch all captured payloads + pub async fn get_traces(&self) -> Vec { + self.client + .get(format!("{}/v1/traces", self.address)) + .send() + .await + .expect("Failed to GET traces") + .json() + .await + .expect("Failed to parse traces JSON") + } +} diff --git a/crates/common/src/traces/tests/mod.rs b/crates/common/src/traces/tests/mod.rs new file mode 100644 index 00000000..7bba42f8 --- /dev/null +++ b/crates/common/src/traces/tests/mod.rs @@ -0,0 +1,4 @@ +mod mock_otel_collector; +mod trace_integration_test; + +pub use mock_otel_collector::MockOtelCollector; diff --git a/crates/common/src/traces/tests/trace_integration_test.rs b/crates/common/src/traces/tests/trace_integration_test.rs new file mode 100644 index 00000000..a3c8a6ba --- /dev/null +++ b/crates/common/src/traces/tests/trace_integration_test.rs @@ -0,0 +1,304 @@ +//! Integration tests for OpenTelemetry tracing in router.rs +//! +//! These tests validate that the spans created for LLM requests contain +//! all expected attributes and events by checking the raw JSON payloads +//! sent to the mock OTEL collector. +//! +//! ## Test Design +//! Each test creates its own MockOtelCollector and TraceCollector: +//! 1. Start MockOtelCollector on random port +//! 2. Create TraceCollector with 500ms flush interval +//! 3. Record spans using TraceCollector +//! 4. Flush and wait (500ms + 200ms buffer = 700ms total) for spans to arrive +//! 5. Get raw JSON payloads (GET /v1/traces) and validate structure +//! 6. Test cleanup happens automatically when collectors are dropped +//! +//! ## Serial Execution +//! Tests use the `#[serial]` attribute to run sequentially because they +//! use global environment variables (OTEL_COLLECTOR_URL, OTEL_TRACING_ENABLED, +//! TRACE_FLUSH_INTERVAL_MS). This ensures test isolation without requiring +//! the `--test-threads=1` command line flag. + +const FLUSH_INTERVAL_MS: u64 = 50; +const FLUSH_BUFFER_MS: u64 = 50; +const TOTAL_WAIT_MS: u64 = FLUSH_INTERVAL_MS + FLUSH_BUFFER_MS; + +use crate::traces::{SpanBuilder, SpanKind, TraceCollector}; +use serde_json::Value; +use serial_test::serial; +use std::sync::Arc; + +use super::MockOtelCollector; + +/// Helper to extract all spans from OTLP JSON payloads +fn extract_spans(payloads: &[Value]) -> Vec<&Value> { + let mut spans = Vec::new(); + for payload in payloads { + if let Some(resource_spans) = payload.get("resourceSpans").and_then(|v| v.as_array()) { + for resource_span in resource_spans { + if let Some(scope_spans) = resource_span.get("scopeSpans").and_then(|v| v.as_array()) { + for scope_span in scope_spans { + if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array()) { + spans.extend(span_list.iter()); + } + } + } + } + } + } + spans +} + +/// Helper to get string attribute value from a span +fn get_string_attr<'a>(span: &'a Value, key: &str) -> Option<&'a str> { + span.get("attributes") + .and_then(|attrs| attrs.as_array()) + .and_then(|attrs| { + attrs.iter().find(|attr| { + attr.get("key").and_then(|k| k.as_str()) == Some(key) + }) + }) + .and_then(|attr| attr.get("value")) + .and_then(|v| v.get("stringValue")) + .and_then(|v| v.as_str()) +} + +#[tokio::test] +#[serial] +async fn test_llm_span_contains_basic_attributes() { + // Start mock OTEL collector + let mock_collector = MockOtelCollector::start().await; + + // Create TraceCollector pointing to mock with 500ms flush intervalc + std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var("OTEL_TRACING_ENABLED", "true"); + std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); + let trace_collector = Arc::new(TraceCollector::new(Some(true))); + + // Create a test span simulating router.rs behavior + let span = SpanBuilder::new("POST /v1/chat/completions >> /v1/chat/completions") + .with_kind(SpanKind::Client) + .with_trace_id("test-trace-123") + .with_attribute("http.method", "POST") + .with_attribute("http.target", "/v1/chat/completions") + .with_attribute("http.upstream_target", "/v1/chat/completions") + .with_attribute("llm.model", "gpt-4o") + .with_attribute("llm.provider", "openai") + .with_attribute("llm.is_streaming", "true") + .with_attribute("llm.temperature", "0.7") + .build(); + + trace_collector.record_span("archgw(llm)", span); + + // Flush and wait for spans to arrive (500ms flush interval + 200ms buffer) + trace_collector.flush().await.expect("Failed to flush"); + tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await; + + let payloads = mock_collector.get_traces().await; + let spans = extract_spans(&payloads); + + assert_eq!(spans.len(), 1, "Expected exactly one span"); + + let span = spans[0]; + // Validate HTTP attributes + assert_eq!(get_string_attr(span, "http.method"), Some("POST")); + assert_eq!(get_string_attr(span, "http.target"), Some("/v1/chat/completions")); + + // Validate LLM attributes + assert_eq!(get_string_attr(span, "llm.model"), Some("gpt-4o")); + assert_eq!(get_string_attr(span, "llm.provider"), Some("openai")); + assert_eq!(get_string_attr(span, "llm.is_streaming"), Some("true")); + assert_eq!(get_string_attr(span, "llm.temperature"), Some("0.7")); +} + +#[tokio::test] +#[serial] +async fn test_llm_span_contains_tool_information() { + let mock_collector = MockOtelCollector::start().await; + std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var("OTEL_TRACING_ENABLED", "true"); + std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); + let trace_collector = Arc::new(TraceCollector::new(Some(true))); + + let tools_formatted = "get_weather(...)\nsearch_web(...)\ncalculate(...)"; + + let span = SpanBuilder::new("POST /v1/chat/completions") + .with_trace_id("test-trace-tools") + .with_attribute("llm.request.tools", tools_formatted) + .with_attribute("llm.model", "gpt-4o") + .build(); + + trace_collector.record_span("archgw(llm)", span); + trace_collector.flush().await.expect("Failed to flush"); + tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await; + + let payloads = mock_collector.get_traces().await; + let spans = extract_spans(&payloads); + + assert!(!spans.is_empty(), "No spans captured"); + + let span = spans[0]; + let tools = get_string_attr(span, "llm.request.tools"); + + assert!(tools.is_some(), "Tools attribute missing"); + assert!(tools.unwrap().contains("get_weather(...)")); + assert!(tools.unwrap().contains("search_web(...)")); + assert!(tools.unwrap().contains("calculate(...)")); + assert!(tools.unwrap().contains('\n'), "Tools should be newline-separated"); +} + +#[tokio::test] +#[serial] +async fn test_llm_span_contains_user_message_preview() { + let mock_collector = MockOtelCollector::start().await; + std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var("OTEL_TRACING_ENABLED", "true"); + std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); + let trace_collector = Arc::new(TraceCollector::new(Some(true))); + + let long_message = "This is a very long user message that should be truncated to 50 characters in the span"; + let preview = if long_message.len() > 50 { + format!("{}...", &long_message[..50]) + } else { + long_message.to_string() + }; + + let span = SpanBuilder::new("POST /v1/messages") + .with_trace_id("test-trace-preview") + .with_attribute("llm.request.user_message_preview", &preview) + .build(); + + trace_collector.record_span("archgw(llm)", span); + trace_collector.flush().await.expect("Failed to flush"); + tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await; + + let payloads = mock_collector.get_traces().await; + let spans = extract_spans(&payloads); + let span = spans[0]; + + let message_preview = get_string_attr(span, "llm.request.user_message_preview"); + + assert!(message_preview.is_some()); + assert!(message_preview.unwrap().len() <= 53); // 50 chars + "..." + assert!(message_preview.unwrap().contains("...")); +} + +#[tokio::test] +#[serial] +async fn test_llm_span_contains_time_to_first_token() { + let mock_collector = MockOtelCollector::start().await; + std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var("OTEL_TRACING_ENABLED", "true"); + std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); + let trace_collector = Arc::new(TraceCollector::new(Some(true))); + + let ttft_ms = "245"; // milliseconds as string + + let span = SpanBuilder::new("POST /v1/chat/completions") + .with_trace_id("test-trace-ttft") + .with_attribute("llm.is_streaming", "true") + .with_attribute("llm.time_to_first_token_ms", ttft_ms) + .build(); + + trace_collector.record_span("archgw(llm)", span); + trace_collector.flush().await.expect("Failed to flush"); + tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await; + + let payloads = mock_collector.get_traces().await; + let spans = extract_spans(&payloads); + let span = spans[0]; + + // Check TTFT attribute + let ttft_attr = get_string_attr(span, "llm.time_to_first_token_ms"); + assert_eq!(ttft_attr, Some("245")); +} + +#[tokio::test] +#[serial] +async fn test_llm_span_contains_upstream_path() { + let mock_collector = MockOtelCollector::start().await; + std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var("OTEL_TRACING_ENABLED", "true"); + std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); + let trace_collector = Arc::new(TraceCollector::new(Some(true))); + + // Test Zhipu provider with path transformation + let span = SpanBuilder::new("POST /v1/chat/completions >> /api/paas/v4/chat/completions") + .with_trace_id("test-trace-upstream") + .with_attribute("http.upstream_target", "/api/paas/v4/chat/completions") + .with_attribute("llm.provider", "zhipu") + .with_attribute("llm.model", "glm-4") + .build(); + + trace_collector.record_span("archgw(llm)", span); + trace_collector.flush().await.expect("Failed to flush"); + tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await; + + let payloads = mock_collector.get_traces().await; + let spans = extract_spans(&payloads); + let span = spans[0]; + + // Operation name should show the transformation + let name = span.get("name").and_then(|v| v.as_str()); + assert!(name.is_some()); + assert!(name.unwrap().contains(">>"), "Operation name should show path transformation"); + + // Check upstream target attribute + let upstream = get_string_attr(span, "http.upstream_target"); + assert_eq!(upstream, Some("/api/paas/v4/chat/completions")); +} + +#[tokio::test] +#[serial] +async fn test_llm_span_multiple_services() { + let mock_collector = MockOtelCollector::start().await; + std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var("OTEL_TRACING_ENABLED", "true"); + std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); + let trace_collector = Arc::new(TraceCollector::new(Some(true))); + + // Create spans for different services + let llm_span = SpanBuilder::new("LLM Request") + .with_trace_id("test-multi") + .with_attribute("service", "llm") + .build(); + + let routing_span = SpanBuilder::new("Routing Decision") + .with_trace_id("test-multi") + .with_attribute("service", "routing") + .build(); + + trace_collector.record_span("archgw(llm)", llm_span); + trace_collector.record_span("archgw(routing)", routing_span); + trace_collector.flush().await.expect("Failed to flush"); + tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await; + + let payloads = mock_collector.get_traces().await; + let all_spans = extract_spans(&payloads); + + assert_eq!(all_spans.len(), 2, "Should have captured both spans"); +} + +#[tokio::test] +#[serial] +async fn test_tracing_disabled_produces_no_spans() { + let mock_collector = MockOtelCollector::start().await; + + // Create TraceCollector with tracing DISABLED + std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var("OTEL_TRACING_ENABLED", "false"); + std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); + let trace_collector = Arc::new(TraceCollector::new(Some(false))); + + let span = SpanBuilder::new("Test Span") + .with_trace_id("test-disabled") + .build(); + + trace_collector.record_span("archgw(llm)", span); + trace_collector.flush().await.ok(); // Should be no-op when disabled + tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await; + + let payloads = mock_collector.get_traces().await; + let all_spans = extract_spans(&payloads); + assert_eq!(all_spans.len(), 0, "No spans should be captured when tracing is disabled"); +} diff --git a/crates/hermesllm/src/apis/amazon_bedrock.rs b/crates/hermesllm/src/apis/amazon_bedrock.rs index 8afd8af0..7b4a511f 100644 --- a/crates/hermesllm/src/apis/amazon_bedrock.rs +++ b/crates/hermesllm/src/apis/amazon_bedrock.rs @@ -229,6 +229,10 @@ impl ProviderRequest for ConverseRequest { false } } + + fn get_temperature(&self) -> Option { + self.inference_config.as_ref()?.temperature + } } // ============================================================================ diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index c3c8bc1a..06d632d9 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -537,6 +537,10 @@ impl ProviderRequest for MessagesRequest { false } } + + fn get_temperature(&self) -> Option { + self.temperature + } } impl MessagesResponse { diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 0bfd43c2..4e006c3a 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -731,6 +731,10 @@ impl ProviderRequest for ChatCompletionsRequest { false } } + + fn get_temperature(&self) -> Option { + self.temperature + } } /// Implementation of ProviderResponse for ChatCompletionsResponse diff --git a/crates/hermesllm/src/apis/openai_responses.rs b/crates/hermesllm/src/apis/openai_responses.rs index fa9976e3..91c4b0cc 100644 --- a/crates/hermesllm/src/apis/openai_responses.rs +++ b/crates/hermesllm/src/apis/openai_responses.rs @@ -1094,6 +1094,10 @@ impl ProviderRequest for ResponsesAPIRequest { false } } + + fn get_temperature(&self) -> Option { + self.temperature + } } // ============================================================================ diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index c087398f..eb8f0788 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -45,6 +45,8 @@ pub trait ProviderRequest: Send + Sync { /// Remove a metadata key from the request and return true if the key was present fn remove_metadata_key(&mut self, key: &str) -> bool; + + fn get_temperature(&self) -> Option; } impl ProviderRequest for ProviderRequestType { @@ -137,6 +139,16 @@ impl ProviderRequest for ProviderRequestType { Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key), } } + + fn get_temperature(&self) -> Option { + match self { + Self::ChatCompletionsRequest(r) => r.get_temperature(), + Self::MessagesRequest(r) => r.get_temperature(), + Self::BedrockConverse(r) => r.get_temperature(), + Self::BedrockConverseStream(r) => r.get_temperature(), + Self::ResponsesAPIRequest(r) => r.get_temperature(), + } + } } /// Parse the client API from a byte slice.