mirror of
https://github.com/katanemo/plano.git
synced 2026-05-08 07:12:42 +02:00
Improve logging (#209)
* improve logging * fix int tests * better * fix more logs * fix more * fix int
This commit is contained in:
parent
2f374df034
commit
ea76d85b43
6 changed files with 319 additions and 309 deletions
|
|
@ -7,7 +7,6 @@ pub const SYSTEM_ROLE: &str = "system";
|
||||||
pub const USER_ROLE: &str = "user";
|
pub const USER_ROLE: &str = "user";
|
||||||
pub const TOOL_ROLE: &str = "tool";
|
pub const TOOL_ROLE: &str = "tool";
|
||||||
pub const ASSISTANT_ROLE: &str = "assistant";
|
pub const ASSISTANT_ROLE: &str = "assistant";
|
||||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
|
||||||
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||||
pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot";
|
pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot";
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ use crate::{
|
||||||
stats::{Gauge, IncrementingMetric},
|
stats::{Gauge, IncrementingMetric},
|
||||||
};
|
};
|
||||||
use derivative::Derivative;
|
use derivative::Derivative;
|
||||||
use log::debug;
|
use log::{debug, trace};
|
||||||
use proxy_wasm::{traits::Context, types::Status};
|
use proxy_wasm::{traits::Context, types::Status};
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration};
|
use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration};
|
||||||
|
|
@ -48,9 +48,10 @@ pub trait Client: Context {
|
||||||
call_args: CallArgs,
|
call_args: CallArgs,
|
||||||
call_context: Self::CallContext,
|
call_context: Self::CallContext,
|
||||||
) -> Result<u32, ClientError> {
|
) -> Result<u32, ClientError> {
|
||||||
debug!(
|
trace!(
|
||||||
"dispatching http call with args={:?} context={:?}",
|
"dispatching http call with args={:?} context={:?}",
|
||||||
call_args, call_context
|
call_args,
|
||||||
|
call_context
|
||||||
);
|
);
|
||||||
|
|
||||||
match self.dispatch_http_call(
|
match self.dispatch_http_call(
|
||||||
|
|
|
||||||
|
|
@ -74,24 +74,15 @@ impl Context for StreamContext {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||||
|
#[cfg_attr(any(), rustfmt::skip)]
|
||||||
match callout_context.response_handler_type {
|
match callout_context.response_handler_type {
|
||||||
ResponseHandlerType::GetEmbeddings => {
|
|
||||||
self.embeddings_handler(body, callout_context)
|
|
||||||
}
|
|
||||||
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
||||||
ResponseHandlerType::ZeroShotIntent => {
|
ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context),
|
||||||
self.zero_shot_intent_detection_resp_handler(body, callout_context)
|
ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context),
|
||||||
}
|
|
||||||
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
|
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
|
||||||
ResponseHandlerType::HallucinationDetect => {
|
ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context),
|
||||||
self.hallucination_classification_resp_handler(body, callout_context)
|
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
|
||||||
}
|
ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
|
||||||
ResponseHandlerType::FunctionCall => {
|
|
||||||
self.function_call_response_handler(body, callout_context)
|
|
||||||
}
|
|
||||||
ResponseHandlerType::DefaultTarget => {
|
|
||||||
self.default_target_handler(body, callout_context)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
self.send_server_error(
|
self.send_server_error(
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ use common::{
|
||||||
http::{CallArgs, Client},
|
http::{CallArgs, Client},
|
||||||
};
|
};
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use log::{debug, warn};
|
use log::{debug, trace, warn};
|
||||||
use proxy_wasm::{traits::HttpContext, types::Action};
|
use proxy_wasm::{traits::HttpContext, types::Action};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
|
|
@ -36,7 +36,7 @@ impl HttpContext for StreamContext {
|
||||||
self.is_chat_completions_request =
|
self.is_chat_completions_request =
|
||||||
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
|
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
|
||||||
|
|
||||||
debug!(
|
trace!(
|
||||||
"on_http_request_headers S[{}] req_headers={:?}",
|
"on_http_request_headers S[{}] req_headers={:?}",
|
||||||
self.context_id,
|
self.context_id,
|
||||||
self.get_http_request_headers()
|
self.get_http_request_headers()
|
||||||
|
|
@ -60,25 +60,14 @@ impl HttpContext for StreamContext {
|
||||||
|
|
||||||
self.request_body_size = body_size;
|
self.request_body_size = body_size;
|
||||||
|
|
||||||
debug!(
|
trace!(
|
||||||
"on_http_request_body S[{}] body_size={}",
|
"on_http_request_body S[{}] body_size={}",
|
||||||
self.context_id, body_size
|
self.context_id,
|
||||||
|
body_size
|
||||||
);
|
);
|
||||||
|
|
||||||
// Deserialize body into spec.
|
let body_bytes = match self.get_http_request_body(0, body_size) {
|
||||||
// Currently OpenAI API.
|
Some(body_bytes) => body_bytes,
|
||||||
let mut deserialized_body: ChatCompletionsRequest =
|
|
||||||
match self.get_http_request_body(0, body_size) {
|
|
||||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
|
||||||
Ok(deserialized) => deserialized,
|
|
||||||
Err(e) => {
|
|
||||||
self.send_server_error(
|
|
||||||
ServerError::Deserialization(e),
|
|
||||||
Some(StatusCode::BAD_REQUEST),
|
|
||||||
);
|
|
||||||
return Action::Pause;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
None => {
|
None => {
|
||||||
self.send_server_error(
|
self.send_server_error(
|
||||||
ServerError::LogicError(format!(
|
ServerError::LogicError(format!(
|
||||||
|
|
@ -91,6 +80,22 @@ impl HttpContext for StreamContext {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
debug!("developer => archgw: {}", String::from_utf8_lossy(&body_bytes));
|
||||||
|
|
||||||
|
// Deserialize body into spec.
|
||||||
|
// Currently OpenAI API.
|
||||||
|
let mut deserialized_body: ChatCompletionsRequest =
|
||||||
|
match serde_json::from_slice(&body_bytes) {
|
||||||
|
Ok(deserialized) => deserialized,
|
||||||
|
Err(e) => {
|
||||||
|
self.send_server_error(
|
||||||
|
ServerError::Deserialization(e),
|
||||||
|
Some(StatusCode::BAD_REQUEST),
|
||||||
|
);
|
||||||
|
return Action::Pause;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
self.arch_state = match deserialized_body.metadata {
|
self.arch_state = match deserialized_body.metadata {
|
||||||
Some(ref metadata) => {
|
Some(ref metadata) => {
|
||||||
if metadata.contains_key(ARCH_STATE_HEADER) {
|
if metadata.contains_key(ARCH_STATE_HEADER) {
|
||||||
|
|
@ -145,7 +150,6 @@ impl HttpContext for StreamContext {
|
||||||
similarity_scores: None,
|
similarity_scores: None,
|
||||||
upstream_cluster: None,
|
upstream_cluster: None,
|
||||||
upstream_cluster_path: None,
|
upstream_cluster_path: None,
|
||||||
tool_calls: None,
|
|
||||||
};
|
};
|
||||||
self.get_embeddings(callout_context);
|
self.get_embeddings(callout_context);
|
||||||
return Action::Pause;
|
return Action::Pause;
|
||||||
|
|
@ -201,7 +205,6 @@ impl HttpContext for StreamContext {
|
||||||
similarity_scores: None,
|
similarity_scores: None,
|
||||||
upstream_cluster: None,
|
upstream_cluster: None,
|
||||||
upstream_cluster_path: None,
|
upstream_cluster_path: None,
|
||||||
tool_calls: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(e) = self.http_call(call_args, call_context) {
|
if let Err(e) = self.http_call(call_args, call_context) {
|
||||||
|
|
@ -212,7 +215,7 @@ impl HttpContext for StreamContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||||
debug!(
|
trace!(
|
||||||
"on_http_response_headers recv [S={}] headers={:?}",
|
"on_http_response_headers recv [S={}] headers={:?}",
|
||||||
self.context_id,
|
self.context_id,
|
||||||
self.get_http_response_headers()
|
self.get_http_response_headers()
|
||||||
|
|
@ -224,9 +227,11 @@ impl HttpContext for StreamContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||||
debug!(
|
trace!(
|
||||||
"recv [S={}] bytes={} end_stream={}",
|
"recv [S={}] bytes={} end_stream={}",
|
||||||
self.context_id, body_size, end_of_stream
|
self.context_id,
|
||||||
|
body_size,
|
||||||
|
end_of_stream
|
||||||
);
|
);
|
||||||
|
|
||||||
if !self.is_chat_completions_request {
|
if !self.is_chat_completions_request {
|
||||||
|
|
@ -248,14 +253,14 @@ impl HttpContext for StreamContext {
|
||||||
.expect("cant get response body");
|
.expect("cant get response body");
|
||||||
|
|
||||||
if self.streaming_response {
|
if self.streaming_response {
|
||||||
debug!("streaming response");
|
trace!("streaming response");
|
||||||
} else {
|
} else {
|
||||||
debug!("non streaming response");
|
trace!("non streaming response");
|
||||||
let chat_completions_response: ChatCompletionsResponse =
|
let chat_completions_response: ChatCompletionsResponse =
|
||||||
match serde_json::from_slice(&body) {
|
match serde_json::from_slice(&body) {
|
||||||
Ok(de) => de,
|
Ok(de) => de,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!(
|
trace!(
|
||||||
"invalid response: {}, {}",
|
"invalid response: {}, {}",
|
||||||
String::from_utf8_lossy(&body),
|
String::from_utf8_lossy(&body),
|
||||||
e
|
e
|
||||||
|
|
@ -316,16 +321,18 @@ impl HttpContext for StreamContext {
|
||||||
serde_json::Value::String(arch_state_str),
|
serde_json::Value::String(arch_state_str),
|
||||||
);
|
);
|
||||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||||
debug!("arch => user: {}", data_serialized);
|
debug!("archgw <= developer: {}", data_serialized);
|
||||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
debug!(
|
trace!(
|
||||||
"recv [S={}] total_tokens={} end_stream={}",
|
"recv [S={}] total_tokens={} end_stream={}",
|
||||||
self.context_id, self.response_tokens, end_of_stream
|
self.context_id,
|
||||||
|
self.response_tokens,
|
||||||
|
end_of_stream
|
||||||
);
|
);
|
||||||
|
|
||||||
Action::Continue
|
Action::Continue
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ use common::consts::{
|
||||||
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
|
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
|
||||||
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER,
|
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER,
|
||||||
ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
|
ARCH_UPSTREAM_HOST_HEADER, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
|
||||||
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO,
|
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST,
|
||||||
HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
|
HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
|
||||||
ZEROSHOT_INTERNAL_HOST,
|
ZEROSHOT_INTERNAL_HOST,
|
||||||
};
|
};
|
||||||
|
|
@ -27,7 +27,7 @@ use common::http::{CallArgs, Client};
|
||||||
use common::stats::Gauge;
|
use common::stats::Gauge;
|
||||||
use derivative::Derivative;
|
use derivative::Derivative;
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use log::{debug, info, warn};
|
use log::{debug, info, trace, warn};
|
||||||
use proxy_wasm::traits::*;
|
use proxy_wasm::traits::*;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
@ -37,11 +37,11 @@ use std::time::Duration;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum ResponseHandlerType {
|
pub enum ResponseHandlerType {
|
||||||
GetEmbeddings,
|
Embeddings,
|
||||||
ArchFC,
|
ArchFC,
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
ZeroShotIntent,
|
ZeroShotIntent,
|
||||||
HallucinationDetect,
|
Hallucination,
|
||||||
ArchGuard,
|
ArchGuard,
|
||||||
DefaultTarget,
|
DefaultTarget,
|
||||||
}
|
}
|
||||||
|
|
@ -54,7 +54,6 @@ pub struct StreamCallContext {
|
||||||
pub prompt_target_name: Option<String>,
|
pub prompt_target_name: Option<String>,
|
||||||
#[derivative(Debug = "ignore")]
|
#[derivative(Debug = "ignore")]
|
||||||
pub request_body: ChatCompletionsRequest,
|
pub request_body: ChatCompletionsRequest,
|
||||||
pub tool_calls: Option<Vec<ToolCall>>,
|
|
||||||
pub similarity_scores: Option<Vec<(String, f64)>>,
|
pub similarity_scores: Option<Vec<(String, f64)>>,
|
||||||
pub upstream_cluster: Option<String>,
|
pub upstream_cluster: Option<String>,
|
||||||
pub upstream_cluster_path: Option<String>,
|
pub upstream_cluster_path: Option<String>,
|
||||||
|
|
@ -129,18 +128,77 @@ impl StreamContext {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_embeddings(&mut self, callout_context: StreamCallContext) {
|
||||||
|
let user_message = callout_context.user_message.unwrap();
|
||||||
|
let get_embeddings_input = CreateEmbeddingRequest {
|
||||||
|
// Need to clone into input because user_message is used below.
|
||||||
|
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
|
||||||
|
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||||
|
encoding_format: None,
|
||||||
|
dimensions: None,
|
||||||
|
user: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let embeddings_request_str: String = match serde_json::to_string(&get_embeddings_input) {
|
||||||
|
Ok(json_data) => json_data,
|
||||||
|
Err(error) => {
|
||||||
|
warn!("error serializing get embeddings request: {}", error);
|
||||||
|
return self.send_server_error(ServerError::Deserialization(error), None);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut headers = vec![
|
||||||
|
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
|
||||||
|
(":method", "POST"),
|
||||||
|
(":path", "/embeddings"),
|
||||||
|
(":authority", EMBEDDINGS_INTERNAL_HOST),
|
||||||
|
("content-type", "application/json"),
|
||||||
|
("x-envoy-max-retries", "3"),
|
||||||
|
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||||
|
];
|
||||||
|
if self.request_id.is_some() {
|
||||||
|
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||||
|
}
|
||||||
|
let call_args = CallArgs::new(
|
||||||
|
ARCH_INTERNAL_CLUSTER_NAME,
|
||||||
|
"/embeddings",
|
||||||
|
headers,
|
||||||
|
Some(embeddings_request_str.as_bytes()),
|
||||||
|
vec![],
|
||||||
|
Duration::from_secs(5),
|
||||||
|
);
|
||||||
|
let call_context = StreamCallContext {
|
||||||
|
response_handler_type: ResponseHandlerType::Embeddings,
|
||||||
|
user_message: Some(user_message),
|
||||||
|
prompt_target_name: None,
|
||||||
|
request_body: callout_context.request_body,
|
||||||
|
similarity_scores: None,
|
||||||
|
upstream_cluster: None,
|
||||||
|
upstream_cluster_path: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"archgw => get embeddings request: {}",
|
||||||
|
embeddings_request_str
|
||||||
|
);
|
||||||
|
if let Err(e) = self.http_call(call_args, call_context) {
|
||||||
|
warn!("error dispatching get embeddings request: {}", e);
|
||||||
|
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
pub fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
||||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
||||||
Ok(embedding_response) => embedding_response,
|
Ok(embedding_response) => embedding_response,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("error deserializing embedding response: {}", e);
|
warn!("error deserializing embedding response: {}", e);
|
||||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let prompt_embeddings_vector = &embedding_response.data[0].embedding;
|
let prompt_embeddings_vector = &embedding_response.data[0].embedding;
|
||||||
|
|
||||||
debug!(
|
trace!(
|
||||||
"embedding model: {}, vector length: {:?}",
|
"embedding model: {}, vector length: {:?}",
|
||||||
embedding_response.model,
|
embedding_response.model,
|
||||||
prompt_embeddings_vector.len()
|
prompt_embeddings_vector.len()
|
||||||
|
|
@ -237,7 +295,7 @@ impl StreamContext {
|
||||||
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
|
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
|
||||||
|
|
||||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||||
debug!("error dispatching zero shot classification request: {}", e);
|
warn!("error dispatching zero shot classification request: {}", e);
|
||||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -247,11 +305,13 @@ impl StreamContext {
|
||||||
body: Vec<u8>,
|
body: Vec<u8>,
|
||||||
callout_context: StreamCallContext,
|
callout_context: StreamCallContext,
|
||||||
) {
|
) {
|
||||||
|
let boyd_str = String::from_utf8(body).expect("could not convert body to string");
|
||||||
|
debug!("archgw <= hallucination response: {}", boyd_str);
|
||||||
let hallucination_response: HallucinationClassificationResponse =
|
let hallucination_response: HallucinationClassificationResponse =
|
||||||
match serde_json::from_slice(&body) {
|
match serde_json::from_str(boyd_str.as_str()) {
|
||||||
Ok(hallucination_response) => hallucination_response,
|
Ok(hallucination_response) => hallucination_response,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("error deserializing hallucination response: {}", e);
|
warn!("error deserializing hallucination response: {}", e);
|
||||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -291,7 +351,7 @@ impl StreamContext {
|
||||||
metadata: None,
|
metadata: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
debug!("hallucination response: {:?}", chat_completion_response);
|
trace!("hallucination response: {:?}", chat_completion_response);
|
||||||
self.send_http_response(
|
self.send_http_response(
|
||||||
StatusCode::OK.as_u16().into(),
|
StatusCode::OK.as_u16().into(),
|
||||||
vec![("Powered-By", "Katanemo")],
|
vec![("Powered-By", "Katanemo")],
|
||||||
|
|
@ -316,7 +376,7 @@ impl StreamContext {
|
||||||
match serde_json::from_slice(&body) {
|
match serde_json::from_slice(&body) {
|
||||||
Ok(zeroshot_response) => zeroshot_response,
|
Ok(zeroshot_response) => zeroshot_response,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!(
|
warn!(
|
||||||
"error deserializing zero shot classification response: {}",
|
"error deserializing zero shot classification response: {}",
|
||||||
e
|
e
|
||||||
);
|
);
|
||||||
|
|
@ -324,7 +384,10 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
debug!("zeroshot intent response: {:?}", zeroshot_intent_response);
|
trace!(
|
||||||
|
"zeroshot intent response: {}",
|
||||||
|
serde_json::to_string(&zeroshot_intent_response).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
let desc_emb_similarity_map: HashMap<String, f64> = callout_context
|
let desc_emb_similarity_map: HashMap<String, f64> = callout_context
|
||||||
.similarity_scores
|
.similarity_scores
|
||||||
|
|
@ -362,7 +425,7 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
info!("no assistant message found, probably first interaction");
|
debug!("no assistant message found, probably first interaction");
|
||||||
}
|
}
|
||||||
|
|
||||||
// get prompt target similarity thresold from overrides
|
// get prompt target similarity thresold from overrides
|
||||||
|
|
@ -382,15 +445,16 @@ impl StreamContext {
|
||||||
// if arch fc responded to the user message, then we don't need to check the similarity score
|
// if arch fc responded to the user message, then we don't need to check the similarity score
|
||||||
// it may be that arch fc is handling the conversation for parameter collection
|
// it may be that arch fc is handling the conversation for parameter collection
|
||||||
if arch_assistant {
|
if arch_assistant {
|
||||||
info!("arch assistant is handling the conversation");
|
info!("arch fc is engaged in parameter collection");
|
||||||
} else {
|
} else {
|
||||||
debug!("checking for default prompt target");
|
|
||||||
if let Some(default_prompt_target) = self
|
if let Some(default_prompt_target) = self
|
||||||
.prompt_targets
|
.prompt_targets
|
||||||
.values()
|
.values()
|
||||||
.find(|pt| pt.default.unwrap_or(false))
|
.find(|pt| pt.default.unwrap_or(false))
|
||||||
{
|
{
|
||||||
debug!("default prompt target found");
|
debug!(
|
||||||
|
"default prompt target found, forwarding request to default prompt target"
|
||||||
|
);
|
||||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||||
|
|
||||||
|
|
@ -401,8 +465,6 @@ impl StreamContext {
|
||||||
callout_context.request_body.messages.clone(),
|
callout_context.request_body.messages.clone(),
|
||||||
);
|
);
|
||||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||||
debug!("no prompt target found with similarity score above threshold, using default prompt target");
|
|
||||||
|
|
||||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||||
|
|
||||||
let mut headers = vec![
|
let mut headers = vec![
|
||||||
|
|
@ -431,7 +493,7 @@ impl StreamContext {
|
||||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||||
|
|
||||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||||
debug!("error dispatching default prompt target request: {}", e);
|
warn!("error dispatching default prompt target request: {}", e);
|
||||||
return self.send_server_error(
|
return self.send_server_error(
|
||||||
ServerError::HttpDispatch(e),
|
ServerError::HttpDispatch(e),
|
||||||
Some(StatusCode::BAD_REQUEST),
|
Some(StatusCode::BAD_REQUEST),
|
||||||
|
|
@ -444,20 +506,12 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let prompt_target = match self.prompt_targets.get(&prompt_target_name) {
|
let prompt_target = self
|
||||||
Some(prompt_target) => prompt_target.clone(),
|
.prompt_targets
|
||||||
None => {
|
.get(&prompt_target_name)
|
||||||
debug!("prompt target not found: {}", prompt_target_name);
|
.expect("prompt target not found")
|
||||||
return self.send_server_error(
|
.clone();
|
||||||
ServerError::LogicError(format!(
|
|
||||||
"Prompt target not found: {prompt_target_name}"
|
|
||||||
)),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
info!("prompt_target name: {:?}", prompt_target_name);
|
|
||||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||||
for pt in self.prompt_targets.values() {
|
for pt in self.prompt_targets.values() {
|
||||||
if pt.default.unwrap_or_default() {
|
if pt.default.unwrap_or_default() {
|
||||||
|
|
@ -506,7 +560,12 @@ impl StreamContext {
|
||||||
);
|
);
|
||||||
|
|
||||||
let chat_completions = ChatCompletionsRequest {
|
let chat_completions = ChatCompletionsRequest {
|
||||||
model: GPT_35_TURBO.to_string(),
|
model: self
|
||||||
|
.chat_completions_request
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.model
|
||||||
|
.clone(),
|
||||||
messages: callout_context.request_body.messages.clone(),
|
messages: callout_context.request_body.messages.clone(),
|
||||||
tools: Some(chat_completion_tools),
|
tools: Some(chat_completion_tools),
|
||||||
stream: false,
|
stream: false,
|
||||||
|
|
@ -515,12 +574,9 @@ impl StreamContext {
|
||||||
};
|
};
|
||||||
|
|
||||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||||
Ok(msg_body) => {
|
Ok(msg_body) => msg_body,
|
||||||
debug!("arch_fc request body content: {}", msg_body);
|
|
||||||
msg_body
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("error serializing arch_fc request body: {}", e);
|
warn!("error serializing arch_fc request body: {}", e);
|
||||||
return self.send_server_error(ServerError::Serialization(e), None);
|
return self.send_server_error(ServerError::Serialization(e), None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -552,6 +608,7 @@ impl StreamContext {
|
||||||
callout_context.response_handler_type = ResponseHandlerType::ArchFC;
|
callout_context.response_handler_type = ResponseHandlerType::ArchFC;
|
||||||
callout_context.prompt_target_name = Some(prompt_target.name);
|
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||||
|
|
||||||
|
debug!("archgw => archfc request: {}", msg_body);
|
||||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||||
debug!("error dispatching arch_fc request: {}", e);
|
debug!("error dispatching arch_fc request: {}", e);
|
||||||
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
|
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
|
||||||
|
|
@ -564,21 +621,28 @@ impl StreamContext {
|
||||||
mut callout_context: StreamCallContext,
|
mut callout_context: StreamCallContext,
|
||||||
) {
|
) {
|
||||||
let body_str = String::from_utf8(body).unwrap();
|
let body_str = String::from_utf8(body).unwrap();
|
||||||
debug!("arch <= app response body: {}", body_str);
|
debug!("archgw <= archfc response: {}", body_str);
|
||||||
|
|
||||||
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
||||||
Ok(arch_fc_response) => arch_fc_response,
|
Ok(arch_fc_response) => arch_fc_response,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("error deserializing arch_fc response: {}", e);
|
warn!("error deserializing archfc response: {}", e);
|
||||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let model_resp = &arch_fc_response.choices[0];
|
arch_fc_response.choices[0]
|
||||||
|
.message
|
||||||
|
.tool_calls
|
||||||
|
.clone_into(&mut self.tool_calls);
|
||||||
|
if self.tool_calls.as_ref().unwrap().len() > 1 {
|
||||||
|
warn!(
|
||||||
|
"multiple tool calls not supported yet, tool_calls count found: {}",
|
||||||
|
self.tool_calls.as_ref().unwrap().len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if model_resp.message.tool_calls.is_none()
|
if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() {
|
||||||
|| model_resp.message.tool_calls.as_ref().unwrap().is_empty()
|
|
||||||
{
|
|
||||||
// This means that Arch FC did not have enough information to resolve the function call
|
// This means that Arch FC did not have enough information to resolve the function call
|
||||||
// Arch FC probably responded with a message asking for more information.
|
// Arch FC probably responded with a message asking for more information.
|
||||||
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
|
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
|
||||||
|
|
@ -592,43 +656,43 @@ impl StreamContext {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
|
|
||||||
self.tool_calls = Some(tool_calls.clone());
|
|
||||||
|
|
||||||
// TODO CO: pass nli check
|
// TODO CO: pass nli check
|
||||||
// If hallucination, pass chat template to check parameters
|
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
|
||||||
|
let prompt_target = self
|
||||||
// extract all tool names
|
.prompt_targets
|
||||||
let tool_names: Vec<String> = tool_calls
|
.get(&tools_call_name)
|
||||||
.iter()
|
.expect("prompt target not found for tool call")
|
||||||
.map(|tool_call| tool_call.function.name.clone())
|
.clone();
|
||||||
.collect();
|
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"call context similarity score: {:?}",
|
"prompt_target_name: {}, tool_name(s): {:?}",
|
||||||
callout_context.similarity_scores
|
prompt_target.name,
|
||||||
|
self.tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.map(|tc| tc.function.name.clone())
|
||||||
|
.collect::<Vec<String>>(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// If hallucination, pass chat template to check parameters
|
||||||
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
|
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
|
||||||
let mut tool_params = tool_calls[0].function.arguments.clone();
|
|
||||||
|
let mut tool_params = self.tool_calls.as_ref().unwrap()[0]
|
||||||
|
.function
|
||||||
|
.arguments
|
||||||
|
.clone();
|
||||||
|
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||||
|
debug!(
|
||||||
|
"tool_params (without messages history): {}",
|
||||||
|
tool_params_json_str
|
||||||
|
);
|
||||||
tool_params.insert(
|
tool_params.insert(
|
||||||
String::from(ARCH_MESSAGES_KEY),
|
String::from(ARCH_MESSAGES_KEY),
|
||||||
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
|
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let tools_call_name = tool_calls[0].function.name.clone();
|
|
||||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
|
||||||
callout_context.tool_calls = Some(tool_calls.clone());
|
|
||||||
|
|
||||||
debug!(
|
|
||||||
"prompt_target_name: {}, tool_name(s): {:?}",
|
|
||||||
prompt_target.name, tool_names
|
|
||||||
);
|
|
||||||
debug!("tool_params: {}", tool_params_json_str);
|
|
||||||
|
|
||||||
if model_resp.message.tool_calls.is_some()
|
|
||||||
&& !model_resp.message.tool_calls.as_ref().unwrap().is_empty()
|
|
||||||
{
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
let v: Value = serde_json::from_str(&tool_params_json_str).unwrap();
|
let v: Value = serde_json::from_str(&tool_params_json_str).unwrap();
|
||||||
let tool_params_dict: HashMap<String, String> = match v.as_object() {
|
let tool_params_dict: HashMap<String, String> = match v.as_object() {
|
||||||
|
|
@ -653,7 +717,7 @@ impl StreamContext {
|
||||||
parameters: tool_params_dict,
|
parameters: tool_params_dict,
|
||||||
};
|
};
|
||||||
|
|
||||||
let json_data: String =
|
let hallucination_request_str: String =
|
||||||
match serde_json::to_string(&hallucination_classification_request) {
|
match serde_json::to_string(&hallucination_classification_request) {
|
||||||
Ok(json_data) => json_data,
|
Ok(json_data) => json_data,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
|
|
@ -683,30 +747,27 @@ impl StreamContext {
|
||||||
ARCH_INTERNAL_CLUSTER_NAME,
|
ARCH_INTERNAL_CLUSTER_NAME,
|
||||||
"/hallucination",
|
"/hallucination",
|
||||||
headers,
|
headers,
|
||||||
Some(json_data.as_bytes()),
|
Some(hallucination_request_str.as_bytes()),
|
||||||
vec![],
|
vec![],
|
||||||
Duration::from_secs(5),
|
Duration::from_secs(5),
|
||||||
);
|
);
|
||||||
callout_context.response_handler_type = ResponseHandlerType::HallucinationDetect;
|
callout_context.response_handler_type = ResponseHandlerType::Hallucination;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"archgw => hallucination request: {}",
|
||||||
|
hallucination_request_str
|
||||||
|
);
|
||||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
self.schedule_api_call_request(callout_context);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
|
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
|
||||||
let tools_call_name = callout_context.tool_calls.as_ref().unwrap()[0]
|
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
|
||||||
.function
|
|
||||||
.name
|
|
||||||
.clone();
|
|
||||||
|
|
||||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
||||||
|
|
||||||
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
|
let mut tool_params = self.tool_calls.as_ref().unwrap()[0]
|
||||||
let mut tool_params = callout_context.tool_calls.as_ref().unwrap()[0]
|
|
||||||
.function
|
.function
|
||||||
.arguments
|
.arguments
|
||||||
.clone();
|
.clone();
|
||||||
|
|
@ -741,8 +802,16 @@ impl StreamContext {
|
||||||
vec![],
|
vec![],
|
||||||
Duration::from_secs(5),
|
Duration::from_secs(5),
|
||||||
);
|
);
|
||||||
callout_context.upstream_cluster = Some(endpoint.name.clone());
|
|
||||||
callout_context.upstream_cluster_path = Some(path.clone());
|
debug!(
|
||||||
|
"archgw => api call, endpoint: {}/{}, body: {}",
|
||||||
|
endpoint.name.as_str(),
|
||||||
|
path,
|
||||||
|
tool_params_json_str
|
||||||
|
);
|
||||||
|
|
||||||
|
callout_context.upstream_cluster = Some(endpoint.name.to_owned());
|
||||||
|
callout_context.upstream_cluster_path = Some(path.to_owned());
|
||||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||||
|
|
||||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||||
|
|
@ -750,14 +819,15 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn function_call_response_handler(
|
pub fn api_call_response_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||||
&mut self,
|
let http_status = self
|
||||||
body: Vec<u8>,
|
.get_http_call_response_header(":status")
|
||||||
callout_context: StreamCallContext,
|
.expect("http status code not found");
|
||||||
) {
|
|
||||||
if let Some(http_status) = self.get_http_call_response_header(":status") {
|
|
||||||
if http_status != StatusCode::OK.as_str() {
|
if http_status != StatusCode::OK.as_str() {
|
||||||
debug!("upstream error response: {}", http_status);
|
warn!(
|
||||||
|
"api server responded with non 2xx status code: {}",
|
||||||
|
http_status
|
||||||
|
);
|
||||||
return self.send_server_error(
|
return self.send_server_error(
|
||||||
ServerError::Upstream {
|
ServerError::Upstream {
|
||||||
host: callout_context.upstream_cluster.unwrap(),
|
host: callout_context.upstream_cluster.unwrap(),
|
||||||
|
|
@ -768,14 +838,10 @@ impl StreamContext {
|
||||||
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
|
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else {
|
self.tool_call_response = Some(String::from_utf8(body).unwrap());
|
||||||
warn!("http status code not found in api response");
|
|
||||||
}
|
|
||||||
let app_function_call_response_str: String = String::from_utf8(body).unwrap();
|
|
||||||
self.tool_call_response = Some(app_function_call_response_str.clone());
|
|
||||||
debug!(
|
debug!(
|
||||||
"arch <= app response body: {}",
|
"archgw <= api call response: {}",
|
||||||
app_function_call_response_str
|
self.tool_call_response.as_ref().unwrap()
|
||||||
);
|
);
|
||||||
let prompt_target_name = callout_context.prompt_target_name.unwrap();
|
let prompt_target_name = callout_context.prompt_target_name.unwrap();
|
||||||
let prompt_target = self
|
let prompt_target = self
|
||||||
|
|
@ -825,7 +891,7 @@ impl StreamContext {
|
||||||
let final_prompt = format!(
|
let final_prompt = format!(
|
||||||
"{}\ncontext: {}",
|
"{}\ncontext: {}",
|
||||||
user_message.content.unwrap(),
|
user_message.content.unwrap(),
|
||||||
app_function_call_response_str
|
self.tool_call_response.as_ref().unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
// add original user prompt
|
// add original user prompt
|
||||||
|
|
@ -848,22 +914,24 @@ impl StreamContext {
|
||||||
metadata: None,
|
metadata: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let json_string = match serde_json::to_string(&chat_completions_request) {
|
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
|
||||||
Ok(json_string) => json_string,
|
Ok(json_string) => json_string,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return self.send_server_error(ServerError::Serialization(e), None);
|
return self.send_server_error(ServerError::Serialization(e), None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
debug!("arch => upstream llm request body: {}", json_string);
|
debug!("archgw => llm request: {}", llm_request_str);
|
||||||
|
|
||||||
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
|
self.set_http_request_body(0, self.request_body_size, &llm_request_str.into_bytes());
|
||||||
self.resume_http_request();
|
self.resume_http_request();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||||
debug!("response received for arch guard");
|
|
||||||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||||
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
|
debug!(
|
||||||
|
"archgw <= archguard response: {:?}",
|
||||||
|
serde_json::to_string(&prompt_guard_resp)
|
||||||
|
);
|
||||||
|
|
||||||
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
|
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
|
||||||
//TODO: handle other scenarios like forward to error target
|
//TODO: handle other scenarios like forward to error target
|
||||||
|
|
@ -871,7 +939,7 @@ impl StreamContext {
|
||||||
.prompt_guards
|
.prompt_guards
|
||||||
.jailbreak_on_exception_message()
|
.jailbreak_on_exception_message()
|
||||||
.unwrap_or("refrain from discussing jailbreaking.");
|
.unwrap_or("refrain from discussing jailbreaking.");
|
||||||
debug!("jailbreak detected: {}", msg);
|
warn!("jailbreak detected: {}", msg);
|
||||||
return self.send_server_error(
|
return self.send_server_error(
|
||||||
ServerError::Jailbreak(String::from(msg)),
|
ServerError::Jailbreak(String::from(msg)),
|
||||||
Some(StatusCode::BAD_REQUEST),
|
Some(StatusCode::BAD_REQUEST),
|
||||||
|
|
@ -881,92 +949,27 @@ impl StreamContext {
|
||||||
self.get_embeddings(callout_context);
|
self.get_embeddings(callout_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_embeddings(&mut self, callout_context: StreamCallContext) {
|
|
||||||
let user_message = callout_context.user_message.unwrap();
|
|
||||||
let get_embeddings_input = CreateEmbeddingRequest {
|
|
||||||
// Need to clone into input because user_message is used below.
|
|
||||||
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
|
|
||||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
|
||||||
encoding_format: None,
|
|
||||||
dimensions: None,
|
|
||||||
user: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
|
|
||||||
Ok(json_data) => json_data,
|
|
||||||
Err(error) => {
|
|
||||||
debug!("error serializing get embeddings request: {}", error);
|
|
||||||
return self.send_server_error(ServerError::Deserialization(error), None);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut headers = vec![
|
|
||||||
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
|
|
||||||
(":method", "POST"),
|
|
||||||
(":path", "/embeddings"),
|
|
||||||
(":authority", EMBEDDINGS_INTERNAL_HOST),
|
|
||||||
("content-type", "application/json"),
|
|
||||||
("x-envoy-max-retries", "3"),
|
|
||||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
|
||||||
];
|
|
||||||
if self.request_id.is_some() {
|
|
||||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
|
||||||
}
|
|
||||||
let call_args = CallArgs::new(
|
|
||||||
ARCH_INTERNAL_CLUSTER_NAME,
|
|
||||||
"/embeddings",
|
|
||||||
headers,
|
|
||||||
Some(json_data.as_bytes()),
|
|
||||||
vec![],
|
|
||||||
Duration::from_secs(5),
|
|
||||||
);
|
|
||||||
let call_context = StreamCallContext {
|
|
||||||
response_handler_type: ResponseHandlerType::GetEmbeddings,
|
|
||||||
user_message: Some(user_message),
|
|
||||||
prompt_target_name: None,
|
|
||||||
request_body: callout_context.request_body,
|
|
||||||
similarity_scores: None,
|
|
||||||
upstream_cluster: None,
|
|
||||||
upstream_cluster_path: None,
|
|
||||||
tool_calls: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(e) = self.http_call(call_args, call_context) {
|
|
||||||
debug!("error dispatching get embeddings request: {}", e);
|
|
||||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
|
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||||
let prompt_target = self
|
let prompt_target = self
|
||||||
.prompt_targets
|
.prompt_targets
|
||||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.clone();
|
.clone();
|
||||||
debug!(
|
|
||||||
"response received for default target: {}",
|
|
||||||
prompt_target.name
|
|
||||||
);
|
|
||||||
// check if the default target should be dispatched to the LLM provider
|
// check if the default target should be dispatched to the LLM provider
|
||||||
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
|
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
|
||||||
let default_target_response_str = String::from_utf8(body).unwrap();
|
let default_target_response_str = String::from_utf8(body).unwrap();
|
||||||
debug!(
|
|
||||||
"sending response back to developer: {}",
|
|
||||||
default_target_response_str
|
|
||||||
);
|
|
||||||
self.send_http_response(
|
self.send_http_response(
|
||||||
StatusCode::OK.as_u16().into(),
|
StatusCode::OK.as_u16().into(),
|
||||||
vec![("Powered-By", "Katanemo")],
|
vec![("Powered-By", "Katanemo")],
|
||||||
Some(default_target_response_str.as_bytes()),
|
Some(default_target_response_str.as_bytes()),
|
||||||
);
|
);
|
||||||
// self.resume_http_request();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
debug!("default_target: sending api response to default llm");
|
|
||||||
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
|
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
|
||||||
Ok(chat_completions_resp) => chat_completions_resp,
|
Ok(chat_completions_resp) => chat_completions_resp,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("error deserializing default target response: {}", e);
|
warn!("error deserializing default target response: {}", e);
|
||||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -1000,7 +1003,12 @@ impl StreamContext {
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
});
|
});
|
||||||
let chat_completion_request = ChatCompletionsRequest {
|
let chat_completion_request = ChatCompletionsRequest {
|
||||||
model: GPT_35_TURBO.to_string(),
|
model: self
|
||||||
|
.chat_completions_request
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.model
|
||||||
|
.clone(),
|
||||||
messages,
|
messages,
|
||||||
tools: None,
|
tools: None,
|
||||||
stream: callout_context.request_body.stream,
|
stream: callout_context.request_body.stream,
|
||||||
|
|
@ -1008,7 +1016,7 @@ impl StreamContext {
|
||||||
metadata: None,
|
metadata: None,
|
||||||
};
|
};
|
||||||
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
|
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
|
||||||
debug!("sending response back to default llm: {}", json_resp);
|
debug!("archgw => (default target) llm request: {}", json_resp);
|
||||||
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
|
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
|
||||||
self.resume_http_request();
|
self.resume_http_request();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
||||||
.returning(Some("/v1/chat/completions"))
|
.returning(Some("/v1/chat/completions"))
|
||||||
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
|
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
|
||||||
.returning(None)
|
.returning(None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
|
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
|
||||||
.returning(None)
|
.returning(None)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
|
|
@ -74,7 +74,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
.returning(Some(chat_completions_request_body))
|
.returning(Some(chat_completions_request_body))
|
||||||
// The actual call is not important in this test, we just need to grab the token_id
|
// The actual call is not important in this test, we just need to grab the token_id
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_http_call(
|
.expect_http_call(
|
||||||
Some("arch_internal"),
|
Some("arch_internal"),
|
||||||
Some(vec![
|
Some(vec![
|
||||||
|
|
@ -92,6 +92,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||||
)
|
)
|
||||||
.returning(Some(1))
|
.returning(Some(1))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_metric_increment("active_http_calls", 1)
|
.expect_metric_increment("active_http_calls", 1)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -116,6 +117,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||||
.returning(Some(&prompt_guard_response_buffer))
|
.returning(Some(&prompt_guard_response_buffer))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_http_call(
|
.expect_http_call(
|
||||||
Some("arch_internal"),
|
Some("arch_internal"),
|
||||||
Some(vec![
|
Some(vec![
|
||||||
|
|
@ -133,7 +135,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||||
)
|
)
|
||||||
.returning(Some(2))
|
.returning(Some(2))
|
||||||
.expect_metric_increment("active_http_calls", 1)
|
.expect_metric_increment("active_http_calls", 1)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.execute_and_expect(ReturnType::None)
|
.execute_and_expect(ReturnType::None)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -159,8 +160,9 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||||
.expect_metric_increment("active_http_calls", -1)
|
.expect_metric_increment("active_http_calls", -1)
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||||
.returning(Some(&embeddings_response_buffer))
|
.returning(Some(&embeddings_response_buffer))
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_http_call(
|
.expect_http_call(
|
||||||
Some("arch_internal"),
|
Some("arch_internal"),
|
||||||
Some(vec![
|
Some(vec![
|
||||||
|
|
@ -178,7 +180,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||||
)
|
)
|
||||||
.returning(Some(3))
|
.returning(Some(3))
|
||||||
.expect_metric_increment("active_http_calls", 1)
|
.expect_metric_increment("active_http_calls", 1)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.execute_and_expect(ReturnType::None)
|
.execute_and_expect(ReturnType::None)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -200,9 +201,10 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||||
.expect_metric_increment("active_http_calls", -1)
|
.expect_metric_increment("active_http_calls", -1)
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||||
.returning(Some(&zeroshot_intent_detection_buffer))
|
.returning(Some(&zeroshot_intent_detection_buffer))
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Info), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_http_call(
|
.expect_http_call(
|
||||||
Some("arch_internal"),
|
Some("arch_internal"),
|
||||||
Some(vec![
|
Some(vec![
|
||||||
|
|
@ -219,8 +221,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.returning(Some(4))
|
.returning(Some(4))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.expect_metric_increment("active_http_calls", 1)
|
.expect_metric_increment("active_http_calls", 1)
|
||||||
.execute_and_expect(ReturnType::None)
|
.execute_and_expect(ReturnType::None)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -245,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
||||||
module
|
module
|
||||||
.call_proxy_on_tick(filter_context)
|
.call_proxy_on_tick(filter_context)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_http_call(
|
.expect_http_call(
|
||||||
Some("arch_internal"),
|
Some("arch_internal"),
|
||||||
Some(vec![
|
Some(vec![
|
||||||
|
|
@ -426,8 +426,9 @@ fn successful_request_to_open_ai_chat_completions() {
|
||||||
)
|
)
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
.returning(Some(chat_completions_request_body))
|
.returning(Some(chat_completions_request_body))
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_http_call(Some("arch_internal"), None, None, None, None)
|
.expect_http_call(Some("arch_internal"), None, None, None, None)
|
||||||
.returning(Some(4))
|
.returning(Some(4))
|
||||||
.expect_metric_increment("active_http_calls", 1)
|
.expect_metric_increment("active_http_calls", 1)
|
||||||
|
|
@ -486,13 +487,14 @@ fn bad_request_to_open_ai_chat_completions() {
|
||||||
)
|
)
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
.returning(Some(incomplete_chat_completions_request_body))
|
.returning(Some(incomplete_chat_completions_request_body))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_send_local_response(
|
.expect_send_local_response(
|
||||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -564,7 +566,7 @@ fn request_to_llm_gateway() {
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_http_call(
|
.expect_http_call(
|
||||||
Some("arch_internal"),
|
Some("arch_internal"),
|
||||||
Some(vec![
|
Some(vec![
|
||||||
|
|
@ -603,6 +605,8 @@ fn request_to_llm_gateway() {
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||||
.returning(Some(&body_text))
|
.returning(Some(&body_text))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_http_call(
|
.expect_http_call(
|
||||||
Some("arch_internal"),
|
Some("arch_internal"),
|
||||||
Some(vec![
|
Some(vec![
|
||||||
|
|
@ -628,10 +632,10 @@ fn request_to_llm_gateway() {
|
||||||
.expect_metric_increment("active_http_calls", -1)
|
.expect_metric_increment("active_http_calls", -1)
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||||
.returning(Some(&body_text))
|
.returning(Some(&body_text))
|
||||||
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
||||||
.returning(Some("200"))
|
.returning(Some("200"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||||
.execute_and_expect(ReturnType::None)
|
.execute_and_expect(ReturnType::None)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -664,11 +668,11 @@ fn request_to_llm_gateway() {
|
||||||
)
|
)
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
|
||||||
.returning(Some(chat_completion_response_str.as_str()))
|
.returning(Some(chat_completion_response_str.as_str()))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
|
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
|
.expect_log(Some(LogLevel::Trace), None)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue