add support for jaeger tracing (#229)

This commit is contained in:
Adil Hafeez 2024-11-07 22:11:00 -06:00 committed by GitHub
parent fb67788be0
commit a72bb804eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
64 changed files with 5032 additions and 1112 deletions

View file

@ -12,6 +12,7 @@ pub struct Overrides {
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Tracing {
pub sampling_rate: Option<f64>,
pub trace_arch_internal: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]

View file

@ -22,6 +22,7 @@ pub const HEALTHZ_PATH: &str = "/healthz";
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
pub const REQUEST_ID_HEADER: &str = "x-request-id";
pub const TRACE_PARENT_HEADER: &str = "traceparent";
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";

View file

@ -157,7 +157,6 @@ impl HttpContext for StreamContext {
);
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
Action::Continue
}

View file

@ -34,14 +34,14 @@ impl Context for StreamContext {
on_http_request_body prompt received get embeddings zeroshot intent
arch guard
arch guard
continue from zeroshot intent

View file

@ -1,6 +1,6 @@
use crate::stream_context::StreamContext;
use common::common_types::EmbeddingType;
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::DEFAULT_EMBEDDING_MODEL;
use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST};
@ -55,6 +55,7 @@ pub struct FilterContext {
embeddings_store: Option<Rc<EmbeddingsStore>>,
temp_embeddings_store: EmbeddingsStore,
active_embedding_calls_count: u32,
tracing: Rc<Option<Tracing>>,
}
impl FilterContext {
@ -69,6 +70,7 @@ impl FilterContext {
embeddings_store: Some(Rc::new(HashMap::new())),
temp_embeddings_store: HashMap::new(),
active_embedding_calls_count: 0,
tracing: Rc::new(None),
}
}
@ -270,6 +272,8 @@ impl RootContext for FilterContext {
self.prompt_guards = Rc::new(prompt_guards)
}
self.tracing = Rc::new(config.tracing);
true
}
@ -288,6 +292,7 @@ impl RootContext for FilterContext {
Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides),
embedding_store,
Rc::clone(&self.tracing),
)))
}

View file

@ -10,7 +10,7 @@ use common::{
consts::{
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, GUARD_INTERNAL_HOST,
HEALTHZ_PATH, REQUEST_ID_HEADER, TOOL_ROLE, USER_ROLE,
HEALTHZ_PATH, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
},
errors::ServerError,
http::{CallArgs, Client},
@ -52,13 +52,14 @@ impl HttpContext for StreamContext {
);
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
Action::Continue
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
// Let the client send the gateway all the data before sending to the LLM_provider.
// TODO: consider a streaming API.
if !end_of_stream {
return Action::Pause;
}
@ -195,6 +196,10 @@ impl HttpContext for StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/guard",
@ -241,7 +246,7 @@ impl HttpContext for StreamContext {
);
if !self.is_chat_completions_request {
debug!("non-streaming request");
debug!("non-gpt request");
return Action::Continue;
}

View file

@ -10,14 +10,14 @@ use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
PromptGuardResponse, ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::configuration::{Overrides, PromptGuards, PromptTarget, Tracing};
use common::consts::{
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER,
ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST,
HALLUCINATION_TEMPLATE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
ZEROSHOT_INTERNAL_HOST,
HALLUCINATION_TEMPLATE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE,
TRACE_PARENT_HEADER, USER_ROLE, ZEROSHOT_INTERNAL_HOST,
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -77,9 +77,12 @@ pub struct StreamContext {
pub chat_completions_request: Option<ChatCompletionsRequest>,
pub prompt_guards: Rc<PromptGuards>,
pub request_id: Option<String>,
pub traceparent: Option<String>,
pub tracing: Rc<Option<Tracing>>,
}
impl StreamContext {
#[allow(clippy::too_many_arguments)]
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
@ -88,6 +91,7 @@ impl StreamContext {
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
tracing: Rc<Option<Tracing>>,
) -> Self {
StreamContext {
context_id,
@ -107,6 +111,8 @@ impl StreamContext {
prompt_guards,
overrides,
request_id: None,
traceparent: None,
tracing,
}
}
@ -165,9 +171,15 @@ impl StreamContext {
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/embeddings",
@ -293,6 +305,10 @@ impl StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/zeroshot",
@ -309,6 +325,16 @@ impl StreamContext {
}
}
fn trace_arch_internal(&self) -> bool {
match self.tracing.as_ref() {
Some(tracing) => match tracing.trace_arch_internal.as_ref() {
Some(trace_arch_internal) => *trace_arch_internal,
None => false,
},
None => false,
}
}
pub fn hallucination_classification_resp_handler(
&mut self,
body: Vec<u8>,
@ -489,6 +515,10 @@ impl StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
@ -636,6 +666,10 @@ impl StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/v1/chat/completions",
@ -812,6 +846,10 @@ impl StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/hallucination",
@ -863,6 +901,10 @@ impl StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&path,
@ -1047,7 +1089,17 @@ impl StreamContext {
{
let default_target_response_str = if self.streaming_response {
let chat_completion_response =
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
match serde_json::from_slice::<ChatCompletionsResponse>(&body) {
Ok(chat_completion_response) => chat_completion_response,
Err(e) => {
warn!(
"error deserializing default target response: {}, body str: {}",
e,
String::from_utf8(body).unwrap()
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let chunks = vec![
ChatCompletionStreamResponse::new(

View file

@ -36,6 +36,8 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
.expect_log(Some(LogLevel::Trace), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
.returning(None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}