mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 21:02:56 +02:00
Implement Client trait for StreamContext (#134)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
5bfccd3959
commit
c1cfbcd44d
7 changed files with 216 additions and 218 deletions
|
|
@ -5,6 +5,7 @@ use crate::consts::{
|
|||
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
||||
use crate::http::{CallArgs, Client, ClientError};
|
||||
use crate::llm_providers::LlmProviders;
|
||||
use crate::ratelimit::Header;
|
||||
use crate::stats::IncrementingMetric;
|
||||
|
|
@ -30,11 +31,13 @@ use public_types::embeddings::{
|
|||
};
|
||||
use serde_json::Value;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ResponseHandlerType {
|
||||
GetEmbeddings,
|
||||
FunctionResolver,
|
||||
|
|
@ -44,7 +47,8 @@ enum ResponseHandlerType {
|
|||
DefaultTarget,
|
||||
}
|
||||
|
||||
pub struct CallContext {
|
||||
#[derive(Debug)]
|
||||
pub struct StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType,
|
||||
user_message: Option<String>,
|
||||
prompt_target_name: Option<String>,
|
||||
|
|
@ -54,13 +58,37 @@ pub struct CallContext {
|
|||
upstream_cluster_path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ServerError {
|
||||
#[error(transparent)]
|
||||
HttpDispatch(ClientError),
|
||||
#[error(transparent)]
|
||||
Deserialization(serde_json::Error),
|
||||
#[error(transparent)]
|
||||
Serialization(serde_json::Error),
|
||||
#[error("{0}")]
|
||||
LogicError(String),
|
||||
#[error("upstream error response authority={authority}, path={path}, status={status}")]
|
||||
Upstream {
|
||||
authority: String,
|
||||
path: String,
|
||||
status: String,
|
||||
},
|
||||
#[error(transparent)]
|
||||
ExceededRatelimit(ratelimit::Error),
|
||||
#[error("jailbreak detected: {0}")]
|
||||
Jailbreak(String),
|
||||
#[error("{why}")]
|
||||
BadRequest { why: String },
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
embeddings_store: Rc<EmbeddingsStore>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
tool_call_response: Option<String>,
|
||||
arch_state: Option<Vec<ArchState>>,
|
||||
|
|
@ -91,8 +119,8 @@ impl StreamContext {
|
|||
metrics,
|
||||
prompt_targets,
|
||||
embeddings_store,
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
chat_completions_request: None,
|
||||
callouts: HashMap::new(),
|
||||
tool_calls: None,
|
||||
tool_call_response: None,
|
||||
arch_state: None,
|
||||
|
|
@ -129,11 +157,17 @@ impl StreamContext {
|
|||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), String> {
|
||||
let llm_provider_api_key_value = self.llm_provider().access_key.as_ref().ok_or(format!(
|
||||
"No access key configured for selected LLM Provider \"{}\"",
|
||||
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
|
||||
let llm_provider_api_key_value =
|
||||
self.llm_provider()
|
||||
))?;
|
||||
.access_key
|
||||
.as_ref()
|
||||
.ok_or(ServerError::BadRequest {
|
||||
why: format!(
|
||||
"No access key configured for selected LLM Provider \"{}\"",
|
||||
self.llm_provider()
|
||||
),
|
||||
})?;
|
||||
|
||||
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
|
||||
|
||||
|
|
@ -159,7 +193,7 @@ impl StreamContext {
|
|||
});
|
||||
}
|
||||
|
||||
fn send_server_error(&self, error: String, override_status_code: Option<StatusCode>) {
|
||||
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
|
||||
debug!("server error occurred: {}", error);
|
||||
self.send_http_response(
|
||||
override_status_code
|
||||
|
|
@ -167,18 +201,15 @@ impl StreamContext {
|
|||
.as_u16()
|
||||
.into(),
|
||||
vec![],
|
||||
Some(error.as_bytes()),
|
||||
Some(format!("{error}").as_bytes()),
|
||||
);
|
||||
}
|
||||
|
||||
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
format!("Error deserializing embedding response: {:?}", e),
|
||||
None,
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -248,13 +279,13 @@ impl StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
let error = format!("Error serializing zero shot request: {}", error);
|
||||
return self.send_server_error(error, None);
|
||||
return self.send_server_error(ServerError::Serialization(error), None);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
let call_args = CallArgs::new(
|
||||
MODEL_SERVER_NAME,
|
||||
"/zeroshot",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/zeroshot"),
|
||||
|
|
@ -266,49 +297,24 @@ impl StreamContext {
|
|||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg = format!(
|
||||
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
|
||||
e
|
||||
);
|
||||
return self.send_server_error(error_msg, None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"dispatched call to model_server/zeroshot token_id={}",
|
||||
token_id
|
||||
);
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
|
||||
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
}
|
||||
|
||||
fn zero_shot_intent_detection_resp_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
mut callout_context: CallContext,
|
||||
mut callout_context: StreamCallContext,
|
||||
) {
|
||||
let zeroshot_intent_response: ZeroShotClassificationResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(zeroshot_response) => zeroshot_response,
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
format!(
|
||||
"Error deserializing zeroshot intent detection response: {:?}",
|
||||
e
|
||||
),
|
||||
None,
|
||||
);
|
||||
return;
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -390,40 +396,34 @@ impl StreamContext {
|
|||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
debug!("no prompt target found with similarity score above threshold, using default prompt target");
|
||||
let token_id = match self.dispatch_http_call(
|
||||
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
let call_args = CallArgs::new(
|
||||
&upstream_endpoint,
|
||||
&upstream_path,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
(
|
||||
"x-envoy-upstream-rq-timeout-ms",
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
|
||||
),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
],
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg =
|
||||
format!("Error dispatching HTTP call for default-target: {:?}", e);
|
||||
return self
|
||||
.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
|
|
@ -433,7 +433,9 @@ impl StreamContext {
|
|||
Some(prompt_target) => prompt_target.clone(),
|
||||
None => {
|
||||
return self.send_server_error(
|
||||
format!("Prompt target not found: {}", prompt_target_name),
|
||||
ServerError::LogicError(format!(
|
||||
"Prompt target not found: {prompt_target_name}"
|
||||
)),
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
|
@ -499,62 +501,42 @@ impl StreamContext {
|
|||
msg_body
|
||||
}
|
||||
Err(e) => {
|
||||
return self
|
||||
.send_server_error(format!("Error serializing request_params: {:?}", e), None);
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
let call_args = CallArgs::new(
|
||||
ARC_FC_CLUSTER,
|
||||
"/v1/chat/completions",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", ARC_FC_CLUSTER),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
(
|
||||
"x-envoy-upstream-rq-timeout-ms",
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
|
||||
),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
],
|
||||
Some(msg_body.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg = format!("Error dispatching HTTP call for function-call: {:?}", e);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"dispatched call to function {} token_id={}",
|
||||
ARC_FC_CLUSTER, token_id
|
||||
);
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
|
||||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
|
||||
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
||||
let body_str = String::from_utf8(body).unwrap();
|
||||
debug!("arch <= app response body: {}", body_str);
|
||||
|
||||
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
||||
Ok(arch_fc_response) => arch_fc_response,
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
format!(
|
||||
"Error deserializing function resolver response into ChatCompletion: {:?}",
|
||||
e
|
||||
),
|
||||
None,
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -607,11 +589,12 @@ impl StreamContext {
|
|||
|
||||
let endpoint = prompt_target.endpoint.unwrap();
|
||||
let path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
let token_id = match self.dispatch_http_call(
|
||||
let call_args = CallArgs::new(
|
||||
&endpoint.name,
|
||||
&path,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", path.as_ref()),
|
||||
(":path", &path),
|
||||
(":authority", endpoint.name.as_str()),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
|
|
@ -619,39 +602,33 @@ impl StreamContext {
|
|||
Some(tool_params_json_str.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg = format!(
|
||||
"Error dispatching call to cluster: {}, path: {}, err: {:?}",
|
||||
&endpoint.name, path, e
|
||||
);
|
||||
debug!("{}", error_msg);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
);
|
||||
callout_context.upstream_cluster = Some(endpoint.name.clone());
|
||||
callout_context.upstream_cluster_path = Some(path.clone());
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
|
||||
self.tool_calls = Some(tool_calls.clone());
|
||||
callout_context.upstream_cluster = Some(endpoint.name);
|
||||
callout_context.upstream_cluster_path = Some(path);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn function_call_response_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
|
||||
let headers = self.get_http_call_response_headers();
|
||||
if let Some(http_status) = headers.iter().find(|(key, _)| key == ":status") {
|
||||
if http_status.1 != StatusCode::OK.as_str() {
|
||||
let error_msg = format!(
|
||||
"Error in function call response: cluster: {}, path: {}, status code: {}",
|
||||
callout_context.upstream_cluster.unwrap(),
|
||||
callout_context.upstream_cluster_path.unwrap(),
|
||||
http_status.1
|
||||
fn function_call_response_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
callout_context: StreamCallContext,
|
||||
) {
|
||||
if let Some(http_status) = self.get_http_call_response_header(":status") {
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
return self.send_server_error(
|
||||
ServerError::Upstream {
|
||||
authority: callout_context.upstream_cluster.unwrap(),
|
||||
path: callout_context.upstream_cluster_path.unwrap(),
|
||||
status: http_status,
|
||||
},
|
||||
None,
|
||||
);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
} else {
|
||||
warn!("http status code not found in api response");
|
||||
|
|
@ -714,8 +691,7 @@ impl StreamContext {
|
|||
let json_string = match serde_json::to_string(&chat_completions_request) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
return self
|
||||
.send_server_error(format!("Error serializing request_body: {:?}", e), None);
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
debug!("arch => openai request body: {}", json_string);
|
||||
|
|
@ -733,7 +709,7 @@ impl StreamContext {
|
|||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
self.send_server_error(
|
||||
format!("Exceeded Ratelimit: {}", err),
|
||||
ServerError::ExceededRatelimit(err),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
);
|
||||
self.metrics.ratelimited_rq.increment(1);
|
||||
|
|
@ -747,7 +723,7 @@ impl StreamContext {
|
|||
self.resume_http_request();
|
||||
}
|
||||
|
||||
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
|
||||
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
debug!("response received for arch guard");
|
||||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
|
||||
|
|
@ -757,14 +733,17 @@ impl StreamContext {
|
|||
let msg = self
|
||||
.prompt_guards
|
||||
.jailbreak_on_exception_message()
|
||||
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
|
||||
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
|
||||
.unwrap_or("refrain from discussing jailbreaking.");
|
||||
return self.send_server_error(
|
||||
ServerError::Jailbreak(String::from(msg)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
|
||||
self.get_embeddings(callout_context);
|
||||
}
|
||||
|
||||
fn get_embeddings(&mut self, callout_context: CallContext) {
|
||||
fn get_embeddings(&mut self, callout_context: StreamCallContext) {
|
||||
let user_message = callout_context.user_message.unwrap();
|
||||
let get_embeddings_input = CreateEmbeddingRequest {
|
||||
// Need to clone into input because user_message is used below.
|
||||
|
|
@ -778,13 +757,13 @@ impl StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
let error_msg = format!("Error serializing embeddings input: {}", error);
|
||||
return self.send_server_error(error_msg, None);
|
||||
return self.send_server_error(ServerError::Deserialization(error), None);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
let call_args = CallArgs::new(
|
||||
MODEL_SERVER_NAME,
|
||||
"/embeddings",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
|
|
@ -796,19 +775,8 @@ impl StreamContext {
|
|||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg = format!("dispatched call to model_server/embeddings: {:?}", e);
|
||||
return self.send_server_error(error_msg, None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"dispatched call to model_server/embeddings token_id={}",
|
||||
token_id
|
||||
);
|
||||
|
||||
let call_context = CallContext {
|
||||
let call_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::GetEmbeddings,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
|
|
@ -817,17 +785,13 @@ impl StreamContext {
|
|||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
}
|
||||
|
||||
fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
|
||||
fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
|
|
@ -856,10 +820,7 @@ impl StreamContext {
|
|||
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
|
||||
Ok(chat_completions_resp) => chat_completions_resp,
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
format!("Error deserializing default target response: {:?}", e),
|
||||
None,
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
let api_resp = chat_completions_resp.choices[0]
|
||||
|
|
@ -948,9 +909,9 @@ impl HttpContext for StreamContext {
|
|||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
format!("Failed to deserialize: {}", msg),
|
||||
ServerError::Deserialization(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
|
|
@ -958,10 +919,10 @@ impl HttpContext for StreamContext {
|
|||
},
|
||||
None => {
|
||||
self.send_server_error(
|
||||
format!(
|
||||
ServerError::LogicError(format!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
),
|
||||
)),
|
||||
None,
|
||||
);
|
||||
return Action::Pause;
|
||||
|
|
@ -1018,7 +979,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
if !prompt_guard_jailbreak_task {
|
||||
debug!("Missing input guard. Making inline call to retrieve");
|
||||
let callout_context = CallContext {
|
||||
let callout_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: user_message_str.clone(),
|
||||
prompt_target_name: None,
|
||||
|
|
@ -1046,14 +1007,14 @@ impl HttpContext for StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
let error_msg = format!("Error serializing prompt guard request: {}", error);
|
||||
self.send_server_error(error_msg, None);
|
||||
self.send_server_error(ServerError::Serialization(error), None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
let call_args = CallArgs::new(
|
||||
MODEL_SERVER_NAME,
|
||||
"/guard",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/guard"),
|
||||
|
|
@ -1065,21 +1026,8 @@ impl HttpContext for StreamContext {
|
|||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
let error_msg = format!(
|
||||
"Error dispatching embedding server HTTP call for prompt-guard: {:?}",
|
||||
e
|
||||
);
|
||||
self.send_server_error(error_msg, None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("dispatched HTTP call to arch_guard token_id={}", token_id);
|
||||
|
||||
let call_context = CallContext {
|
||||
);
|
||||
let call_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
|
||||
prompt_target_name: None,
|
||||
|
|
@ -1088,14 +1036,10 @@ impl HttpContext for StreamContext {
|
|||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
|
||||
Action::Pause
|
||||
}
|
||||
|
|
@ -1130,7 +1074,10 @@ impl HttpContext for StreamContext {
|
|||
let chat_completions_data = match body_str.split_once("data: ") {
|
||||
Some((_, chat_completions_data)) => chat_completions_data,
|
||||
None => {
|
||||
self.send_server_error(String::from("parsing error in streaming data"), None);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(String::from("parsing error in streaming data")),
|
||||
None,
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
|
@ -1141,7 +1088,9 @@ impl HttpContext for StreamContext {
|
|||
Err(_) => {
|
||||
if chat_completions_data != "[NONE]" {
|
||||
self.send_server_error(
|
||||
String::from("error in streaming response"),
|
||||
ServerError::LogicError(String::from(
|
||||
"error in streaming response",
|
||||
)),
|
||||
None,
|
||||
);
|
||||
return Action::Continue;
|
||||
|
|
@ -1168,14 +1117,7 @@ impl HttpContext for StreamContext {
|
|||
match serde_json::from_slice(&body) {
|
||||
Ok(de) => de,
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
format!(
|
||||
"error in non-streaming response: {}\n response was={}",
|
||||
e,
|
||||
String::from_utf8(body).unwrap()
|
||||
),
|
||||
None,
|
||||
);
|
||||
self.send_server_error(ServerError::Deserialization(e), None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
|
@ -1260,7 +1202,11 @@ impl Context for StreamContext {
|
|||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_context = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
let callout_context = self
|
||||
.callouts
|
||||
.get_mut()
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id");
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
|
|
@ -1284,9 +1230,21 @@ impl Context for StreamContext {
|
|||
}
|
||||
} else {
|
||||
self.send_server_error(
|
||||
String::from("No response body in inline HTTP request"),
|
||||
ServerError::LogicError(String::from("No response body in inline HTTP request")),
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for StreamContext {
|
||||
type CallContext = StreamCallContext;
|
||||
|
||||
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
|
||||
&self.callouts
|
||||
}
|
||||
|
||||
fn active_http_calls(&self) -> &crate::stats::Gauge {
|
||||
&self.metrics.active_http_calls
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue