add support for jaeger tracing (#229)

This commit is contained in:
Adil Hafeez 2024-11-07 22:11:00 -06:00 committed by GitHub
parent fb67788be0
commit a72bb804eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
64 changed files with 5032 additions and 1112 deletions

View file

@ -10,14 +10,14 @@ use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
PromptGuardResponse, ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::configuration::{Overrides, PromptGuards, PromptTarget, Tracing};
use common::consts::{
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER,
ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST,
HALLUCINATION_TEMPLATE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
ZEROSHOT_INTERNAL_HOST,
HALLUCINATION_TEMPLATE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE,
TRACE_PARENT_HEADER, USER_ROLE, ZEROSHOT_INTERNAL_HOST,
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -77,9 +77,12 @@ pub struct StreamContext {
pub chat_completions_request: Option<ChatCompletionsRequest>,
pub prompt_guards: Rc<PromptGuards>,
pub request_id: Option<String>,
pub traceparent: Option<String>,
pub tracing: Rc<Option<Tracing>>,
}
impl StreamContext {
#[allow(clippy::too_many_arguments)]
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
@ -88,6 +91,7 @@ impl StreamContext {
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
tracing: Rc<Option<Tracing>>,
) -> Self {
StreamContext {
context_id,
@ -107,6 +111,8 @@ impl StreamContext {
prompt_guards,
overrides,
request_id: None,
traceparent: None,
tracing,
}
}
@ -165,9 +171,15 @@ impl StreamContext {
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/embeddings",
@ -293,6 +305,10 @@ impl StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/zeroshot",
@ -309,6 +325,16 @@ impl StreamContext {
}
}
fn trace_arch_internal(&self) -> bool {
match self.tracing.as_ref() {
Some(tracing) => match tracing.trace_arch_internal.as_ref() {
Some(trace_arch_internal) => *trace_arch_internal,
None => false,
},
None => false,
}
}
pub fn hallucination_classification_resp_handler(
&mut self,
body: Vec<u8>,
@ -489,6 +515,10 @@ impl StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
@ -636,6 +666,10 @@ impl StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/v1/chat/completions",
@ -812,6 +846,10 @@ impl StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.trace_arch_internal() && self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/hallucination",
@ -863,6 +901,10 @@ impl StreamContext {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&path,
@ -1047,7 +1089,17 @@ impl StreamContext {
{
let default_target_response_str = if self.streaming_response {
let chat_completion_response =
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
match serde_json::from_slice::<ChatCompletionsResponse>(&body) {
Ok(chat_completion_response) => chat_completion_response,
Err(e) => {
warn!(
"error deserializing default target response: {}, body str: {}",
e,
String::from_utf8(body).unwrap()
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let chunks = vec![
ChatCompletionStreamResponse::new(