mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
fixed PR comments and added more trace attributes
This commit is contained in:
parent
c0cf877b4f
commit
28b674454b
21 changed files with 565 additions and 26 deletions
3
.github/workflows/rust_tests.yml
vendored
3
.github/workflows/rust_tests.yml
vendored
|
|
@ -29,3 +29,6 @@ jobs:
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: cargo test --lib
|
run: cargo test --lib
|
||||||
|
|
||||||
|
- name: Run trace integration tests
|
||||||
|
run: cargo test -p common --features trace-collection traces::tests::trace_integration_test
|
||||||
|
|
|
||||||
|
|
@ -206,7 +206,7 @@ static_resources:
|
||||||
- name: outbound_api_traffic
|
- name: outbound_api_traffic
|
||||||
address:
|
address:
|
||||||
socket_address:
|
socket_address:
|
||||||
address: 0.0.0.0
|
address: 127.0.0.1
|
||||||
port_value: 11000
|
port_value: 11000
|
||||||
traffic_direction: OUTBOUND
|
traffic_direction: OUTBOUND
|
||||||
filter_chains:
|
filter_chains:
|
||||||
|
|
@ -225,7 +225,7 @@ static_resources:
|
||||||
envoy_grpc:
|
envoy_grpc:
|
||||||
cluster_name: opentelemetry_collector
|
cluster_name: opentelemetry_collector
|
||||||
timeout: 0.250s
|
timeout: 0.250s
|
||||||
service_name: tool
|
service_name: tools
|
||||||
random_sampling:
|
random_sampling:
|
||||||
value: {{ arch_tracing.random_sampling }}
|
value: {{ arch_tracing.random_sampling }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
@ -473,7 +473,7 @@ static_resources:
|
||||||
- name: otel_collector_proxy
|
- name: otel_collector_proxy
|
||||||
address:
|
address:
|
||||||
socket_address:
|
socket_address:
|
||||||
address: 0.0.0.0
|
address: 127.0.0.1
|
||||||
port_value: 9903
|
port_value: 9903
|
||||||
traffic_direction: OUTBOUND
|
traffic_direction: OUTBOUND
|
||||||
filter_chains:
|
filter_chains:
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@
|
||||||
"editor.defaultFormatter": "ms-python.black-formatter",
|
"editor.defaultFormatter": "ms-python.black-formatter",
|
||||||
"editor.formatOnSave": true
|
"editor.formatOnSave": true
|
||||||
},
|
},
|
||||||
|
"rust-analyzer.cargo.features": ["trace-collection"]
|
||||||
},
|
},
|
||||||
"extensions": {
|
"extensions": {
|
||||||
"recommendations": [
|
"recommendations": [
|
||||||
|
|
|
||||||
75
crates/Cargo.lock
generated
75
crates/Cargo.lock
generated
|
|
@ -167,6 +167,61 @@ dependencies = [
|
||||||
"time",
|
"time",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum"
|
||||||
|
version = "0.7.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"axum-core",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http 1.3.1",
|
||||||
|
"http-body 1.0.1",
|
||||||
|
"http-body-util",
|
||||||
|
"hyper 1.6.0",
|
||||||
|
"hyper-util",
|
||||||
|
"itoa",
|
||||||
|
"matchit",
|
||||||
|
"memchr",
|
||||||
|
"mime",
|
||||||
|
"percent-encoding",
|
||||||
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"serde_path_to_error",
|
||||||
|
"serde_urlencoded",
|
||||||
|
"sync_wrapper",
|
||||||
|
"tokio",
|
||||||
|
"tower 0.5.2",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-core"
|
||||||
|
version = "0.4.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http 1.3.1",
|
||||||
|
"http-body 1.0.1",
|
||||||
|
"http-body-util",
|
||||||
|
"mime",
|
||||||
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
|
"sync_wrapper",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "backoff"
|
name = "backoff"
|
||||||
version = "0.4.0"
|
version = "0.4.0"
|
||||||
|
|
@ -370,6 +425,7 @@ dependencies = [
|
||||||
name = "common"
|
name = "common"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"axum",
|
||||||
"derivative",
|
"derivative",
|
||||||
"duration-string",
|
"duration-string",
|
||||||
"governor",
|
"governor",
|
||||||
|
|
@ -384,6 +440,7 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_with",
|
"serde_with",
|
||||||
"serde_yaml",
|
"serde_yaml",
|
||||||
|
"serial_test",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"tiktoken-rs",
|
"tiktoken-rs",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
@ -1429,6 +1486,12 @@ dependencies = [
|
||||||
"regex-automata 0.1.10",
|
"regex-automata 0.1.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matchit"
|
||||||
|
version = "0.7.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "md5"
|
name = "md5"
|
||||||
version = "0.7.0"
|
version = "0.7.0"
|
||||||
|
|
@ -2461,6 +2524,16 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_path_to_error"
|
||||||
|
version = "0.1.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
|
||||||
|
dependencies = [
|
||||||
|
"itoa",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_urlencoded"
|
name = "serde_urlencoded"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
|
|
@ -2984,6 +3057,7 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -3022,6 +3096,7 @@ version = "0.1.41"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
|
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"log",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tracing-attributes",
|
"tracing-attributes",
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,7 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
members = ["llm_gateway", "prompt_gateway", "common", "brightstaff", "hermesllm"]
|
members = ["llm_gateway", "prompt_gateway", "common", "brightstaff", "hermesllm"]
|
||||||
|
|
||||||
|
[workspace.metadata.rust-analyzer]
|
||||||
|
# Enable features for better IDE support
|
||||||
|
cargo.features = ["trace-collection"]
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ use tokio::sync::RwLock;
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
use crate::router::llm_router::RouterService;
|
use crate::router::llm_router::RouterService;
|
||||||
use crate::handlers::utils::{create_streaming_response, PassthroughProcessor, truncate_message};
|
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
|
||||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||||
use crate::tracing::operation_component;
|
use crate::tracing::operation_component;
|
||||||
|
|
||||||
|
|
@ -24,7 +24,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn chat(
|
pub async fn llm_chat(
|
||||||
request: Request<hyper::body::Incoming>,
|
request: Request<hyper::body::Incoming>,
|
||||||
router_service: Arc<RouterService>,
|
router_service: Arc<RouterService>,
|
||||||
full_qualified_llm_provider_url: String,
|
full_qualified_llm_provider_url: String,
|
||||||
|
|
@ -36,12 +36,19 @@ pub async fn chat(
|
||||||
let request_path = request.uri().path().to_string();
|
let request_path = request.uri().path().to_string();
|
||||||
let request_headers = request.headers().clone();
|
let request_headers = request.headers().clone();
|
||||||
|
|
||||||
// Extract traceparent header early (Envoy should have added this)
|
// Extract or generate traceparent - this establishes the trace context for all spans
|
||||||
let traceparent = request_headers
|
let traceparent: String = request_headers
|
||||||
.get("traceparent")
|
.get("traceparent")
|
||||||
.and_then(|h| h.to_str().ok())
|
.and_then(|h| h.to_str().ok())
|
||||||
.unwrap_or("00-00000000000000000000000000000000-0000000000000000-01")
|
.map(|s| s.to_string())
|
||||||
.to_string();
|
.unwrap_or_else(|| {
|
||||||
|
// No traceparent - this is a root span, generate a new trace ID
|
||||||
|
use uuid::Uuid;
|
||||||
|
let trace_id = Uuid::new_v4().to_string().replace("-", "");
|
||||||
|
let span_id = Uuid::new_v4().to_string().replace("-", "")[..16].to_string();
|
||||||
|
// Format: version-trace_id-parent_span_id-trace_flags
|
||||||
|
format!("00-{}-{}-01", trace_id, span_id)
|
||||||
|
});
|
||||||
|
|
||||||
let mut request_headers = request_headers;
|
let mut request_headers = request_headers;
|
||||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||||
|
|
@ -68,6 +75,7 @@ pub async fn chat(
|
||||||
// Model alias resolution: update model field in client_request immediately
|
// Model alias resolution: update model field in client_request immediately
|
||||||
// This ensures all downstream objects use the resolved model
|
// This ensures all downstream objects use the resolved model
|
||||||
let model_from_request = client_request.model().to_string();
|
let model_from_request = client_request.model().to_string();
|
||||||
|
let temperature = client_request.get_temperature();
|
||||||
let is_streaming_request = client_request.is_streaming();
|
let is_streaming_request = client_request.is_streaming();
|
||||||
let resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
let resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
||||||
|
|
||||||
|
|
@ -177,11 +185,12 @@ pub async fn chat(
|
||||||
request_start_system_time,
|
request_start_system_time,
|
||||||
tool_names,
|
tool_names,
|
||||||
user_message_preview,
|
user_message_preview,
|
||||||
|
temperature,
|
||||||
&llm_providers,
|
&llm_providers,
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
// Use PassthroughProcessor to track streaming metrics and finalize the span
|
// Use PassthroughProcessor to track streaming metrics and finalize the span
|
||||||
let processor = PassthroughProcessor::new(
|
let processor = ObservableStreamProcessor::new(
|
||||||
trace_collector,
|
trace_collector,
|
||||||
operation_component::LLM,
|
operation_component::LLM,
|
||||||
llm_span,
|
llm_span,
|
||||||
|
|
@ -230,6 +239,7 @@ async fn build_llm_span(
|
||||||
start_time: std::time::SystemTime,
|
start_time: std::time::SystemTime,
|
||||||
tool_names: Option<Vec<String>>,
|
tool_names: Option<Vec<String>>,
|
||||||
user_message_preview: Option<String>,
|
user_message_preview: Option<String>,
|
||||||
|
temperature: Option<f32>,
|
||||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||||
) -> common::traces::Span {
|
) -> common::traces::Span {
|
||||||
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
|
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
|
||||||
|
|
@ -274,6 +284,10 @@ async fn build_llm_span(
|
||||||
.with_attribute(llm::IS_STREAMING, is_streaming.to_string());
|
.with_attribute(llm::IS_STREAMING, is_streaming.to_string());
|
||||||
|
|
||||||
// Add optional attributes
|
// Add optional attributes
|
||||||
|
if let Some(temp) = temperature {
|
||||||
|
span_builder = span_builder.with_attribute(llm::TEMPERATURE, temp.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(tools) = tool_names {
|
if let Some(tools) = tool_names {
|
||||||
let formatted_tools = tools.iter()
|
let formatted_tools = tools.iter()
|
||||||
.map(|name| format!("{}(...)", name))
|
.map(|name| format!("{}(...)", name))
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
pub mod agent_chat_completions;
|
pub mod agent_chat_completions;
|
||||||
pub mod agent_selector;
|
pub mod agent_selector;
|
||||||
pub mod router;
|
pub mod llm;
|
||||||
pub mod router_chat;
|
pub mod router_chat;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
pub mod function_calling;
|
pub mod function_calling;
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ pub trait StreamProcessor: Send + 'static {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A processor that tracks streaming metrics and finalizes the span
|
/// A processor that tracks streaming metrics and finalizes the span
|
||||||
pub struct PassthroughProcessor {
|
pub struct ObservableStreamProcessor {
|
||||||
collector: Arc<TraceCollector>,
|
collector: Arc<TraceCollector>,
|
||||||
service_name: String,
|
service_name: String,
|
||||||
span: Span,
|
span: Span,
|
||||||
|
|
@ -40,7 +40,7 @@ pub struct PassthroughProcessor {
|
||||||
time_to_first_token: Option<u128>,
|
time_to_first_token: Option<u128>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PassthroughProcessor {
|
impl ObservableStreamProcessor {
|
||||||
/// Create a new passthrough processor
|
/// Create a new passthrough processor
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
|
@ -66,7 +66,7 @@ impl PassthroughProcessor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamProcessor for PassthroughProcessor {
|
impl StreamProcessor for ObservableStreamProcessor {
|
||||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||||
self.total_bytes += chunk.len();
|
self.total_bytes += chunk.len();
|
||||||
self.chunk_count += 1;
|
self.chunk_count += 1;
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use brightstaff::handlers::agent_chat_completions::agent_chat;
|
use brightstaff::handlers::agent_chat_completions::agent_chat;
|
||||||
use brightstaff::handlers::router::chat;
|
use brightstaff::handlers::llm::llm_chat;
|
||||||
use brightstaff::handlers::models::list_models;
|
use brightstaff::handlers::models::list_models;
|
||||||
use brightstaff::handlers::function_calling::{function_calling_chat_handler};
|
use brightstaff::handlers::function_calling::{function_calling_chat_handler};
|
||||||
use brightstaff::router::llm_router::RouterService;
|
use brightstaff::router::llm_router::RouterService;
|
||||||
|
|
@ -130,7 +130,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||||
let fully_qualified_url =
|
let fully_qualified_url =
|
||||||
format!("{}{}", llm_provider_url, req.uri().path());
|
format!("{}{}", llm_provider_url, req.uri().path());
|
||||||
chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector)
|
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector)
|
||||||
.with_context(parent_cx)
|
.with_context(parent_cx)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -83,19 +83,19 @@ pub mod llm {
|
||||||
pub const TOTAL_TOKENS: &str = "llm.usage.total_tokens";
|
pub const TOTAL_TOKENS: &str = "llm.usage.total_tokens";
|
||||||
|
|
||||||
/// Temperature parameter used
|
/// Temperature parameter used
|
||||||
pub const TEMPERATURE: &str = "llm.request.temperature";
|
pub const TEMPERATURE: &str = "llm.temperature";
|
||||||
|
|
||||||
/// Max tokens parameter used
|
/// Max tokens parameter used
|
||||||
pub const MAX_TOKENS: &str = "llm.request.max_tokens";
|
pub const MAX_TOKENS: &str = "llm.max_tokens";
|
||||||
|
|
||||||
/// Top-p parameter used
|
/// Top-p parameter used
|
||||||
pub const TOP_P: &str = "llm.request.top_p";
|
pub const TOP_P: &str = "llm.top_p";
|
||||||
|
|
||||||
/// List of tool names provided in the request
|
/// List of tool names provided in the request
|
||||||
pub const TOOLS: &str = "llm.tools";
|
pub const TOOLS: &str = "llm.tools";
|
||||||
|
|
||||||
/// Preview of the user message (truncated)
|
/// Preview of the user message (truncated)
|
||||||
pub const USER_MESSAGE_PREVIEW: &str = "llm.request.user_message_preview";
|
pub const USER_MESSAGE_PREVIEW: &str = "llm.user_message_preview";
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
|
||||||
|
|
@ -33,4 +33,6 @@ trace-collection = ["tokio", "reqwest", "tracing"]
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
pretty_assertions = "1.4.1"
|
pretty_assertions = "1.4.1"
|
||||||
serde_json = "1.0.64"
|
serde_json = "1.0.64"
|
||||||
|
serial_test = "3.2"
|
||||||
|
axum = "0.7"
|
||||||
tokio = { version = "1.44", features = ["sync", "time", "macros", "rt"] }
|
tokio = { version = "1.44", features = ["sync", "time", "macros", "rt"] }
|
||||||
|
|
|
||||||
|
|
@ -52,13 +52,13 @@ impl TraceCollector {
|
||||||
/// - `None` - Check `OTEL_TRACING_ENABLED` env var (defaults to true if not set)
|
/// - `None` - Check `OTEL_TRACING_ENABLED` env var (defaults to true if not set)
|
||||||
///
|
///
|
||||||
/// Other parameters are read from environment variables:
|
/// Other parameters are read from environment variables:
|
||||||
/// - `TRACE_FLUSH_INTERVAL_SECS` - Flush interval in seconds (default: 1)
|
/// - `TRACE_FLUSH_INTERVAL_MS` - Flush interval in milliseconds (default: 1000)
|
||||||
/// - `OTEL_COLLECTOR_URL` - OTEL collector endpoint (default: http://localhost:9903/v1/traces)
|
/// - `OTEL_COLLECTOR_URL` - OTEL collector endpoint (default: http://localhost:9903/v1/traces)
|
||||||
pub fn new(enabled: Option<bool>) -> Self {
|
pub fn new(enabled: Option<bool>) -> Self {
|
||||||
let flush_interval_secs = std::env::var("TRACE_FLUSH_INTERVAL_SECS")
|
let flush_interval_ms = std::env::var("TRACE_FLUSH_INTERVAL_MS")
|
||||||
.ok()
|
.ok()
|
||||||
.and_then(|s| s.parse().ok())
|
.and_then(|s| s.parse().ok())
|
||||||
.unwrap_or(1);
|
.unwrap_or(1000);
|
||||||
|
|
||||||
let otel_url = std::env::var("OTEL_COLLECTOR_URL")
|
let otel_url = std::env::var("OTEL_COLLECTOR_URL")
|
||||||
.unwrap_or_else(|_| "http://localhost:9903/v1/traces".to_string());
|
.unwrap_or_else(|_| "http://localhost:9903/v1/traces".to_string());
|
||||||
|
|
@ -75,13 +75,13 @@ impl TraceCollector {
|
||||||
});
|
});
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"TraceCollector initialized: flush_interval={}s, url={}, enabled={}",
|
"TraceCollector initialized: flush_interval={}ms, url={}, enabled={}",
|
||||||
flush_interval_secs, otel_url, enabled
|
flush_interval_ms, otel_url, enabled
|
||||||
);
|
);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
spans_by_service: Arc::new(Mutex::new(HashMap::new())),
|
spans_by_service: Arc::new(Mutex::new(HashMap::new())),
|
||||||
flush_interval: Duration::from_secs(flush_interval_secs),
|
flush_interval: Duration::from_millis(flush_interval_ms),
|
||||||
otel_url,
|
otel_url,
|
||||||
enabled,
|
enabled,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,9 @@ mod constants;
|
||||||
#[cfg(feature = "trace-collection")]
|
#[cfg(feature = "trace-collection")]
|
||||||
mod collector;
|
mod collector;
|
||||||
|
|
||||||
|
#[cfg(all(test, feature = "trace-collection"))]
|
||||||
|
mod tests;
|
||||||
|
|
||||||
// Re-export original types
|
// Re-export original types
|
||||||
pub use shapes::{
|
pub use shapes::{
|
||||||
Span, Event, Traceparent, TraceparentNewError,
|
Span, Event, Traceparent, TraceparentNewError,
|
||||||
|
|
|
||||||
101
crates/common/src/traces/tests/mock_otel_collector.rs
Normal file
101
crates/common/src/traces/tests/mock_otel_collector.rs
Normal file
|
|
@ -0,0 +1,101 @@
|
||||||
|
//! Mock OTEL Collector for testing trace output
|
||||||
|
//!
|
||||||
|
//! This module provides a simple HTTP server that mimics an OTEL collector.
|
||||||
|
//! It exposes three endpoints:
|
||||||
|
//! - POST /v1/traces: Capture incoming OTLP JSON payloads
|
||||||
|
//! - GET /v1/traces: Return all captured payloads as JSON array
|
||||||
|
//! - DELETE /v1/traces: Clear all captured payloads
|
||||||
|
//!
|
||||||
|
//! Each test creates its own MockOtelCollector instance.
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
http::StatusCode,
|
||||||
|
routing::{delete, get, post},
|
||||||
|
Json, Router,
|
||||||
|
};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
type SharedTraces = Arc<RwLock<Vec<Value>>>;
|
||||||
|
|
||||||
|
/// POST /v1/traces - capture incoming OTLP payload
|
||||||
|
async fn post_traces(
|
||||||
|
State(traces): State<SharedTraces>,
|
||||||
|
Json(payload): Json<Value>,
|
||||||
|
) -> StatusCode {
|
||||||
|
traces.write().await.push(payload);
|
||||||
|
StatusCode::OK
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /v1/traces - return all captured payloads
|
||||||
|
async fn get_traces(State(traces): State<SharedTraces>) -> Json<Vec<Value>> {
|
||||||
|
Json(traces.read().await.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DELETE /v1/traces - clear all captured payloads
|
||||||
|
async fn delete_traces(State(traces): State<SharedTraces>) -> StatusCode {
|
||||||
|
traces.write().await.clear();
|
||||||
|
StatusCode::NO_CONTENT
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mock OTEL collector server
|
||||||
|
pub struct MockOtelCollector {
|
||||||
|
address: String,
|
||||||
|
client: reqwest::Client,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
server_handle: tokio::task::JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockOtelCollector {
|
||||||
|
/// Create and start a new mock collector on a random port
|
||||||
|
pub async fn start() -> Self {
|
||||||
|
let traces = Arc::new(RwLock::new(Vec::new()));
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/v1/traces", post(post_traces))
|
||||||
|
.route("/v1/traces", get(get_traces))
|
||||||
|
.route("/v1/traces", delete(delete_traces))
|
||||||
|
.with_state(traces.clone());
|
||||||
|
|
||||||
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||||
|
.await
|
||||||
|
.expect("Failed to bind to random port");
|
||||||
|
|
||||||
|
let addr = listener.local_addr().expect("Failed to get local address");
|
||||||
|
let address = format!("http://127.0.0.1:{}", addr.port());
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app)
|
||||||
|
.await
|
||||||
|
.expect("Server failed");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Give server a moment to start
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
address,
|
||||||
|
client: reqwest::Client::new(),
|
||||||
|
server_handle,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the address of the collector
|
||||||
|
pub fn address(&self) -> &str {
|
||||||
|
&self.address
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /v1/traces - fetch all captured payloads
|
||||||
|
pub async fn get_traces(&self) -> Vec<Value> {
|
||||||
|
self.client
|
||||||
|
.get(format!("{}/v1/traces", self.address))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("Failed to GET traces")
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.expect("Failed to parse traces JSON")
|
||||||
|
}
|
||||||
|
}
|
||||||
4
crates/common/src/traces/tests/mod.rs
Normal file
4
crates/common/src/traces/tests/mod.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
mod mock_otel_collector;
|
||||||
|
mod trace_integration_test;
|
||||||
|
|
||||||
|
pub use mock_otel_collector::MockOtelCollector;
|
||||||
304
crates/common/src/traces/tests/trace_integration_test.rs
Normal file
304
crates/common/src/traces/tests/trace_integration_test.rs
Normal file
|
|
@ -0,0 +1,304 @@
|
||||||
|
//! Integration tests for OpenTelemetry tracing in router.rs
|
||||||
|
//!
|
||||||
|
//! These tests validate that the spans created for LLM requests contain
|
||||||
|
//! all expected attributes and events by checking the raw JSON payloads
|
||||||
|
//! sent to the mock OTEL collector.
|
||||||
|
//!
|
||||||
|
//! ## Test Design
|
||||||
|
//! Each test creates its own MockOtelCollector and TraceCollector:
|
||||||
|
//! 1. Start MockOtelCollector on random port
|
||||||
|
//! 2. Create TraceCollector with 500ms flush interval
|
||||||
|
//! 3. Record spans using TraceCollector
|
||||||
|
//! 4. Flush and wait (500ms + 200ms buffer = 700ms total) for spans to arrive
|
||||||
|
//! 5. Get raw JSON payloads (GET /v1/traces) and validate structure
|
||||||
|
//! 6. Test cleanup happens automatically when collectors are dropped
|
||||||
|
//!
|
||||||
|
//! ## Serial Execution
|
||||||
|
//! Tests use the `#[serial]` attribute to run sequentially because they
|
||||||
|
//! use global environment variables (OTEL_COLLECTOR_URL, OTEL_TRACING_ENABLED,
|
||||||
|
//! TRACE_FLUSH_INTERVAL_MS). This ensures test isolation without requiring
|
||||||
|
//! the `--test-threads=1` command line flag.
|
||||||
|
|
||||||
|
const FLUSH_INTERVAL_MS: u64 = 50;
|
||||||
|
const FLUSH_BUFFER_MS: u64 = 50;
|
||||||
|
const TOTAL_WAIT_MS: u64 = FLUSH_INTERVAL_MS + FLUSH_BUFFER_MS;
|
||||||
|
|
||||||
|
use crate::traces::{SpanBuilder, SpanKind, TraceCollector};
|
||||||
|
use serde_json::Value;
|
||||||
|
use serial_test::serial;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use super::MockOtelCollector;
|
||||||
|
|
||||||
|
/// Helper to extract all spans from OTLP JSON payloads
|
||||||
|
fn extract_spans(payloads: &[Value]) -> Vec<&Value> {
|
||||||
|
let mut spans = Vec::new();
|
||||||
|
for payload in payloads {
|
||||||
|
if let Some(resource_spans) = payload.get("resourceSpans").and_then(|v| v.as_array()) {
|
||||||
|
for resource_span in resource_spans {
|
||||||
|
if let Some(scope_spans) = resource_span.get("scopeSpans").and_then(|v| v.as_array()) {
|
||||||
|
for scope_span in scope_spans {
|
||||||
|
if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array()) {
|
||||||
|
spans.extend(span_list.iter());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
spans
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to get string attribute value from a span
|
||||||
|
fn get_string_attr<'a>(span: &'a Value, key: &str) -> Option<&'a str> {
|
||||||
|
span.get("attributes")
|
||||||
|
.and_then(|attrs| attrs.as_array())
|
||||||
|
.and_then(|attrs| {
|
||||||
|
attrs.iter().find(|attr| {
|
||||||
|
attr.get("key").and_then(|k| k.as_str()) == Some(key)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.and_then(|attr| attr.get("value"))
|
||||||
|
.and_then(|v| v.get("stringValue"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_llm_span_contains_basic_attributes() {
|
||||||
|
// Start mock OTEL collector
|
||||||
|
let mock_collector = MockOtelCollector::start().await;
|
||||||
|
|
||||||
|
// Create TraceCollector pointing to mock with 500ms flush intervalc
|
||||||
|
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||||
|
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||||
|
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||||
|
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||||
|
|
||||||
|
// Create a test span simulating router.rs behavior
|
||||||
|
let span = SpanBuilder::new("POST /v1/chat/completions >> /v1/chat/completions")
|
||||||
|
.with_kind(SpanKind::Client)
|
||||||
|
.with_trace_id("test-trace-123")
|
||||||
|
.with_attribute("http.method", "POST")
|
||||||
|
.with_attribute("http.target", "/v1/chat/completions")
|
||||||
|
.with_attribute("http.upstream_target", "/v1/chat/completions")
|
||||||
|
.with_attribute("llm.model", "gpt-4o")
|
||||||
|
.with_attribute("llm.provider", "openai")
|
||||||
|
.with_attribute("llm.is_streaming", "true")
|
||||||
|
.with_attribute("llm.temperature", "0.7")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
trace_collector.record_span("archgw(llm)", span);
|
||||||
|
|
||||||
|
// Flush and wait for spans to arrive (500ms flush interval + 200ms buffer)
|
||||||
|
trace_collector.flush().await.expect("Failed to flush");
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||||
|
|
||||||
|
let payloads = mock_collector.get_traces().await;
|
||||||
|
let spans = extract_spans(&payloads);
|
||||||
|
|
||||||
|
assert_eq!(spans.len(), 1, "Expected exactly one span");
|
||||||
|
|
||||||
|
let span = spans[0];
|
||||||
|
// Validate HTTP attributes
|
||||||
|
assert_eq!(get_string_attr(span, "http.method"), Some("POST"));
|
||||||
|
assert_eq!(get_string_attr(span, "http.target"), Some("/v1/chat/completions"));
|
||||||
|
|
||||||
|
// Validate LLM attributes
|
||||||
|
assert_eq!(get_string_attr(span, "llm.model"), Some("gpt-4o"));
|
||||||
|
assert_eq!(get_string_attr(span, "llm.provider"), Some("openai"));
|
||||||
|
assert_eq!(get_string_attr(span, "llm.is_streaming"), Some("true"));
|
||||||
|
assert_eq!(get_string_attr(span, "llm.temperature"), Some("0.7"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_llm_span_contains_tool_information() {
|
||||||
|
let mock_collector = MockOtelCollector::start().await;
|
||||||
|
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||||
|
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||||
|
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||||
|
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||||
|
|
||||||
|
let tools_formatted = "get_weather(...)\nsearch_web(...)\ncalculate(...)";
|
||||||
|
|
||||||
|
let span = SpanBuilder::new("POST /v1/chat/completions")
|
||||||
|
.with_trace_id("test-trace-tools")
|
||||||
|
.with_attribute("llm.request.tools", tools_formatted)
|
||||||
|
.with_attribute("llm.model", "gpt-4o")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
trace_collector.record_span("archgw(llm)", span);
|
||||||
|
trace_collector.flush().await.expect("Failed to flush");
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||||
|
|
||||||
|
let payloads = mock_collector.get_traces().await;
|
||||||
|
let spans = extract_spans(&payloads);
|
||||||
|
|
||||||
|
assert!(!spans.is_empty(), "No spans captured");
|
||||||
|
|
||||||
|
let span = spans[0];
|
||||||
|
let tools = get_string_attr(span, "llm.request.tools");
|
||||||
|
|
||||||
|
assert!(tools.is_some(), "Tools attribute missing");
|
||||||
|
assert!(tools.unwrap().contains("get_weather(...)"));
|
||||||
|
assert!(tools.unwrap().contains("search_web(...)"));
|
||||||
|
assert!(tools.unwrap().contains("calculate(...)"));
|
||||||
|
assert!(tools.unwrap().contains('\n'), "Tools should be newline-separated");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_llm_span_contains_user_message_preview() {
|
||||||
|
let mock_collector = MockOtelCollector::start().await;
|
||||||
|
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||||
|
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||||
|
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||||
|
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||||
|
|
||||||
|
let long_message = "This is a very long user message that should be truncated to 50 characters in the span";
|
||||||
|
let preview = if long_message.len() > 50 {
|
||||||
|
format!("{}...", &long_message[..50])
|
||||||
|
} else {
|
||||||
|
long_message.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let span = SpanBuilder::new("POST /v1/messages")
|
||||||
|
.with_trace_id("test-trace-preview")
|
||||||
|
.with_attribute("llm.request.user_message_preview", &preview)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
trace_collector.record_span("archgw(llm)", span);
|
||||||
|
trace_collector.flush().await.expect("Failed to flush");
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||||
|
|
||||||
|
let payloads = mock_collector.get_traces().await;
|
||||||
|
let spans = extract_spans(&payloads);
|
||||||
|
let span = spans[0];
|
||||||
|
|
||||||
|
let message_preview = get_string_attr(span, "llm.request.user_message_preview");
|
||||||
|
|
||||||
|
assert!(message_preview.is_some());
|
||||||
|
assert!(message_preview.unwrap().len() <= 53); // 50 chars + "..."
|
||||||
|
assert!(message_preview.unwrap().contains("..."));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_llm_span_contains_time_to_first_token() {
|
||||||
|
let mock_collector = MockOtelCollector::start().await;
|
||||||
|
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||||
|
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||||
|
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||||
|
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||||
|
|
||||||
|
let ttft_ms = "245"; // milliseconds as string
|
||||||
|
|
||||||
|
let span = SpanBuilder::new("POST /v1/chat/completions")
|
||||||
|
.with_trace_id("test-trace-ttft")
|
||||||
|
.with_attribute("llm.is_streaming", "true")
|
||||||
|
.with_attribute("llm.time_to_first_token_ms", ttft_ms)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
trace_collector.record_span("archgw(llm)", span);
|
||||||
|
trace_collector.flush().await.expect("Failed to flush");
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||||
|
|
||||||
|
let payloads = mock_collector.get_traces().await;
|
||||||
|
let spans = extract_spans(&payloads);
|
||||||
|
let span = spans[0];
|
||||||
|
|
||||||
|
// Check TTFT attribute
|
||||||
|
let ttft_attr = get_string_attr(span, "llm.time_to_first_token_ms");
|
||||||
|
assert_eq!(ttft_attr, Some("245"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_llm_span_contains_upstream_path() {
|
||||||
|
let mock_collector = MockOtelCollector::start().await;
|
||||||
|
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||||
|
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||||
|
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||||
|
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||||
|
|
||||||
|
// Test Zhipu provider with path transformation
|
||||||
|
let span = SpanBuilder::new("POST /v1/chat/completions >> /api/paas/v4/chat/completions")
|
||||||
|
.with_trace_id("test-trace-upstream")
|
||||||
|
.with_attribute("http.upstream_target", "/api/paas/v4/chat/completions")
|
||||||
|
.with_attribute("llm.provider", "zhipu")
|
||||||
|
.with_attribute("llm.model", "glm-4")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
trace_collector.record_span("archgw(llm)", span);
|
||||||
|
trace_collector.flush().await.expect("Failed to flush");
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||||
|
|
||||||
|
let payloads = mock_collector.get_traces().await;
|
||||||
|
let spans = extract_spans(&payloads);
|
||||||
|
let span = spans[0];
|
||||||
|
|
||||||
|
// Operation name should show the transformation
|
||||||
|
let name = span.get("name").and_then(|v| v.as_str());
|
||||||
|
assert!(name.is_some());
|
||||||
|
assert!(name.unwrap().contains(">>"), "Operation name should show path transformation");
|
||||||
|
|
||||||
|
// Check upstream target attribute
|
||||||
|
let upstream = get_string_attr(span, "http.upstream_target");
|
||||||
|
assert_eq!(upstream, Some("/api/paas/v4/chat/completions"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_llm_span_multiple_services() {
|
||||||
|
let mock_collector = MockOtelCollector::start().await;
|
||||||
|
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||||
|
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||||
|
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||||
|
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||||
|
|
||||||
|
// Create spans for different services
|
||||||
|
let llm_span = SpanBuilder::new("LLM Request")
|
||||||
|
.with_trace_id("test-multi")
|
||||||
|
.with_attribute("service", "llm")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let routing_span = SpanBuilder::new("Routing Decision")
|
||||||
|
.with_trace_id("test-multi")
|
||||||
|
.with_attribute("service", "routing")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
trace_collector.record_span("archgw(llm)", llm_span);
|
||||||
|
trace_collector.record_span("archgw(routing)", routing_span);
|
||||||
|
trace_collector.flush().await.expect("Failed to flush");
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||||
|
|
||||||
|
let payloads = mock_collector.get_traces().await;
|
||||||
|
let all_spans = extract_spans(&payloads);
|
||||||
|
|
||||||
|
assert_eq!(all_spans.len(), 2, "Should have captured both spans");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_tracing_disabled_produces_no_spans() {
|
||||||
|
let mock_collector = MockOtelCollector::start().await;
|
||||||
|
|
||||||
|
// Create TraceCollector with tracing DISABLED
|
||||||
|
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||||
|
std::env::set_var("OTEL_TRACING_ENABLED", "false");
|
||||||
|
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||||
|
let trace_collector = Arc::new(TraceCollector::new(Some(false)));
|
||||||
|
|
||||||
|
let span = SpanBuilder::new("Test Span")
|
||||||
|
.with_trace_id("test-disabled")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
trace_collector.record_span("archgw(llm)", span);
|
||||||
|
trace_collector.flush().await.ok(); // Should be no-op when disabled
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||||
|
|
||||||
|
let payloads = mock_collector.get_traces().await;
|
||||||
|
let all_spans = extract_spans(&payloads);
|
||||||
|
assert_eq!(all_spans.len(), 0, "No spans should be captured when tracing is disabled");
|
||||||
|
}
|
||||||
|
|
@ -229,6 +229,10 @@ impl ProviderRequest for ConverseRequest {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_temperature(&self) -> Option<f32> {
|
||||||
|
self.inference_config.as_ref()?.temperature
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
|
||||||
|
|
@ -537,6 +537,10 @@ impl ProviderRequest for MessagesRequest {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_temperature(&self) -> Option<f32> {
|
||||||
|
self.temperature
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MessagesResponse {
|
impl MessagesResponse {
|
||||||
|
|
|
||||||
|
|
@ -731,6 +731,10 @@ impl ProviderRequest for ChatCompletionsRequest {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_temperature(&self) -> Option<f32> {
|
||||||
|
self.temperature
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation of ProviderResponse for ChatCompletionsResponse
|
/// Implementation of ProviderResponse for ChatCompletionsResponse
|
||||||
|
|
|
||||||
|
|
@ -1094,6 +1094,10 @@ impl ProviderRequest for ResponsesAPIRequest {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_temperature(&self) -> Option<f32> {
|
||||||
|
self.temperature
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,8 @@ pub trait ProviderRequest: Send + Sync {
|
||||||
|
|
||||||
/// Remove a metadata key from the request and return true if the key was present
|
/// Remove a metadata key from the request and return true if the key was present
|
||||||
fn remove_metadata_key(&mut self, key: &str) -> bool;
|
fn remove_metadata_key(&mut self, key: &str) -> bool;
|
||||||
|
|
||||||
|
fn get_temperature(&self) -> Option<f32>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ProviderRequest for ProviderRequestType {
|
impl ProviderRequest for ProviderRequestType {
|
||||||
|
|
@ -137,6 +139,16 @@ impl ProviderRequest for ProviderRequestType {
|
||||||
Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key),
|
Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_temperature(&self) -> Option<f32> {
|
||||||
|
match self {
|
||||||
|
Self::ChatCompletionsRequest(r) => r.get_temperature(),
|
||||||
|
Self::MessagesRequest(r) => r.get_temperature(),
|
||||||
|
Self::BedrockConverse(r) => r.get_temperature(),
|
||||||
|
Self::BedrockConverseStream(r) => r.get_temperature(),
|
||||||
|
Self::ResponsesAPIRequest(r) => r.get_temperature(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse the client API from a byte slice.
|
/// Parse the client API from a byte slice.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue