direct reps

This commit is contained in:
Adil Hafeez 2024-10-27 19:21:31 -07:00
parent a0ffabb0fd
commit 85aee8c9a8
2 changed files with 70 additions and 8 deletions

View file

@ -3,8 +3,8 @@ use crate::hallucination::extract_messages_for_hallucination;
use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
ParameterType, ToolCall, ToolType,
};
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
@ -636,9 +636,10 @@ impl StreamContext {
};
arch_fc_response.choices[0]
.message
.tool_calls
.clone_into(&mut self.tool_calls);
.message
.tool_calls
.clone_into(&mut self.tool_calls);
if self.tool_calls.as_ref().unwrap().len() > 1 {
warn!(
"multiple tool calls not supported yet, tool_calls count found: {}",
@ -653,10 +654,44 @@ impl StreamContext {
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
let direct_response_str = if self.streaming_response {
let chunks = [
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
),
ChatCompletionStreamResponse::new(
Some(
arch_fc_response.choices[0]
.message
.content
.as_ref()
.unwrap()
.clone(),
),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
),
];
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
} else {
body_str
};
if self.streaming_response {}
return self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(body_str.as_bytes()),
Some(direct_response_str.as_bytes()),
);
}
@ -989,8 +1024,7 @@ impl StreamContext {
response_str.push_str("\n\n");
}
response_str
}
else {
} else {
String::from_utf8(body).unwrap()
};

View file

@ -29,6 +29,34 @@ def test_prompt_gateway(stream):
assert response_json.get("model").startswith("gpt-4o-mini")
@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway_arch_direct_response(stream):
body = {
"messages": [
{
"role": "user",
"content": "how is the weather",
}
],
"stream": stream,
}
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
assert response.status_code == 200
if stream:
chunks = get_data_chunks(response, n=3)
assert len(chunks) > 0
response_json = json.loads(chunks[0])
# if its streaming we return tool call and api call in first two chunks
assert response_json.get("model").startswith("Arch")
else:
response_json = response.json()
assert response_json.get("model").startswith("Arch")
choices = response_json.get("choices", [])
assert len(choices) > 0
message = choices[0]["message"]["content"]
assert "Could you provide the following details days" not in message
@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway_param_gathering(stream):
body = {