mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add more tests
This commit is contained in:
parent
f5ecf733ff
commit
e35560623c
7 changed files with 171 additions and 49 deletions
|
|
@ -42,6 +42,8 @@ pub mod open_ai {
|
|||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
|
||||
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
#[serde(default)]
|
||||
|
|
@ -242,6 +244,28 @@ pub mod open_ai {
|
|||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
// create constructor for ChatCompletionsResponse
|
||||
impl ChatCompletionsResponse {
|
||||
pub fn new(message: String) -> Self {
|
||||
ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: Some(message),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
index: 0,
|
||||
finish_reason: "done".to_string(),
|
||||
}],
|
||||
usage: None,
|
||||
model: ARCH_FC_MODEL_NAME.to_string(),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub completion_tokens: usize,
|
||||
|
|
@ -254,6 +278,24 @@ pub mod open_ai {
|
|||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
impl ChatCompletionStreamResponse {
|
||||
pub fn new(response: Option<String>, role: Option<String>) -> Self {
|
||||
ChatCompletionStreamResponse {
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role,
|
||||
content: response,
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChatCompletionChunkResponseError {
|
||||
#[error("failed to deserialize")]
|
||||
|
|
|
|||
|
|
@ -27,4 +27,4 @@ pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
|
|||
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
|
||||
pub const ARCH_MODEL_PREFIX: &str = "Arch";
|
||||
pub const HALLUCINATION_TEMPLATE: &str =
|
||||
"It seems I’m missing some information. Could you provide the following details ";
|
||||
"It seems I'm missing some information. Could you provide the following details ";
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
|||
use crate::hallucination::extract_messages_for_hallucination;
|
||||
use acap::cos;
|
||||
use common::common_types::open_ai::{
|
||||
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
|
||||
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, ToolCall,
|
||||
ToolType,
|
||||
ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, ChunkChoice, Delta, FunctionDefinition, FunctionParameter,
|
||||
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
|
||||
};
|
||||
use common::common_types::{
|
||||
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
|
||||
|
|
@ -303,16 +303,16 @@ impl StreamContext {
|
|||
body: Vec<u8>,
|
||||
callout_context: StreamCallContext,
|
||||
) {
|
||||
let boyd_str = String::from_utf8(body).expect("could not convert body to string");
|
||||
debug!("archgw <= hallucination response: {}", boyd_str);
|
||||
let body_str = String::from_utf8(body).expect("could not convert body to string");
|
||||
debug!("archgw <= hallucination response: {}", body_str);
|
||||
let hallucination_response: HallucinationClassificationResponse =
|
||||
match serde_json::from_str(boyd_str.as_str()) {
|
||||
match serde_json::from_str(body_str.as_str()) {
|
||||
Ok(hallucination_response) => hallucination_response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"error deserializing hallucination response: {}, body: {}",
|
||||
e,
|
||||
boyd_str.as_str()
|
||||
body_str.as_str()
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
|
|
@ -331,34 +331,31 @@ impl StreamContext {
|
|||
if !keys_with_low_score.is_empty() {
|
||||
let response =
|
||||
HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?";
|
||||
let message = Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: Some(response),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
|
||||
let chat_completion_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message,
|
||||
index: 0,
|
||||
finish_reason: "done".to_string(),
|
||||
}],
|
||||
usage: None,
|
||||
model: ARCH_FC_MODEL_NAME.to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
let response_str = if self.streaming_response {
|
||||
let chunks = [
|
||||
ChatCompletionStreamResponse::new(None, Some(ASSISTANT_ROLE.to_string())),
|
||||
ChatCompletionStreamResponse::new(Some(response), None),
|
||||
];
|
||||
|
||||
trace!("hallucination response: {:?}", chat_completion_response);
|
||||
let mut response_str = String::new();
|
||||
for chunk in chunks.iter() {
|
||||
response_str.push_str("data: ");
|
||||
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
|
||||
response_str.push_str("\n\n");
|
||||
}
|
||||
response_str
|
||||
} else {
|
||||
let chat_completion_response = ChatCompletionsResponse::new(response);
|
||||
serde_json::to_string(&chat_completion_response).unwrap()
|
||||
};
|
||||
debug!("hallucination response: {:?}", response_str);
|
||||
// make sure on_http_response_body does not attach tool calls and tool response to the response
|
||||
self.tool_calls = None;
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
Some(
|
||||
serde_json::to_string(&chat_completion_response)
|
||||
.unwrap()
|
||||
.as_bytes(),
|
||||
),
|
||||
Some(response_str.as_bytes()),
|
||||
);
|
||||
} else {
|
||||
// not a hallucination, resume the flow
|
||||
|
|
@ -948,7 +945,7 @@ impl StreamContext {
|
|||
self.get_embeddings(callout_context);
|
||||
}
|
||||
|
||||
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
pub fn default_target_handler(&self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
|
|
@ -956,8 +953,48 @@ impl StreamContext {
|
|||
.clone();
|
||||
|
||||
// check if the default target should be dispatched to the LLM provider
|
||||
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
|
||||
let default_target_response_str = String::from_utf8(body).unwrap();
|
||||
if !prompt_target
|
||||
.auto_llm_dispatch_on_response
|
||||
.unwrap_or_default()
|
||||
{
|
||||
let default_target_response_str = if self.streaming_response {
|
||||
let chat_completion_response =
|
||||
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
|
||||
|
||||
let chunk_role_message = ChatCompletionStreamResponse {
|
||||
model: Some(chat_completion_response.model.clone()),
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role: Some(USER_ROLE.to_string()),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
let chat_completion_stream_response = ChatCompletionStreamResponse {
|
||||
model: Some(chat_completion_response.model),
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: chat_completion_response.choices[0].message.content.clone(),
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
let chunk_role = serde_json::to_string(&chunk_role_message).unwrap();
|
||||
let chunk_data = serde_json::to_string(&chat_completion_stream_response).unwrap();
|
||||
format!("data: {}\n\ndata: {}\n\n", chunk_role, chunk_data)
|
||||
} else {
|
||||
String::from_utf8(body).unwrap()
|
||||
};
|
||||
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
|
|
@ -965,20 +1002,20 @@ impl StreamContext {
|
|||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
|
||||
Ok(chat_completions_resp) => chat_completions_resp,
|
||||
Err(e) => {
|
||||
warn!("error deserializing default target response: {}", e);
|
||||
warn!(
|
||||
"error deserializing default target response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8(body).unwrap()
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
let api_resp = chat_completions_resp.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
let mut messages = callout_context.request_body.messages;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
// add system prompt
|
||||
match prompt_target.system_prompt.as_ref() {
|
||||
None => {}
|
||||
|
|
@ -994,13 +1031,24 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
messages.append(&mut callout_context.request_body.messages);
|
||||
|
||||
let api_resp = chat_completions_resp.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
|
||||
let user_message = messages.pop().unwrap();
|
||||
let message = format!("{}\ncontext: {}", user_message.content.unwrap(), api_resp);
|
||||
messages.push(Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(api_resp.clone()),
|
||||
content: Some(message),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: self
|
||||
.chat_completions_request
|
||||
|
|
@ -1014,6 +1062,7 @@ impl StreamContext {
|
|||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
|
||||
debug!("archgw => (default target) llm request: {}", json_resp);
|
||||
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
|
||||
|
|
|
|||
|
|
@ -66,18 +66,18 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon
|
|||
|
||||
|
||||
class DefaultTargetRequest(BaseModel):
|
||||
arch_messages: list
|
||||
messages: list
|
||||
|
||||
|
||||
@app.post("/default_target")
|
||||
async def default_target(req: DefaultTargetRequest, res: Response):
|
||||
logger.info(f"Received arch_messages: {req.arch_messages}")
|
||||
logger.info(f"Received arch_messages: {req.messages}")
|
||||
resp = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "hello world from api server",
|
||||
"role": "user",
|
||||
"content": "I can help you with weather forecast or insurance claim details",
|
||||
},
|
||||
"finish_reason": "completed",
|
||||
"index": 0,
|
||||
|
|
|
|||
|
|
@ -82,10 +82,10 @@ prompt_targets:
|
|||
name: api_server
|
||||
path: /default_target
|
||||
system_prompt: |
|
||||
You are a helpful assistant. Use the information that is provided to you.
|
||||
You are a helpful assistant! Summarize the user's request and provide a helpful response.
|
||||
# if it is set to false arch will send response that it received from this prompt target to the user
|
||||
# if true arch will forward the response to the default LLM
|
||||
auto_llm_dispatch_on_response: true
|
||||
auto_llm_dispatch_on_response: false
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ LLM_GATEWAY_ENDPOINT = os.getenv(
|
|||
def get_data_chunks(stream, n=1):
|
||||
chunks = []
|
||||
for chunk in stream.iter_lines():
|
||||
print(chunk)
|
||||
if chunk:
|
||||
chunk = chunk.decode("utf-8")
|
||||
chunk_data_id = chunk[0:6]
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def test_prompt_gateway_param_gathering(stream):
|
|||
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
|
||||
assert response.status_code == 200
|
||||
if stream:
|
||||
chunks = get_data_chunks(response)
|
||||
chunks = get_data_chunks(response, n=3)
|
||||
assert len(chunks) > 0
|
||||
response_json = json.loads(chunks[0])
|
||||
# if its streaming we return tool call and api call in first two chunks
|
||||
|
|
@ -112,3 +112,33 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("gpt-4o-mini")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_prompt_gateway_default_target(stream):
|
||||
body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello, what can you do for me?",
|
||||
},
|
||||
],
|
||||
"stream": stream,
|
||||
}
|
||||
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
|
||||
assert response.status_code == 200
|
||||
if stream:
|
||||
chunks = get_data_chunks(response, n=3)
|
||||
assert len(chunks) > 0
|
||||
response_json = json.loads(chunks[0])
|
||||
assert response_json.get("model").startswith("api_server")
|
||||
response_json = json.loads(chunks[1])
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
content = choices[0]["delta"]["content"]
|
||||
assert (
|
||||
content == "I can help you with weather forecast or insurance claim details"
|
||||
)
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("api_server")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue