mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Send back developer error correctly (#195)
This commit is contained in:
parent
32eeddade3
commit
cb74e8ffe2
3 changed files with 38 additions and 13 deletions
|
|
@ -336,10 +336,9 @@ impl HttpContext for StreamContext {
|
|||
let chat_completions_response: ChatCompletionsResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(de) => de,
|
||||
Err(e) => {
|
||||
Err(_e) => {
|
||||
debug!("invalid response: {}", String::from_utf8_lossy(&body));
|
||||
self.send_server_error(ServerError::Deserialization(e), None);
|
||||
return Action::Pause;
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
|||
use acap::cos;
|
||||
use common::common_types::open_ai::{
|
||||
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
|
||||
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType,
|
||||
StreamOptions, ToolCall, ToolCallState, ToolType,
|
||||
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, ToolCall,
|
||||
ToolCallState, ToolType,
|
||||
};
|
||||
use common::common_types::{
|
||||
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
|
||||
|
|
@ -33,6 +33,7 @@ use sha2::{Digest, Sha256};
|
|||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
|
||||
use common::stats::IncrementingMetric;
|
||||
|
|
@ -70,11 +71,12 @@ pub enum ServerError {
|
|||
Serialization(serde_json::Error),
|
||||
#[error("{0}")]
|
||||
LogicError(String),
|
||||
#[error("upstream error response authority={authority}, path={path}, status={status}")]
|
||||
#[error("upstream application error host={host}, path={path}, status={status}, body={body}")]
|
||||
Upstream {
|
||||
authority: String,
|
||||
host: String,
|
||||
path: String,
|
||||
status: String,
|
||||
body: String,
|
||||
},
|
||||
#[error("jailbreak detected: {0}")]
|
||||
Jailbreak(String),
|
||||
|
|
@ -149,7 +151,6 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
|
||||
debug!("server error occurred: {}", error);
|
||||
self.send_http_response(
|
||||
override_status_code
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
|
|
@ -164,6 +165,7 @@ impl StreamContext {
|
|||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(e) => {
|
||||
debug!("error deserializing embedding response: {}", e);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -234,6 +236,10 @@ impl StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
debug!(
|
||||
"error serializing zero shot classification request: {}",
|
||||
error
|
||||
);
|
||||
return self.send_server_error(ServerError::Serialization(error), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -263,6 +269,7 @@ impl StreamContext {
|
|||
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
debug!("error dispatching zero shot classification request: {}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
}
|
||||
|
|
@ -276,6 +283,7 @@ impl StreamContext {
|
|||
match serde_json::from_slice(&body) {
|
||||
Ok(hallucination_response) => hallucination_response,
|
||||
Err(e) => {
|
||||
debug!("error deserializing hallucination response: {}", e);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -339,6 +347,10 @@ impl StreamContext {
|
|||
match serde_json::from_slice(&body) {
|
||||
Ok(zeroshot_response) => zeroshot_response,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"error deserializing zero shot classification response: {}",
|
||||
e
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -450,6 +462,7 @@ impl StreamContext {
|
|||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
debug!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
|
|
@ -465,6 +478,7 @@ impl StreamContext {
|
|||
let prompt_target = match self.prompt_targets.get(&prompt_target_name) {
|
||||
Some(prompt_target) => prompt_target.clone(),
|
||||
None => {
|
||||
debug!("prompt target not found: {}", prompt_target_name);
|
||||
return self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Prompt target not found: {prompt_target_name}"
|
||||
|
|
@ -537,6 +551,7 @@ impl StreamContext {
|
|||
msg_body
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("error serializing arch_fc request body: {}", e);
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -569,6 +584,7 @@ impl StreamContext {
|
|||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
debug!("error dispatching arch_fc request: {}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
|
|
@ -580,6 +596,7 @@ impl StreamContext {
|
|||
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
||||
Ok(arch_fc_response) => arch_fc_response,
|
||||
Err(e) => {
|
||||
debug!("error deserializing arch_fc response: {}", e);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -693,6 +710,10 @@ impl StreamContext {
|
|||
match serde_json::to_string(&hallucination_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
debug!(
|
||||
"error serializing hallucination classification request: {}",
|
||||
error
|
||||
);
|
||||
return self.send_server_error(ServerError::Serialization(error), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -789,13 +810,15 @@ impl StreamContext {
|
|||
) {
|
||||
if let Some(http_status) = self.get_http_call_response_header(":status") {
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
debug!("upstream error response: {}", http_status);
|
||||
return self.send_server_error(
|
||||
ServerError::Upstream {
|
||||
authority: callout_context.upstream_cluster.unwrap(),
|
||||
host: callout_context.upstream_cluster.unwrap(),
|
||||
path: callout_context.upstream_cluster_path.unwrap(),
|
||||
status: http_status,
|
||||
status: http_status.clone(),
|
||||
body: String::from_utf8(body).unwrap(),
|
||||
},
|
||||
None,
|
||||
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -893,6 +916,7 @@ impl StreamContext {
|
|||
.prompt_guards
|
||||
.jailbreak_on_exception_message()
|
||||
.unwrap_or("refrain from discussing jailbreaking.");
|
||||
debug!("jailbreak detected: {}", msg);
|
||||
return self.send_server_error(
|
||||
ServerError::Jailbreak(String::from(msg)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
|
|
@ -916,6 +940,7 @@ impl StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
debug!("error serializing get embeddings request: {}", error);
|
||||
return self.send_server_error(ServerError::Deserialization(error), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -952,6 +977,7 @@ impl StreamContext {
|
|||
};
|
||||
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
debug!("error dispatching get embeddings request: {}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
}
|
||||
|
|
@ -985,6 +1011,7 @@ impl StreamContext {
|
|||
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
|
||||
Ok(chat_completions_resp) => chat_completions_resp,
|
||||
Err(e) => {
|
||||
debug!("error deserializing default target response: {}", e);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -1071,7 +1098,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let mut deserialized_body: ChatCompletionsRequest =
|
||||
let deserialized_body: ChatCompletionsRequest =
|
||||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
|
|
|
|||
|
|
@ -487,7 +487,6 @@ fn bad_request_to_open_ai_chat_completions() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(incomplete_chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||
None,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue