mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fix merge
This commit is contained in:
commit
c62f763070
15 changed files with 265 additions and 120 deletions
|
|
@ -13,8 +13,11 @@ pub const MESSAGES_KEY: &str = "messages";
|
|||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const HEALTHZ_PATH: &str = "/healthz";
|
||||
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
|
||||
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message";
|
||||
pub const X_ARCH_TOOL_CALL: &str = "x-arch-tool-call-message";
|
||||
pub const X_ARCH_FC_MODEL_RESPONSE: &str = "x-arch-fc-model-response";
|
||||
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function";
|
||||
pub const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
pub const TRACE_PARENT_HEADER: &str = "traceparent";
|
||||
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
|
||||
|
|
|
|||
|
|
@ -411,7 +411,7 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
|
||||
if self.request_body_sent_time.is_none() {
|
||||
debug!("on_http_response_body: request body not sent, no doing any processing in llm filter");
|
||||
debug!("on_http_response_body: request body not sent, not doing any processing in llm filter");
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ use common::{
|
|||
self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
},
|
||||
consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER,
|
||||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
|
||||
MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE,
|
||||
TRACE_PARENT_HEADER, USER_ROLE,
|
||||
TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE,
|
||||
X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL,
|
||||
},
|
||||
errors::ServerError,
|
||||
http::{CallArgs, Client},
|
||||
|
|
@ -125,8 +126,8 @@ impl HttpContext for StreamContext {
|
|||
|
||||
self.arch_state = match deserialized_body.metadata {
|
||||
Some(ref metadata) => {
|
||||
if metadata.contains_key(ARCH_STATE_HEADER) {
|
||||
let arch_state_str = metadata[ARCH_STATE_HEADER].clone();
|
||||
if metadata.contains_key(X_ARCH_STATE_HEADER) {
|
||||
let arch_state_str = metadata[X_ARCH_STATE_HEADER].clone();
|
||||
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
|
||||
Some(arch_state)
|
||||
} else {
|
||||
|
|
@ -336,10 +337,10 @@ impl HttpContext for StreamContext {
|
|||
if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
self.arch_fc_response.clone(),
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
self.tool_calls.to_owned(),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
self.tool_call_response.clone(),
|
||||
|
|
@ -381,17 +382,39 @@ impl HttpContext for StreamContext {
|
|||
*metadata = Value::Object(serde_json::Map::new());
|
||||
}
|
||||
|
||||
let fc_messages = vec![
|
||||
self.generate_toll_call_message(),
|
||||
self.generate_api_response_message(),
|
||||
];
|
||||
let tool_call_message = self.generate_toll_call_message();
|
||||
let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_TOOL_CALL.to_string(),
|
||||
serde_json::Value::String(tool_call_message_str),
|
||||
);
|
||||
|
||||
let api_response_message = self.generate_api_response_message();
|
||||
let api_response_message_str =
|
||||
serde_json::to_string(&api_response_message).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_API_RESPONSE.to_string(),
|
||||
serde_json::Value::String(api_response_message_str),
|
||||
);
|
||||
|
||||
let fc_messages = vec![tool_call_message, api_response_message];
|
||||
|
||||
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
|
||||
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
|
||||
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
ARCH_STATE_HEADER.to_string(),
|
||||
X_ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::Value::String(arch_state_str),
|
||||
);
|
||||
|
||||
if let Some(arch_fc_response) = self.arch_fc_response.as_ref() {
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_FC_MODEL_RESPONSE.to_string(),
|
||||
serde_json::Value::String(
|
||||
serde_json::to_string(arch_fc_response).unwrap(),
|
||||
),
|
||||
);
|
||||
}
|
||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||
info!("archgw <= developer: {}", data_serialized);
|
||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ use common::consts::{
|
|||
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY,
|
||||
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
|
||||
X_ARCH_FC_MODEL_RESPONSE,
|
||||
};
|
||||
use common::errors::ServerError;
|
||||
use common::http::{CallArgs, Client};
|
||||
|
|
@ -64,10 +65,10 @@ pub struct StreamContext {
|
|||
pub time_to_first_token: Option<u128>,
|
||||
pub traceparent: Option<String>,
|
||||
pub _tracing: Rc<Option<Tracing>>,
|
||||
pub arch_fc_response: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
context_id: u32,
|
||||
metrics: Rc<Metrics>,
|
||||
|
|
@ -98,6 +99,7 @@ impl StreamContext {
|
|||
_tracing: tracing,
|
||||
start_upstream_llm_request_time: 0,
|
||||
time_to_first_token: None,
|
||||
arch_fc_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -142,15 +144,17 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// intent was matched if we see function_latency in metadata
|
||||
let intent_matched = model_server_response
|
||||
let intent_matched = check_intent_matched(&model_server_response);
|
||||
info!("intent matched: {}", intent_matched);
|
||||
|
||||
self.arch_fc_response = model_server_response
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.get("function_latency"))
|
||||
.is_some();
|
||||
.and_then(|metadata| metadata.get(X_ARCH_FC_MODEL_RESPONSE))
|
||||
.cloned();
|
||||
|
||||
|
||||
if !intent_matched {
|
||||
info!("intent not matched");
|
||||
// check if we have a default prompt target
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
|
|
@ -278,9 +282,9 @@ impl StreamContext {
|
|||
let direct_response_str = if self.streaming_response {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
self.arch_fc_response.clone(),
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
|
|
@ -293,7 +297,7 @@ impl StreamContext {
|
|||
.clone(),
|
||||
),
|
||||
None,
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
|
@ -624,12 +628,23 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
pub fn generate_toll_call_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
if self.arch_fc_response.is_none() {
|
||||
info!("arch_fc_response is none, generating tool call message");
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
}
|
||||
} else {
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: self.arch_fc_response.as_ref().cloned(),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -761,6 +776,26 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool {
|
||||
let content = model_server_response
|
||||
.choices
|
||||
.get(0)
|
||||
.and_then(|choice| choice.message.content.as_ref());
|
||||
|
||||
let content_has_value = content.is_some() && !content.unwrap().is_empty();
|
||||
|
||||
let tool_calls = model_server_response
|
||||
.choices
|
||||
.get(0)
|
||||
.and_then(|choice| choice.message.tool_calls.as_ref());
|
||||
|
||||
// intent was matched if content has some value or tool_calls is empty
|
||||
let intent_matched =
|
||||
content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty());
|
||||
|
||||
return intent_matched;
|
||||
}
|
||||
|
||||
impl Client for StreamContext {
|
||||
type CallContext = StreamCallContext;
|
||||
|
||||
|
|
@ -772,3 +807,77 @@ impl Client for StreamContext {
|
|||
&self.metrics.active_http_calls
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall};
|
||||
|
||||
use crate::stream_context::check_intent_matched;
|
||||
|
||||
#[test]
|
||||
fn test_intent_matched() {
|
||||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("".to_string()),
|
||||
tool_calls: Some(vec![]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
index: None,
|
||||
}],
|
||||
usage: None,
|
||||
model: "arch-fc".to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
assert_eq!(check_intent_matched(&model_server_response), false);
|
||||
|
||||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("hello".to_string()),
|
||||
tool_calls: Some(vec![]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
index: None,
|
||||
}],
|
||||
usage: None,
|
||||
model: "arch-fc".to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
assert_eq!(check_intent_matched(&model_server_response), true);
|
||||
|
||||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("".to_string()),
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "1".to_string(),
|
||||
function: common::api::open_ai::FunctionCallDetail {
|
||||
name: "test".to_string(),
|
||||
arguments: None,
|
||||
},
|
||||
tool_type: common::api::open_ai::ToolType::Function,
|
||||
}]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
index: None,
|
||||
}],
|
||||
usage: None,
|
||||
model: "arch-fc".to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
assert_eq!(check_intent_matched(&model_server_response), true);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -380,6 +380,7 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
|
|
@ -453,6 +454,7 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -493,19 +495,9 @@ fn prompt_gateway_request_no_intent_match() {
|
|||
finish_reason: Some("test".to_string()),
|
||||
index: Some(0),
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: Some(HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)])),
|
||||
},
|
||||
}]),
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
|
|
@ -523,7 +515,7 @@ fn prompt_gateway_request_no_intent_match() {
|
|||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("intent not matched"))
|
||||
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
|
||||
.expect_log(
|
||||
Some(LogLevel::Info),
|
||||
Some("no default prompt target found, forwarding request to upstream llm"),
|
||||
|
|
@ -651,17 +643,7 @@ fn prompt_gateway_request_no_intent_match_default_target() {
|
|||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: Some(HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)])),
|
||||
},
|
||||
}]),
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
|
|
@ -679,7 +661,7 @@ fn prompt_gateway_request_no_intent_match_default_target() {
|
|||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("intent not matched"))
|
||||
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
|
||||
.expect_log(
|
||||
Some(LogLevel::Info),
|
||||
Some("default prompt target found, forwarding request to default prompt target"),
|
||||
|
|
|
|||
|
|
@ -120,8 +120,11 @@ def process_stream_chunk(chunk, history):
|
|||
|
||||
if delta.content:
|
||||
# append content to the last history item
|
||||
history[-1]["content"] = history[-1].get("content", "") + delta.content
|
||||
if history[-1]["model"] != "Arch-Function-Chat":
|
||||
history[-1]["content"] = history[-1].get("content", "") + delta.content
|
||||
# yield content if it is from assistant
|
||||
if history[-1]["model"] == "Arch-Function":
|
||||
return None
|
||||
if history[-1]["role"] == "assistant":
|
||||
return delta.content
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ logger = get_model_server_logger()
|
|||
|
||||
|
||||
# Define the client
|
||||
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
|
||||
# ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
|
||||
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "http://35.225.55.128:8000/v1")
|
||||
ARCH_API_KEY = "EMPTY"
|
||||
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
|
||||
ARCH_AGENT_CLIENT = ARCH_CLIENT
|
||||
|
|
|
|||
|
|
@ -27,16 +27,15 @@ logger = utils.get_model_server_logger()
|
|||
class ArchFunctionConfig:
|
||||
TASK_PROMPT = (
|
||||
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed."
|
||||
"\nToday's date: {today_date}"
|
||||
"\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}\n</tools>"
|
||||
"\n\nYour task is to decide which functions are needed and collect missing parameters if necessary.\n\n"
|
||||
"\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tools}\n</tools>"
|
||||
"\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
|
||||
)
|
||||
|
||||
FORMAT_PROMPT = (
|
||||
"Based on your analysis, provide your response in one of the following JSON formats:"
|
||||
'\n1. If no functions are needed:\n```\n{"response": "Your response text here"}\n```'
|
||||
'\n2. If functions are needed but some required parameters are missing:\n```\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```'
|
||||
'\n3. If functions are needed and all required parameters are available:\n```\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```'
|
||||
"\n\nBased on your analysis, provide your response in one of the following JSON formats:"
|
||||
'\n1. If no functions are needed:\n```json\n{"response": "Your response text here"}\n```'
|
||||
'\n2. If functions are needed but some required parameters are missing:\n```json\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```'
|
||||
'\n3. If functions are needed and all required parameters are available:\n```json\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```'
|
||||
)
|
||||
|
||||
GENERATION_PARAMS = {
|
||||
|
|
@ -193,16 +192,21 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
}
|
||||
|
||||
try:
|
||||
if content.startswith("```") and content.endswith("```"):
|
||||
content = content.strip("```").strip()
|
||||
if content.startswith("json"):
|
||||
content = content[4:].strip()
|
||||
|
||||
model_response = json.loads(self._fix_json_string(content))
|
||||
|
||||
response_dict["response"] = model_response.get("response", "")
|
||||
response_dict["required_functions"] = model_response.get(
|
||||
"required_functions", ""
|
||||
"required_functions", []
|
||||
)
|
||||
response_dict["clarification"] = model_response.get("clarification", "")
|
||||
|
||||
for tool_call in model_response.get("tool_calls", []):
|
||||
response_dict["tool_call"].append(
|
||||
response_dict["tool_calls"].append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
|
|
@ -413,8 +417,8 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallucination_state.tokens) > 2 and has_tool_calls is None:
|
||||
content = ''.join(self.hallucination_state.tokens)
|
||||
if len(self.hallucination_state.tokens) > 5 and has_tool_calls is None:
|
||||
content = "".join(self.hallucination_state.tokens)
|
||||
if "tool_calls" in content:
|
||||
has_tool_calls = True
|
||||
else:
|
||||
|
|
@ -448,6 +452,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
# if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
|
||||
# model_response += chunk.choices[0].delta.content
|
||||
|
||||
logger.info(f"[arch-fc]: raw model response: {model_response}")
|
||||
# Extract tool calls from model response
|
||||
response_dict = self._parse_model_resonse(model_response)
|
||||
|
||||
|
|
@ -499,10 +504,15 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
model_message = Message(content="", tool_calls=[])
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_message)], model=self.model_name
|
||||
choices=[Choice(message=model_message)],
|
||||
model=self.model_name,
|
||||
metadata={"x-arch-fc-model-response": model_response},
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}")
|
||||
logger.info(
|
||||
f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump())}"
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
|
|
|
|||
|
|
@ -104,10 +104,10 @@ class ArchBaseHandler:
|
|||
"""
|
||||
|
||||
today_date = utils.get_today_date()
|
||||
tool_text = self._convert_tools(tools)
|
||||
tools = self._convert_tools(tools)
|
||||
|
||||
system_prompt = (
|
||||
self.task_prompt.format(today_date=today_date, tool_text=tool_text)
|
||||
self.task_prompt.format(today_date=today_date, tools=tools)
|
||||
+ self.format_prompt
|
||||
)
|
||||
|
||||
|
|
@ -142,7 +142,7 @@ class ArchBaseHandler:
|
|||
{"role": "system", "content": self._format_system_prompt(tools)}
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
for idx, message in enumerate(messages):
|
||||
role, content, tool_calls = (
|
||||
message.role,
|
||||
message.content,
|
||||
|
|
@ -158,9 +158,17 @@ class ArchBaseHandler:
|
|||
if metadata.get("optimize_context_window", "false").lower() == "true":
|
||||
content = f"<tool_response>\n\n</tool_response>"
|
||||
else:
|
||||
content = (
|
||||
f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
|
||||
)
|
||||
# sample response below
|
||||
# "content": "<tool_response>\n{'name': 'get_stock_price', 'result': '$196.66'}\n</tool_response>"
|
||||
# msg[idx-1] contains tool call = '{"tool_calls": [{"name": "currency_exchange", "arguments": {"currency_symbol": "NZD"}}]}'
|
||||
func_name = json.loads(messages[idx - 1].content)["tool_calls"][
|
||||
0
|
||||
].get("name", "no_name")
|
||||
tool_response = {
|
||||
"name": func_name,
|
||||
"result": content,
|
||||
}
|
||||
content = f"<tool_response>\n{json.dumps(tool_response)}\n</tool_response>"
|
||||
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
|
|
|
|||
|
|
@ -87,16 +87,15 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
final_response = await model_handler.chat_completion(req)
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
if not final_response.metadata:
|
||||
final_response.metadata = {}
|
||||
|
||||
# Parameter gathering for detected intents
|
||||
if final_response.choices[0].message.content:
|
||||
final_response.metadata = {
|
||||
"function_latency": str(round(latency * 1000, 3)),
|
||||
}
|
||||
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
|
||||
# Function Calling
|
||||
elif final_response.choices[0].message.tool_calls:
|
||||
final_response.metadata = {
|
||||
"function_latency": str(round(latency * 1000, 3)),
|
||||
}
|
||||
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
|
||||
|
||||
# *********************************************************************************************
|
||||
# TODO: Put the following code back when hallucination check is ready
|
||||
|
|
@ -107,9 +106,7 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
)
|
||||
# No intent detected
|
||||
else:
|
||||
final_response.metadata = {
|
||||
"intent_latency": str(round(latency * 1000, 3)),
|
||||
}
|
||||
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
|
||||
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
|
||||
|
|
|
|||
|
|
@ -123,35 +123,35 @@ def get_greeting_data():
|
|||
return req, False, False, False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"get_data_func",
|
||||
[
|
||||
get_hallucination_data_complex,
|
||||
get_complete_data,
|
||||
get_irrelevant_data,
|
||||
get_complete_data_2,
|
||||
],
|
||||
)
|
||||
async def test_function_calling(get_data_func):
|
||||
req, intent, hallucination, parameter_gathering = get_data_func()
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# "get_data_func",
|
||||
# [
|
||||
# get_hallucination_data_complex,
|
||||
# get_complete_data,
|
||||
# get_irrelevant_data,
|
||||
# get_complete_data_2,
|
||||
# ],
|
||||
# )
|
||||
# async def test_function_calling(get_data_func):
|
||||
# req, intent, hallucination, parameter_gathering = get_data_func()
|
||||
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
# intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
|
||||
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
|
||||
# assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
|
||||
|
||||
if intent:
|
||||
function_calling_response = await handler_map["Arch-Function"].chat_completion(
|
||||
req
|
||||
)
|
||||
assert (
|
||||
handler_map["Arch-Function"].hallucination_state.hallucination
|
||||
== hallucination
|
||||
)
|
||||
response_txt = function_calling_response.choices[0].message.content
|
||||
# if intent:
|
||||
# function_calling_response = await handler_map["Arch-Function"].chat_completion(
|
||||
# req
|
||||
# )
|
||||
# assert (
|
||||
# handler_map["Arch-Function"].hallucination_state.hallucination
|
||||
# == hallucination
|
||||
# )
|
||||
# response_txt = function_calling_response.choices[0].message.content
|
||||
|
||||
if parameter_gathering:
|
||||
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
|
||||
assert any(
|
||||
response_txt.startswith(prefix) for prefix in prefill_prefix
|
||||
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"
|
||||
# if parameter_gathering:
|
||||
# prefill_prefix = handler_map["Arch-Function"].prefill_prefix
|
||||
# assert any(
|
||||
# response_txt.startswith(prefix) for prefix in prefill_prefix
|
||||
# ), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"
|
||||
|
|
|
|||
|
|
@ -47,14 +47,11 @@ TEST_CASE_FIXTURES = {
|
|||
"tool_call_id": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_6009",
|
||||
"id": "call_2925",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": {
|
||||
"location": "Seattle, WA",
|
||||
"days": "2",
|
||||
},
|
||||
"arguments": {"location": "Seattle", "days": "2"},
|
||||
},
|
||||
}
|
||||
],
|
||||
|
|
@ -63,7 +60,11 @@ TEST_CASE_FIXTURES = {
|
|||
}
|
||||
],
|
||||
"model": "Arch-Function",
|
||||
"metadata": {"intent_latency": "455.092", "function_latency": "312.744"},
|
||||
"metadata": {
|
||||
"x-arch-fc-model-response": '{"tool_calls": [{"name": "get_current_weather", "arguments": {"location": "Seattle", "days": "2"}}]}',
|
||||
"function_latency": "361.841",
|
||||
"intent_latency": "361.841",
|
||||
},
|
||||
},
|
||||
"api_server_response": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -42,9 +42,11 @@ def test_prompt_gateway(stream):
|
|||
assert "role" in choices[0]["delta"]
|
||||
role = choices[0]["delta"]["role"]
|
||||
assert role == "assistant"
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
print(f"choices: {choices}")
|
||||
tool_call_str = choices[0].get("delta", {}).get("content", "")
|
||||
tool_calls = json.loads(tool_call_str).get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
tool_call = tool_calls[0]
|
||||
location = tool_call["arguments"]["location"]
|
||||
assert expected_tool_call["arguments"]["location"] in location.lower()
|
||||
del expected_tool_call["arguments"]["location"]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,9 @@ import requests
|
|||
import logging
|
||||
import yaml
|
||||
|
||||
pytestmark = pytest.mark.skip(
|
||||
reason="Skipping entire test file as hallucination is not enabled for archfc 1.1 yet"
|
||||
)
|
||||
|
||||
MODEL_SERVER_ENDPOINT = os.getenv(
|
||||
"MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@ import yaml
|
|||
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
pytestmark = pytest.mark.skip(
|
||||
reason="Skipping entire test file as this these tests are heavily dependent on model output"
|
||||
)
|
||||
|
||||
MODEL_SERVER_ENDPOINT = os.getenv(
|
||||
"MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue