add more tests

This commit is contained in:
Adil Hafeez 2024-10-27 15:02:43 -07:00
parent f5ecf733ff
commit e35560623c
7 changed files with 171 additions and 49 deletions

View file

@ -42,6 +42,8 @@ pub mod open_ai {
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsRequest {
#[serde(default)]
@ -242,6 +244,28 @@ pub mod open_ai {
pub metadata: Option<HashMap<String, String>>,
}
// create constructor for ChatCompletionsResponse
impl ChatCompletionsResponse {
pub fn new(message: String) -> Self {
ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
role: ASSISTANT_ROLE.to_string(),
content: Some(message),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
},
index: 0,
finish_reason: "done".to_string(),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
metadata: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub completion_tokens: usize,
@ -254,6 +278,24 @@ pub mod open_ai {
pub choices: Vec<ChunkChoice>,
}
impl ChatCompletionStreamResponse {
pub fn new(response: Option<String>, role: Option<String>) -> Self {
ChatCompletionStreamResponse {
model: Some(ARCH_FC_MODEL_NAME.to_string()),
choices: vec![ChunkChoice {
delta: Delta {
role,
content: response,
tool_calls: None,
model: None,
tool_call_id: None,
},
finish_reason: None,
}],
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ChatCompletionChunkResponseError {
#[error("failed to deserialize")]

View file

@ -27,4 +27,4 @@ pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
pub const ARCH_MODEL_PREFIX: &str = "Arch";
pub const HALLUCINATION_TEMPLATE: &str =
"It seems Im missing some information. Could you provide the following details ";
"It seems I'm missing some information. Could you provide the following details ";

View file

@ -2,9 +2,9 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::hallucination::extract_messages_for_hallucination;
use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, ToolCall,
ToolType,
ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, ChunkChoice, Delta, FunctionDefinition, FunctionParameter,
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
};
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
@ -303,16 +303,16 @@ impl StreamContext {
body: Vec<u8>,
callout_context: StreamCallContext,
) {
let boyd_str = String::from_utf8(body).expect("could not convert body to string");
debug!("archgw <= hallucination response: {}", boyd_str);
let body_str = String::from_utf8(body).expect("could not convert body to string");
debug!("archgw <= hallucination response: {}", body_str);
let hallucination_response: HallucinationClassificationResponse =
match serde_json::from_str(boyd_str.as_str()) {
match serde_json::from_str(body_str.as_str()) {
Ok(hallucination_response) => hallucination_response,
Err(e) => {
warn!(
"error deserializing hallucination response: {}, body: {}",
e,
boyd_str.as_str()
body_str.as_str()
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
@ -331,34 +331,31 @@ impl StreamContext {
if !keys_with_low_score.is_empty() {
let response =
HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?";
let message = Message {
role: ASSISTANT_ROLE.to_string(),
content: Some(response),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
};
let chat_completion_response = ChatCompletionsResponse {
choices: vec![Choice {
message,
index: 0,
finish_reason: "done".to_string(),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
metadata: None,
};
let response_str = if self.streaming_response {
let chunks = [
ChatCompletionStreamResponse::new(None, Some(ASSISTANT_ROLE.to_string())),
ChatCompletionStreamResponse::new(Some(response), None),
];
trace!("hallucination response: {:?}", chat_completion_response);
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
} else {
let chat_completion_response = ChatCompletionsResponse::new(response);
serde_json::to_string(&chat_completion_response).unwrap()
};
debug!("hallucination response: {:?}", response_str);
// make sure on_http_response_body does not attach tool calls and tool response to the response
self.tool_calls = None;
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(
serde_json::to_string(&chat_completion_response)
.unwrap()
.as_bytes(),
),
Some(response_str.as_bytes()),
);
} else {
// not a hallucination, resume the flow
@ -948,7 +945,7 @@ impl StreamContext {
self.get_embeddings(callout_context);
}
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
pub fn default_target_handler(&self, body: Vec<u8>, mut callout_context: StreamCallContext) {
let prompt_target = self
.prompt_targets
.get(callout_context.prompt_target_name.as_ref().unwrap())
@ -956,8 +953,48 @@ impl StreamContext {
.clone();
// check if the default target should be dispatched to the LLM provider
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
let default_target_response_str = String::from_utf8(body).unwrap();
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or_default()
{
let default_target_response_str = if self.streaming_response {
let chat_completion_response =
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
let chunk_role_message = ChatCompletionStreamResponse {
model: Some(chat_completion_response.model.clone()),
choices: vec![ChunkChoice {
delta: Delta {
role: Some(USER_ROLE.to_string()),
content: None,
tool_calls: None,
model: None,
tool_call_id: None,
},
finish_reason: None,
}],
};
let chat_completion_stream_response = ChatCompletionStreamResponse {
model: Some(chat_completion_response.model),
choices: vec![ChunkChoice {
delta: Delta {
role: None,
content: chat_completion_response.choices[0].message.content.clone(),
tool_calls: None,
model: None,
tool_call_id: None,
},
finish_reason: None,
}],
};
let chunk_role = serde_json::to_string(&chunk_role_message).unwrap();
let chunk_data = serde_json::to_string(&chat_completion_stream_response).unwrap();
format!("data: {}\n\ndata: {}\n\n", chunk_role, chunk_data)
} else {
String::from_utf8(body).unwrap()
};
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
@ -965,20 +1002,20 @@ impl StreamContext {
);
return;
}
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
Ok(chat_completions_resp) => chat_completions_resp,
Err(e) => {
warn!("error deserializing default target response: {}", e);
warn!(
"error deserializing default target response: {}, body str: {}",
e,
String::from_utf8(body).unwrap()
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let api_resp = chat_completions_resp.choices[0]
.message
.content
.as_ref()
.unwrap();
let mut messages = callout_context.request_body.messages;
let mut messages = Vec::new();
// add system prompt
match prompt_target.system_prompt.as_ref() {
None => {}
@ -994,13 +1031,24 @@ impl StreamContext {
}
}
messages.append(&mut callout_context.request_body.messages);
let api_resp = chat_completions_resp.choices[0]
.message
.content
.as_ref()
.unwrap();
let user_message = messages.pop().unwrap();
let message = format!("{}\ncontext: {}", user_message.content.unwrap(), api_resp);
messages.push(Message {
role: USER_ROLE.to_string(),
content: Some(api_resp.clone()),
content: Some(message),
model: None,
tool_calls: None,
tool_call_id: None,
});
let chat_completion_request = ChatCompletionsRequest {
model: self
.chat_completions_request
@ -1014,6 +1062,7 @@ impl StreamContext {
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("archgw => (default target) llm request: {}", json_resp);
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());

View file

@ -66,18 +66,18 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon
class DefaultTargetRequest(BaseModel):
arch_messages: list
messages: list
@app.post("/default_target")
async def default_target(req: DefaultTargetRequest, res: Response):
logger.info(f"Received arch_messages: {req.arch_messages}")
logger.info(f"Received arch_messages: {req.messages}")
resp = {
"choices": [
{
"message": {
"role": "assistant",
"content": "hello world from api server",
"role": "user",
"content": "I can help you with weather forecast or insurance claim details",
},
"finish_reason": "completed",
"index": 0,

View file

@ -82,10 +82,10 @@ prompt_targets:
name: api_server
path: /default_target
system_prompt: |
You are a helpful assistant. Use the information that is provided to you.
You are a helpful assistant! Summarize the user's request and provide a helpful response.
# if it is set to false arch will send response that it received from this prompt target to the user
# if true arch will forward the response to the default LLM
auto_llm_dispatch_on_response: true
auto_llm_dispatch_on_response: false
tracing:
random_sampling: 100

View file

@ -12,6 +12,7 @@ LLM_GATEWAY_ENDPOINT = os.getenv(
def get_data_chunks(stream, n=1):
chunks = []
for chunk in stream.iter_lines():
print(chunk)
if chunk:
chunk = chunk.decode("utf-8")
chunk_data_id = chunk[0:6]

View file

@ -43,7 +43,7 @@ def test_prompt_gateway_param_gathering(stream):
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
assert response.status_code == 200
if stream:
chunks = get_data_chunks(response)
chunks = get_data_chunks(response, n=3)
assert len(chunks) > 0
response_json = json.loads(chunks[0])
# if its streaming we return tool call and api call in first two chunks
@ -112,3 +112,33 @@ def test_prompt_gateway_param_tool_call(stream):
else:
response_json = response.json()
assert response_json.get("model").startswith("gpt-4o-mini")
@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway_default_target(stream):
body = {
"messages": [
{
"role": "user",
"content": "hello, what can you do for me?",
},
],
"stream": stream,
}
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
assert response.status_code == 200
if stream:
chunks = get_data_chunks(response, n=3)
assert len(chunks) > 0
response_json = json.loads(chunks[0])
assert response_json.get("model").startswith("api_server")
response_json = json.loads(chunks[1])
choices = response_json.get("choices", [])
assert len(choices) > 0
content = choices[0]["delta"]["content"]
assert (
content == "I can help you with weather forecast or insurance claim details"
)
else:
response_json = response.json()
assert response_json.get("model").startswith("api_server")