mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Integrate Arch-Function-Chat (#449)
This commit is contained in:
parent
f31aa59fac
commit
7d4b261a68
26 changed files with 558 additions and 603 deletions
|
|
@ -13,8 +13,11 @@ pub const MESSAGES_KEY: &str = "messages";
|
|||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: [&str; 2] = ["/v1/chat/completions", "/openai/v1/chat/completions"];
|
||||
pub const HEALTHZ_PATH: &str = "/healthz";
|
||||
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
|
||||
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message";
|
||||
pub const X_ARCH_TOOL_CALL: &str = "x-arch-tool-call-message";
|
||||
pub const X_ARCH_FC_MODEL_RESPONSE: &str = "x-arch-fc-model-response";
|
||||
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function";
|
||||
pub const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
pub const TRACE_PARENT_HEADER: &str = "traceparent";
|
||||
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
|
||||
|
|
|
|||
|
|
@ -50,8 +50,7 @@ pub trait Client: Context {
|
|||
) -> Result<u32, ClientError> {
|
||||
debug!(
|
||||
"dispatching http call with args={:?} context={:?}",
|
||||
call_args,
|
||||
call_context
|
||||
call_args, call_context
|
||||
);
|
||||
|
||||
match self.dispatch_http_call(
|
||||
|
|
|
|||
|
|
@ -101,9 +101,7 @@ impl RatelimitMap {
|
|||
) -> Result<(), Error> {
|
||||
debug!(
|
||||
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
|
||||
provider,
|
||||
selector,
|
||||
tokens_used
|
||||
provider, selector, tokens_used
|
||||
);
|
||||
|
||||
let provider_limits = match self.datastore.get(&provider) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use log::{debug};
|
||||
use log::debug;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, String> {
|
||||
|
|
|
|||
|
|
@ -428,7 +428,7 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
|
||||
if self.request_body_sent_time.is_none() {
|
||||
debug!("on_http_response_body: request body not sent, no doing any processing in llm filter");
|
||||
debug!("on_http_response_body: request body not sent, not doing any processing in llm filter");
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ use common::{
|
|||
self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
},
|
||||
consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER,
|
||||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
|
||||
MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE,
|
||||
TRACE_PARENT_HEADER, USER_ROLE,
|
||||
TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE,
|
||||
X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL,
|
||||
},
|
||||
errors::ServerError,
|
||||
http::{CallArgs, Client},
|
||||
|
|
@ -125,8 +126,8 @@ impl HttpContext for StreamContext {
|
|||
|
||||
self.arch_state = match deserialized_body.metadata {
|
||||
Some(ref metadata) => {
|
||||
if metadata.contains_key(ARCH_STATE_HEADER) {
|
||||
let arch_state_str = metadata[ARCH_STATE_HEADER].clone();
|
||||
if metadata.contains_key(X_ARCH_STATE_HEADER) {
|
||||
let arch_state_str = metadata[X_ARCH_STATE_HEADER].clone();
|
||||
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
|
||||
Some(arch_state)
|
||||
} else {
|
||||
|
|
@ -336,10 +337,10 @@ impl HttpContext for StreamContext {
|
|||
if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
self.arch_fc_response.clone(),
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
self.tool_calls.to_owned(),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
self.tool_call_response.clone(),
|
||||
|
|
@ -381,17 +382,39 @@ impl HttpContext for StreamContext {
|
|||
*metadata = Value::Object(serde_json::Map::new());
|
||||
}
|
||||
|
||||
let fc_messages = vec![
|
||||
self.generate_toll_call_message(),
|
||||
self.generate_api_response_message(),
|
||||
];
|
||||
let tool_call_message = self.generate_tool_call_message();
|
||||
let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_TOOL_CALL.to_string(),
|
||||
serde_json::Value::String(tool_call_message_str),
|
||||
);
|
||||
|
||||
let api_response_message = self.generate_api_response_message();
|
||||
let api_response_message_str =
|
||||
serde_json::to_string(&api_response_message).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_API_RESPONSE.to_string(),
|
||||
serde_json::Value::String(api_response_message_str),
|
||||
);
|
||||
|
||||
let fc_messages = vec![tool_call_message, api_response_message];
|
||||
|
||||
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
|
||||
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
|
||||
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
ARCH_STATE_HEADER.to_string(),
|
||||
X_ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::Value::String(arch_state_str),
|
||||
);
|
||||
|
||||
if let Some(arch_fc_response) = self.arch_fc_response.as_ref() {
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_FC_MODEL_RESPONSE.to_string(),
|
||||
serde_json::Value::String(
|
||||
serde_json::to_string(arch_fc_response).unwrap(),
|
||||
),
|
||||
);
|
||||
}
|
||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||
info!("archgw <= developer: {}", data_serialized);
|
||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ use common::consts::{
|
|||
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY,
|
||||
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
|
||||
X_ARCH_FC_MODEL_RESPONSE,
|
||||
};
|
||||
use common::errors::ServerError;
|
||||
use common::http::{CallArgs, Client};
|
||||
|
|
@ -64,10 +65,10 @@ pub struct StreamContext {
|
|||
pub time_to_first_token: Option<u128>,
|
||||
pub traceparent: Option<String>,
|
||||
pub _tracing: Rc<Option<Tracing>>,
|
||||
pub arch_fc_response: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
context_id: u32,
|
||||
metrics: Rc<Metrics>,
|
||||
|
|
@ -98,6 +99,7 @@ impl StreamContext {
|
|||
_tracing: tracing,
|
||||
start_upstream_llm_request_time: 0,
|
||||
time_to_first_token: None,
|
||||
arch_fc_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -142,15 +144,16 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// intent was matched if we see function_latency in metadata
|
||||
let intent_matched = model_server_response
|
||||
let intent_matched = check_intent_matched(&model_server_response);
|
||||
info!("intent matched: {}", intent_matched);
|
||||
|
||||
self.arch_fc_response = model_server_response
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.get("function_latency"))
|
||||
.is_some();
|
||||
.and_then(|metadata| metadata.get(X_ARCH_FC_MODEL_RESPONSE))
|
||||
.cloned();
|
||||
|
||||
if !intent_matched {
|
||||
info!("intent not matched");
|
||||
// check if we have a default prompt target
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
|
|
@ -278,9 +281,9 @@ impl StreamContext {
|
|||
let direct_response_str = if self.streaming_response {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
self.arch_fc_response.clone(),
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
|
|
@ -293,7 +296,7 @@ impl StreamContext {
|
|||
.clone(),
|
||||
),
|
||||
None,
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
|
@ -623,13 +626,24 @@ impl StreamContext {
|
|||
messages
|
||||
}
|
||||
|
||||
pub fn generate_toll_call_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
pub fn generate_tool_call_message(&mut self) -> Message {
|
||||
if self.arch_fc_response.is_none() {
|
||||
info!("arch_fc_response is none, generating tool call message");
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
}
|
||||
} else {
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: self.arch_fc_response.as_ref().cloned(),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -761,6 +775,23 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool {
|
||||
let content = model_server_response
|
||||
.choices.first()
|
||||
.and_then(|choice| choice.message.content.as_ref());
|
||||
|
||||
let content_has_value = content.is_some() && !content.unwrap().is_empty();
|
||||
|
||||
let tool_calls = model_server_response
|
||||
.choices.first()
|
||||
.and_then(|choice| choice.message.tool_calls.as_ref());
|
||||
|
||||
// intent was matched if content has some value or tool_calls is empty
|
||||
|
||||
|
||||
content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty())
|
||||
}
|
||||
|
||||
impl Client for StreamContext {
|
||||
type CallContext = StreamCallContext;
|
||||
|
||||
|
|
@ -772,3 +803,77 @@ impl Client for StreamContext {
|
|||
&self.metrics.active_http_calls
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall};
|
||||
|
||||
use crate::stream_context::check_intent_matched;
|
||||
|
||||
#[test]
|
||||
fn test_intent_matched() {
|
||||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("".to_string()),
|
||||
tool_calls: Some(vec![]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
index: None,
|
||||
}],
|
||||
usage: None,
|
||||
model: "arch-fc".to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
assert!(!check_intent_matched(&model_server_response));
|
||||
|
||||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("hello".to_string()),
|
||||
tool_calls: Some(vec![]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
index: None,
|
||||
}],
|
||||
usage: None,
|
||||
model: "arch-fc".to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
assert!(check_intent_matched(&model_server_response));
|
||||
|
||||
let model_server_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
content: Some("".to_string()),
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "1".to_string(),
|
||||
function: common::api::open_ai::FunctionCallDetail {
|
||||
name: "test".to_string(),
|
||||
arguments: None,
|
||||
},
|
||||
tool_type: common::api::open_ai::ToolType::Function,
|
||||
}]),
|
||||
role: "assistant".to_string(),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
index: None,
|
||||
}],
|
||||
usage: None,
|
||||
model: "arch-fc".to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
assert!(check_intent_matched(&model_server_response));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -380,6 +380,7 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
|
|
@ -453,6 +454,7 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -493,19 +495,9 @@ fn prompt_gateway_request_no_intent_match() {
|
|||
finish_reason: Some("test".to_string()),
|
||||
index: Some(0),
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: Some(HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)])),
|
||||
},
|
||||
}]),
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
|
|
@ -523,7 +515,7 @@ fn prompt_gateway_request_no_intent_match() {
|
|||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("intent not matched"))
|
||||
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
|
||||
.expect_log(
|
||||
Some(LogLevel::Info),
|
||||
Some("no default prompt target found, forwarding request to upstream llm"),
|
||||
|
|
@ -651,17 +643,7 @@ fn prompt_gateway_request_no_intent_match_default_target() {
|
|||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: Some(HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)])),
|
||||
},
|
||||
}]),
|
||||
tool_calls: None,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
|
|
@ -679,7 +661,7 @@ fn prompt_gateway_request_no_intent_match_default_target() {
|
|||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("intent not matched"))
|
||||
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
|
||||
.expect_log(
|
||||
Some(LogLevel::Info),
|
||||
Some("default prompt target found, forwarding request to default prompt target"),
|
||||
|
|
|
|||
|
|
@ -22,12 +22,12 @@ llm_providers:
|
|||
provider_interface: openai
|
||||
model: llama-3.2-3b-preview
|
||||
base_url: https://api.groq.com
|
||||
default: true
|
||||
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider_interface: openai
|
||||
model: gpt-4o
|
||||
default: true
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ async def weather(req: WeatherRequest, res: Response):
|
|||
|
||||
|
||||
class DefaultTargetRequest(BaseModel):
|
||||
messages: list
|
||||
messages: list = []
|
||||
|
||||
|
||||
@app.post("/default_target")
|
||||
|
|
@ -86,12 +86,9 @@ async def default_target(req: DefaultTargetRequest, res: Response):
|
|||
"role": "assistant",
|
||||
"content": "I can help you with weather forecast",
|
||||
},
|
||||
"finish_reason": "completed",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"model": "api_server",
|
||||
"usage": {"completion_tokens": 0},
|
||||
}
|
||||
logger.info(f"sending response: {json.dumps(resp)}")
|
||||
return resp
|
||||
|
|
|
|||
4
demos/shared/chatbot_ui/.vscode/launch.json
vendored
4
demos/shared/chatbot_ui/.vscode/launch.json
vendored
|
|
@ -15,7 +15,7 @@
|
|||
"LLM": "1",
|
||||
"CHAT_COMPLETION_ENDPOINT": "http://localhost:10000/v1",
|
||||
"STREAMING": "True",
|
||||
"ARCH_CONFIG": "../../weather_forecast/arch_config.yaml"
|
||||
"ARCH_CONFIG": "../../samples_python/weather_forecast/arch_config.yaml"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
|
@ -29,7 +29,7 @@
|
|||
"LLM": "1",
|
||||
"CHAT_COMPLETION_ENDPOINT": "http://localhost:12000/v1",
|
||||
"STREAMING": "True",
|
||||
"ARCH_CONFIG": "../../llm_routing/arch_config.yaml"
|
||||
"ARCH_CONFIG": "../../samples_python/weather_forecast/arch_config.yaml"
|
||||
}
|
||||
},
|
||||
]
|
||||
|
|
|
|||
|
|
@ -120,8 +120,11 @@ def process_stream_chunk(chunk, history):
|
|||
|
||||
if delta.content:
|
||||
# append content to the last history item
|
||||
history[-1]["content"] = history[-1].get("content", "") + delta.content
|
||||
if history[-1]["model"] != "Arch-Function-Chat":
|
||||
history[-1]["content"] = history[-1].get("content", "") + delta.content
|
||||
# yield content if it is from assistant
|
||||
if history[-1]["model"] == "Arch-Function":
|
||||
return None
|
||||
if history[-1]["role"] == "assistant":
|
||||
return delta.content
|
||||
|
||||
|
|
|
|||
|
|
@ -88,6 +88,22 @@ def chat(
|
|||
|
||||
yield "", conversation, history, debug_output, model_selector
|
||||
|
||||
# update assistant response to have correct format
|
||||
# arch-fc 1.1 expects following format:
|
||||
# {
|
||||
# "response": "<assistant response>",
|
||||
# }
|
||||
# and this entire block needs to be encoded in ```json\n{json_encoded_content}\n```
|
||||
|
||||
if not history[-1]["model"].startswith("Arch"):
|
||||
assistant_response = {
|
||||
"response": history[-1]["content"],
|
||||
}
|
||||
history[-1]["content"] = "```json\n{}\n```".format(
|
||||
json.dumps(assistant_response)
|
||||
)
|
||||
log.info("history: {}".format(json.dumps(history)))
|
||||
|
||||
|
||||
def main():
|
||||
with gr.Blocks(
|
||||
|
|
|
|||
|
|
@ -30,5 +30,11 @@ llm_providers:
|
|||
model: deepseek-reasoner
|
||||
base_url: https://api.deepseek.com/
|
||||
|
||||
- name: groq
|
||||
access_key: $GROQ_API_KEY
|
||||
provider_interface: openai
|
||||
model: llama-3.1-8b-instant
|
||||
base_url: https://api.groq.com
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@ from src.core.guardrails import get_guardrail_handler
|
|||
from src.core.function_calling import (
|
||||
ArchAgentConfig,
|
||||
ArchAgentHandler,
|
||||
ArchIntentConfig,
|
||||
ArchIntentHandler,
|
||||
ArchFunctionConfig,
|
||||
ArchFunctionHandler,
|
||||
)
|
||||
|
|
@ -17,7 +15,10 @@ logger = get_model_server_logger()
|
|||
|
||||
|
||||
# Define the client
|
||||
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
|
||||
# ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
|
||||
# use temporary endpoint until we deprecate archfc-v1.0 from archfc.katanemo.dev
|
||||
# and officially release archfc-v1.1 on archfc.katanemo.dev
|
||||
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "http://34.72.123.163:8000/v1")
|
||||
ARCH_API_KEY = "EMPTY"
|
||||
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
|
||||
ARCH_AGENT_CLIENT = ARCH_CLIENT
|
||||
|
|
@ -30,9 +31,6 @@ ARCH_GUARD_MODEL_ALIAS = "katanemo/Arch-Guard"
|
|||
|
||||
# Define model handlers
|
||||
handler_map = {
|
||||
"Arch-Intent": ArchIntentHandler(
|
||||
ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig
|
||||
),
|
||||
"Arch-Function": ArchFunctionHandler(
|
||||
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
|
||||
),
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import copy
|
|||
import json
|
||||
import random
|
||||
import builtins
|
||||
import textwrap
|
||||
import src.commons.utils as utils
|
||||
|
||||
from openai import OpenAI
|
||||
|
|
@ -22,179 +21,25 @@ from src.core.utils.model_utils import (
|
|||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
class ArchIntentConfig:
|
||||
TASK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
|
||||
"""
|
||||
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
|
||||
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
FORMAT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
|
||||
- First line must read 'Yes' or 'No'.
|
||||
- If yes, a second line must include a comma-separated list of tool indexes.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
EXTRA_INSTRUCTION = "Are there any tools can help?"
|
||||
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.01,
|
||||
"max_tokens": 1,
|
||||
"stop_token_ids": [151645],
|
||||
}
|
||||
|
||||
|
||||
class ArchIntentHandler(ArchBaseHandler):
|
||||
def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig):
|
||||
"""
|
||||
Initializes the intent handler.
|
||||
|
||||
Args:
|
||||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
config (ArchIntentConfig): The configuration for Arch-Intent.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
client,
|
||||
model_name,
|
||||
config.TASK_PROMPT,
|
||||
config.TOOL_PROMPT_TEMPLATE,
|
||||
config.FORMAT_PROMPT,
|
||||
config.GENERATION_PARAMS,
|
||||
)
|
||||
|
||||
self.extra_instruction = config.EXTRA_INSTRUCTION
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Converts a list of tools into a JSON-like format with indexed keys.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||
|
||||
Returns:
|
||||
str: A string representation of converted tools.
|
||||
"""
|
||||
|
||||
converted = [
|
||||
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
|
||||
]
|
||||
return "\n".join(converted)
|
||||
|
||||
def detect_intent(self, content: str) -> bool:
|
||||
"""
|
||||
Detect if any intent match with prompts
|
||||
|
||||
Args:
|
||||
content: str: Model response that contains intent detection results
|
||||
|
||||
Returns:
|
||||
bool: A boolean value to indicate if any intent match with prompts or not
|
||||
"""
|
||||
if hasattr(content.choices[0].message, "content"):
|
||||
return content.choices[0].message.content == "Yes"
|
||||
else:
|
||||
return False
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
Generates a chat completion for a given request.
|
||||
|
||||
Args:
|
||||
req (ChatMessage): A chat message request object.
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse: The model's response to the chat request.
|
||||
|
||||
Note:
|
||||
Currently only support vllm inference
|
||||
"""
|
||||
logger.info("[Arch-Intent] - ChatCompletion")
|
||||
|
||||
# In the case that no tools are available, simply return `No` to avoid making a call
|
||||
if len(req.tools) == 0:
|
||||
model_response = Message(content="No", tool_calls=[])
|
||||
logger.info("No tools found, return `No` as the model response.")
|
||||
else:
|
||||
messages = self._process_messages(
|
||||
req.messages, req.tools, self.extra_instruction
|
||||
)
|
||||
|
||||
logger.info(f"[request to arch-fc (intent)]: {json.dumps(messages)}")
|
||||
|
||||
model_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=False,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
|
||||
logger.info(f"[response]: {json.dumps(model_response.model_dump())}")
|
||||
|
||||
model_response = Message(
|
||||
content=model_response.choices[0].message.content, tool_calls=[]
|
||||
)
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_response)], model=self.model_name
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
# =============================================================================================================
|
||||
# ==============================================================================================================================================
|
||||
|
||||
|
||||
class ArchFunctionConfig:
|
||||
TASK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant.
|
||||
TASK_PROMPT = (
|
||||
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed."
|
||||
"\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tools}\n</tools>"
|
||||
"\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
|
||||
)
|
||||
|
||||
Today's date: {}
|
||||
""".format(
|
||||
utils.get_today_date()
|
||||
)
|
||||
).strip()
|
||||
|
||||
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
|
||||
"""
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
FORMAT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
"""
|
||||
).strip()
|
||||
FORMAT_PROMPT = (
|
||||
"\n\nBased on your analysis, provide your response in one of the following JSON formats:"
|
||||
'\n1. If no functions are needed:\n```json\n{"response": "Your response text here"}\n```'
|
||||
'\n2. If functions are needed but some required parameters are missing:\n```json\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```'
|
||||
'\n3. If functions are needed and all required parameters are available:\n```json\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```'
|
||||
)
|
||||
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.6,
|
||||
"temperature": 0.1,
|
||||
"top_p": 1.0,
|
||||
"top_k": 10,
|
||||
"max_tokens": 1024,
|
||||
|
|
@ -203,34 +48,9 @@ class ArchFunctionConfig:
|
|||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
PREFILL_CONFIG = {
|
||||
"prefill_params": {
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
},
|
||||
"prefill_prefix": [
|
||||
"May",
|
||||
"Could",
|
||||
"Sure",
|
||||
"Definitely",
|
||||
"Certainly",
|
||||
"Of course",
|
||||
"Can",
|
||||
],
|
||||
}
|
||||
|
||||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
|
||||
|
||||
class ArchAgentConfig(ArchFunctionConfig):
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.01,
|
||||
"stop_token_ids": [151645],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
|
||||
class ArchFunctionHandler(ArchBaseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -251,13 +71,17 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
client,
|
||||
model_name,
|
||||
config.TASK_PROMPT,
|
||||
config.TOOL_PROMPT_TEMPLATE,
|
||||
config.FORMAT_PROMPT,
|
||||
config.GENERATION_PARAMS,
|
||||
)
|
||||
|
||||
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
|
||||
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
|
||||
self.generation_params = self.generation_params | {
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
}
|
||||
|
||||
self.default_prefix = '```json\n{"'
|
||||
self.clarify_prefix = '```json\n{"required_functions":'
|
||||
|
||||
self.hallucination_state = None
|
||||
|
||||
|
|
@ -280,7 +104,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
str: A string representation of converted tools.
|
||||
"""
|
||||
|
||||
converted = [json.dumps(tool) for tool in tools]
|
||||
converted = [json.dumps(tool["function"], ensure_ascii=False) for tool in tools]
|
||||
return "\n".join(converted)
|
||||
|
||||
def _fix_json_string(self, json_str: str) -> str:
|
||||
|
|
@ -328,10 +152,14 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str.replace("'", '"')
|
||||
try:
|
||||
fixed_str = json.loads(fixed_str)
|
||||
except Exception:
|
||||
fixed_str = json.loads(fixed_str.replace("'", '"'))
|
||||
|
||||
def _extract_tool_calls(self, content: str) -> Dict[str, any]:
|
||||
return json.dumps(fixed_str)
|
||||
|
||||
def _parse_model_response(self, content: str) -> Dict[str, any]:
|
||||
"""
|
||||
Extracts tool call information from a given string.
|
||||
|
||||
|
|
@ -340,49 +168,55 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
Returns:
|
||||
Dict: A dictionary of extraction, including:
|
||||
- "result": A list of tool call dictionaries.
|
||||
- "status": A boolean indicating if the extraction was valid.
|
||||
- "message": An error message or exception if extraction failed.
|
||||
- "required_functions": A list of detected intents.
|
||||
- "clarification": Text to collect missing parameters
|
||||
- "tool_calls": A list of tool call dictionaries.
|
||||
- "is_valid": A boolean indicating if the extraction was valid.
|
||||
- "error_message": An error message or exception if parsing failed.
|
||||
"""
|
||||
|
||||
tool_calls, is_valid, error_message = [], True, ""
|
||||
response_dict = {
|
||||
"raw_response": [],
|
||||
"response": [],
|
||||
"required_functions": [],
|
||||
"clarification": "",
|
||||
"tool_calls": [],
|
||||
"is_valid": True,
|
||||
"error_message": "",
|
||||
}
|
||||
|
||||
flag = False
|
||||
for line in content.split("\n"):
|
||||
if not is_valid:
|
||||
break
|
||||
try:
|
||||
if content.startswith("```") and content.endswith("```"):
|
||||
content = content.strip("```").strip()
|
||||
if content.startswith("json"):
|
||||
content = content[4:].strip()
|
||||
|
||||
if "<tool_call>" == line:
|
||||
flag = True
|
||||
elif "</tool_call>" == line:
|
||||
flag = False
|
||||
else:
|
||||
if flag:
|
||||
try:
|
||||
tool_content = json.loads(line)
|
||||
except Exception as e:
|
||||
fixed_content = self._fix_json_string(line)
|
||||
try:
|
||||
tool_content = json.loads(fixed_content)
|
||||
except Exception:
|
||||
is_valid, error_message = False, e
|
||||
break
|
||||
content = self._fix_json_string(content)
|
||||
response_dict["raw_response"] = f"```json\n{content}\n```"
|
||||
|
||||
tool = {
|
||||
model_response = json.loads(content)
|
||||
response_dict["response"] = model_response.get("response", "")
|
||||
response_dict["required_functions"] = model_response.get(
|
||||
"required_functions", []
|
||||
)
|
||||
response_dict["clarification"] = model_response.get("clarification", "")
|
||||
|
||||
for tool_call in model_response.get("tool_calls", []):
|
||||
response_dict["tool_calls"].append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
"name": tool_call.get("name", ""),
|
||||
"arguments": tool_call.get("arguments", {}),
|
||||
},
|
||||
}
|
||||
if "arguments" in tool_content:
|
||||
tool["function"]["arguments"] = tool_content["arguments"]
|
||||
)
|
||||
except Exception as e:
|
||||
response_dict["is_valid"] = False
|
||||
response_dict["error_message"] = f"Fail to parse model responses: {e}"
|
||||
|
||||
tool_calls.append(tool)
|
||||
|
||||
flag = False
|
||||
|
||||
return {"result": tool_calls, "status": is_valid, "message": error_message}
|
||||
return response_dict
|
||||
|
||||
def _convert_data_type(self, value: str, target_type: str):
|
||||
# TODO: Add more conversion rules as needed
|
||||
|
|
@ -414,36 +248,37 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
- "message": An error message.
|
||||
"""
|
||||
|
||||
is_valid, invalid_tool_call, error_message = True, None, ""
|
||||
verification_dict = {
|
||||
"is_valid": True,
|
||||
"invalid_tool_call": {},
|
||||
"error_message": "",
|
||||
}
|
||||
|
||||
functions = {}
|
||||
for tool in tools:
|
||||
if tool["type"] == "function":
|
||||
functions[tool["function"]["name"]] = tool["function"]["parameters"]
|
||||
functions[tool["function"]["name"]] = tool["function"]["parameters"]
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if not is_valid:
|
||||
if not verification_dict["is_valid"]:
|
||||
break
|
||||
|
||||
func_name = tool_call["function"]["name"]
|
||||
func_args = tool_call["function"].get("arguments")
|
||||
if not func_args:
|
||||
func_args = {}
|
||||
func_args = tool_call["function"]["arguments"]
|
||||
|
||||
# Check whether the function is available or not
|
||||
if func_name not in functions:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"{func_name} is not defined!"
|
||||
break
|
||||
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict["error_message"] = f"{func_name} is not available!"
|
||||
else:
|
||||
# Check if all the requried parameters can be found in the tool calls
|
||||
for required_param in functions[func_name].get("required", []):
|
||||
if required_param not in func_args:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||
break
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
|
|
@ -453,9 +288,11 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
logger.info(func_args)
|
||||
for param_name in func_args:
|
||||
if param_name not in function_properties:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||
break
|
||||
else:
|
||||
param_value = func_args[param_name]
|
||||
|
|
@ -469,22 +306,22 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
param_value, data_type
|
||||
)
|
||||
if not isinstance(param_value, data_type):
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
||||
break
|
||||
else:
|
||||
error_message = (
|
||||
f"Data type `{target_type}` is not supported."
|
||||
)
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Data type `{target_type}` is not supported."
|
||||
|
||||
return {
|
||||
"status": is_valid,
|
||||
"invalid_tool_call": invalid_tool_call,
|
||||
"message": error_message,
|
||||
}
|
||||
return verification_dict
|
||||
|
||||
def _add_prefill_message(self, messages: List[Dict[str, str]]):
|
||||
def _prefill_message(self, messages: List[Dict[str, str]], prefill_message):
|
||||
"""
|
||||
Update messages and generation params for prompt prefilling
|
||||
|
||||
|
|
@ -494,29 +331,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
Returns:
|
||||
prefill_messages (List[Dict[str, str]]): A list of messages.
|
||||
"""
|
||||
|
||||
return messages + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": random.choice(self.prefill_prefix),
|
||||
}
|
||||
]
|
||||
|
||||
def _engage_parameter_gathering(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Engage parameter gathering for tool calls
|
||||
"""
|
||||
|
||||
# TODO: log enaging parameter gathering
|
||||
prefill_response = self.client.chat.completions.create(
|
||||
messages=self._add_prefill_message(messages),
|
||||
model=self.model_name,
|
||||
extra_body={
|
||||
**self.generation_params,
|
||||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
return prefill_response
|
||||
return messages + [{"role": "assistant", "content": prefill_message}]
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
|
|
@ -544,7 +359,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
# always enable `stream=True` to collect model responses
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
messages=self._prefill_message(messages, self.default_prefix),
|
||||
model=self.model_name,
|
||||
stream=True,
|
||||
extra_body=self.generation_params,
|
||||
|
|
@ -565,72 +380,114 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
|
||||
if self.hallucination_state.tokens[0] == "<tool_call>":
|
||||
# check if moodel response starts with tool calls, we do it after 5 tokens because we only check the first part of the response.
|
||||
if len(self.hallucination_state.tokens) > 5 and has_tool_calls is None:
|
||||
content = "".join(self.hallucination_state.tokens)
|
||||
if "tool_calls" in content:
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
break
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
if self.hallucination_state.hallucination is True:
|
||||
has_hallucination = True
|
||||
break
|
||||
|
||||
if has_tool_calls:
|
||||
if has_hallucination:
|
||||
# start prompt prefilling if hallcuination is found in tool calls
|
||||
logger.info(
|
||||
f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
)
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = "".join(self.hallucination_state.tokens)
|
||||
if has_tool_calls and has_hallucination:
|
||||
# start prompt prefilling if hallcuination is found in tool calls
|
||||
logger.info(
|
||||
f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
messages=self._prefill_message(messages, self.clarify_prefix),
|
||||
model=self.model_name,
|
||||
stream=False,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
model_response = response.choices[0].message.content
|
||||
else:
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
model_response = "".join(self.hallucination_state.tokens)
|
||||
|
||||
# Extract tool calls from model response
|
||||
extracted = self._extract_tool_calls(model_response)
|
||||
response_dict = self._parse_model_response(model_response)
|
||||
logger.info(f"[arch-fc]: raw model response: {response_dict['raw_response']}")
|
||||
|
||||
if extracted["status"]:
|
||||
# Response with tool calls
|
||||
if len(extracted["result"]):
|
||||
verified = {}
|
||||
if use_agent_orchestrator:
|
||||
# skip tool call verification if using agent orchestrator
|
||||
verified = {"status": True, "message": ""}
|
||||
else:
|
||||
verified = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=extracted["result"]
|
||||
)
|
||||
|
||||
if verified["status"]:
|
||||
logger.info(
|
||||
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in extracted['result']])}"
|
||||
)
|
||||
model_response = Message(content="", tool_calls=extracted["result"])
|
||||
else:
|
||||
logger.error(f"Invalid tool call - {verified['message']}")
|
||||
# Response without tool calls
|
||||
# General model response
|
||||
if response_dict.get("response", ""):
|
||||
model_message = Message(content="", tool_calls=[])
|
||||
# Parameter gathering
|
||||
elif response_dict.get("required_functions", []):
|
||||
if not use_agent_orchestrator:
|
||||
clarification = response_dict.get("clarification", "")
|
||||
model_message = Message(content=clarification, tool_calls=[])
|
||||
else:
|
||||
model_response = Message(content=model_response, tool_calls=[])
|
||||
# Response with tool calls but contain errors
|
||||
model_message = Message(content="", tool_calls=[])
|
||||
# Function Calling
|
||||
elif response_dict.get("tool_calls", []):
|
||||
if response_dict["is_valid"]:
|
||||
if not use_agent_orchestrator:
|
||||
verification_dict = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=response_dict["tool_calls"]
|
||||
)
|
||||
|
||||
if verification_dict["is_valid"]:
|
||||
logger.info(
|
||||
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in response_dict['tool_calls']])}"
|
||||
)
|
||||
model_message = Message(
|
||||
content="", tool_calls=response_dict["tool_calls"]
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Invalid tool call - {verification_dict['error_message']}"
|
||||
)
|
||||
model_message = Message(content="", tool_calls=[])
|
||||
else:
|
||||
# skip tool call verification if using agent orchestrator
|
||||
logger.info(
|
||||
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in response_dict['tool_calls']])}"
|
||||
)
|
||||
model_message = Message(
|
||||
content="", tool_calls=response_dict["tool_calls"]
|
||||
)
|
||||
|
||||
else:
|
||||
# Response with tool calls but invalid
|
||||
model_message = Message(content="", tool_calls=[])
|
||||
# Response not in the desired format
|
||||
else:
|
||||
logger.error(f"Tool call extraction error - {extracted['message']}")
|
||||
logger.error(f"Invalid model response - {model_response}")
|
||||
model_message = Message(content="", tool_calls=[])
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_response)], model=self.model_name
|
||||
choices=[Choice(message=model_message)],
|
||||
model=self.model_name,
|
||||
metadata={"x-arch-fc-model-response": response_dict["raw_response"]},
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}")
|
||||
logger.info(
|
||||
f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump(exclude_none=True))}"
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
# ==============================================================================================================================================
|
||||
|
||||
|
||||
class ArchAgentConfig(ArchFunctionConfig):
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.01,
|
||||
"top_p": 1.0,
|
||||
"top_k": 10,
|
||||
"max_tokens": 1024,
|
||||
"stop_token_ids": [151645],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
|
||||
class ArchAgentHandler(ArchFunctionHandler):
|
||||
def __init__(self, client: OpenAI, model_name: str, config: ArchAgentConfig):
|
||||
super().__init__(client, model_name, config)
|
||||
|
|
@ -657,7 +514,7 @@ class ArchAgentHandler(ArchFunctionHandler):
|
|||
):
|
||||
tool_copy = copy.deepcopy(tool)
|
||||
del tool_copy["function"]["parameters"]
|
||||
converted.append(json.dumps(tool_copy))
|
||||
converted.append(json.dumps(tool_copy["function"], ensure_ascii=False))
|
||||
else:
|
||||
converted.append(json.dumps(tool))
|
||||
converted.append(json.dumps(tool["function"], ensure_ascii=False))
|
||||
return "\n".join(converted)
|
||||
|
|
|
|||
|
|
@ -13,16 +13,15 @@ from src.commons.utils import get_model_server_logger
|
|||
logger = get_model_server_logger()
|
||||
|
||||
# constants
|
||||
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
|
||||
FUNC_NAME_START_PATTERN = ('{"name":"', "{'name':'")
|
||||
FUNC_NAME_END_TOKEN = ('",', "',")
|
||||
TOOL_CALL_TOKEN = "<tool_call>"
|
||||
END_TOOL_CALL_TOKEN = "</tool_call>"
|
||||
END_TOOL_CALL_TOKEN = "}}"
|
||||
|
||||
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
|
||||
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
|
||||
PARAMETER_NAME_START_PATTERN = (',"', ",'")
|
||||
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'", '":"', "':'")
|
||||
PARAMETER_NAME_START_PATTERN = ('","', "','")
|
||||
PARAMETER_VALUE_START_PATTERN = ('":', "':")
|
||||
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
|
||||
PARAMETER_VALUE_END_TOKEN = ('",', '"}')
|
||||
|
||||
BRACKETS = {"(": ")", "{": "}", "[": "]"}
|
||||
|
||||
|
|
@ -37,16 +36,9 @@ class MaskToken(Enum):
|
|||
|
||||
|
||||
HALLUCINATION_THRESHOLD_DICT = {
|
||||
MaskToken.TOOL_CALL.value: {
|
||||
"entropy": 0.35,
|
||||
"varentropy": 1.7,
|
||||
"probability": 0.8,
|
||||
},
|
||||
MaskToken.PARAMETER_VALUE.value: {
|
||||
"entropy": 0.28,
|
||||
"varentropy": 1.4,
|
||||
"probability": 0.8,
|
||||
},
|
||||
"entropy": 0.0001,
|
||||
"varentropy": 0.0001,
|
||||
"probability": 0.8,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -160,6 +152,7 @@ class HallucinationState:
|
|||
self._process_function(function)
|
||||
self.open_bracket = False
|
||||
self.bracket = None
|
||||
self.function_name = ""
|
||||
self.check_parameter_name = {}
|
||||
self.HALLUCINATION_THRESHOLD_DICT = HALLUCINATION_THRESHOLD_DICT
|
||||
|
||||
|
|
@ -208,22 +201,20 @@ class HallucinationState:
|
|||
r = next(self.response_iterator)
|
||||
if hasattr(r.choices[0].delta, "content"):
|
||||
token_content = r.choices[0].delta.content
|
||||
if token_content:
|
||||
if token_content != "":
|
||||
try:
|
||||
logprobs = [
|
||||
p.logprob
|
||||
for p in r.choices[0].logprobs.content[0].top_logprobs
|
||||
]
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error extracting logprobs from response: {e}"
|
||||
)
|
||||
if token_content == END_TOOL_CALL_TOKEN:
|
||||
self._reset_parameters()
|
||||
else:
|
||||
self.append_and_check_token_hallucination(
|
||||
token_content, logprobs
|
||||
)
|
||||
except Exception as e:
|
||||
self.append_and_check_token_hallucination(
|
||||
token_content, [None]
|
||||
)
|
||||
|
||||
return token_content
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
|
@ -234,12 +225,12 @@ class HallucinationState:
|
|||
Detects hallucinations based on the token type and log probabilities.
|
||||
"""
|
||||
content = "".join(self.tokens).replace(" ", "")
|
||||
if self.tokens[-1] == TOOL_CALL_TOKEN:
|
||||
self.mask.append(MaskToken.TOOL_CALL)
|
||||
self._check_logprob()
|
||||
|
||||
# Function name extraction logic
|
||||
# If the state is function name and the token is not an end token, add to the mask
|
||||
if content.endswith(END_TOOL_CALL_TOKEN):
|
||||
self._reset_parameters()
|
||||
|
||||
if self.state == "function_name":
|
||||
if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
|
||||
self.mask.append(MaskToken.FUNCTION_NAME)
|
||||
|
|
@ -359,7 +350,7 @@ class HallucinationState:
|
|||
if check_threshold(
|
||||
entropy,
|
||||
varentropy,
|
||||
self.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value],
|
||||
self.HALLUCINATION_THRESHOLD_DICT,
|
||||
):
|
||||
self.hallucination = True
|
||||
self.error_message = f"token '{self.tokens[-1]}' is uncertain. Generated response:\n{''.join(self.tokens)}"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import src.commons.utils as utils
|
||||
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -56,7 +57,6 @@ class ArchBaseHandler:
|
|||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt_template: str,
|
||||
format_prompt: str,
|
||||
generation_params: Dict,
|
||||
):
|
||||
|
|
@ -67,7 +67,6 @@ class ArchBaseHandler:
|
|||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
"""
|
||||
|
|
@ -75,7 +74,6 @@ class ArchBaseHandler:
|
|||
self.model_name = model_name
|
||||
|
||||
self.task_prompt = task_prompt
|
||||
self.tool_prompt_template = tool_prompt_template
|
||||
self.format_prompt = format_prompt
|
||||
|
||||
self.generation_params = generation_params
|
||||
|
|
@ -105,13 +103,11 @@ class ArchBaseHandler:
|
|||
str: A formatted system prompt.
|
||||
"""
|
||||
|
||||
tool_text = self._convert_tools(tools)
|
||||
today_date = utils.get_today_date()
|
||||
tools = self._convert_tools(tools)
|
||||
|
||||
system_prompt = (
|
||||
self.task_prompt
|
||||
+ "\n\n"
|
||||
+ self.tool_prompt_template.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
self.task_prompt.format(today_date=today_date, tools=tools)
|
||||
+ self.format_prompt
|
||||
)
|
||||
|
||||
|
|
@ -146,7 +142,7 @@ class ArchBaseHandler:
|
|||
{"role": "system", "content": self._format_system_prompt(tools)}
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
for idx, message in enumerate(messages):
|
||||
role, content, tool_calls = (
|
||||
message.role,
|
||||
message.content,
|
||||
|
|
@ -162,9 +158,24 @@ class ArchBaseHandler:
|
|||
if metadata.get("optimize_context_window", "false").lower() == "true":
|
||||
content = f"<tool_response>\n\n</tool_response>"
|
||||
else:
|
||||
content = (
|
||||
f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
|
||||
# sample response below
|
||||
# "content": "<tool_response>\n{'name': 'get_stock_price', 'result': '$196.66'}\n</tool_response>"
|
||||
# msg[idx-1] contains tool call = '{"tool_calls": [{"name": "currency_exchange", "arguments": {"currency_symbol": "NZD"}}]}'
|
||||
tool_call_msg = messages[idx - 1].content
|
||||
if tool_call_msg.startswith("```") and tool_call_msg.endswith(
|
||||
"```"
|
||||
):
|
||||
tool_call_msg = tool_call_msg.strip("```").strip()
|
||||
if tool_call_msg.startswith("json"):
|
||||
tool_call_msg = tool_call_msg[4:].strip()
|
||||
func_name = json.loads(tool_call_msg)["tool_calls"][0].get(
|
||||
"name", "no_name"
|
||||
)
|
||||
tool_response = {
|
||||
"name": func_name,
|
||||
"result": content,
|
||||
}
|
||||
content = f"<tool_response>\n{json.dumps(tool_response)}\n</tool_response>"
|
||||
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
|
|
|
|||
|
|
@ -71,67 +71,58 @@ async def models():
|
|||
@app.post("/function_calling")
|
||||
async def function_calling(req: ChatMessage, res: Response):
|
||||
logger.info("[Endpoint: /function_calling]")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
|
||||
|
||||
final_response: ChatCompletionResponse = None
|
||||
error_messages = None
|
||||
|
||||
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
|
||||
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
|
||||
|
||||
try:
|
||||
intent_detected = False
|
||||
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
|
||||
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
|
||||
if not use_agent_orchestrator:
|
||||
intent_start_time = time.perf_counter()
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
intent_latency = time.perf_counter() - intent_start_time
|
||||
intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response)
|
||||
handler_name = "Arch-Agent" if use_agent_orchestrator else "Arch-Function"
|
||||
model_handler: ArchFunctionHandler = handler_map[handler_name]
|
||||
|
||||
if use_agent_orchestrator or intent_detected:
|
||||
# TODO: measure agreement between intent detection and function calling
|
||||
try:
|
||||
function_start_time = time.perf_counter()
|
||||
handler_name = (
|
||||
"Arch-Agent" if use_agent_orchestrator else "Arch-Function"
|
||||
start_time = time.perf_counter()
|
||||
final_response = await model_handler.chat_completion(req)
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
if not final_response.metadata:
|
||||
final_response.metadata = {}
|
||||
|
||||
# Parameter gathering for detected intents
|
||||
if final_response.choices[0].message.content:
|
||||
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
|
||||
# Function Calling
|
||||
elif final_response.choices[0].message.tool_calls:
|
||||
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
|
||||
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["hallucination"] = str(
|
||||
model_handler.hallucination_state.hallucination
|
||||
)
|
||||
function_calling_handler: ArchFunctionHandler = handler_map[
|
||||
handler_name
|
||||
]
|
||||
final_response = await function_calling_handler.chat_completion(req)
|
||||
function_latency = time.perf_counter() - function_start_time
|
||||
|
||||
final_response.metadata = {
|
||||
"function_latency": str(round(function_latency * 1000, 3)),
|
||||
}
|
||||
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["intent_latency"] = str(
|
||||
round(intent_latency * 1000, 3)
|
||||
)
|
||||
final_response.metadata["hallucination"] = str(
|
||||
function_calling_handler.hallucination_state.hallucination
|
||||
)
|
||||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
error_messages = (
|
||||
f"[{handler_name}] - Error in tool call extraction: {e}"
|
||||
)
|
||||
except StopIteration as e:
|
||||
res.statuscode = 500
|
||||
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
|
||||
except Exception as e:
|
||||
res.status_code = 500
|
||||
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
|
||||
raise
|
||||
# No intent detected
|
||||
else:
|
||||
# no intent matched
|
||||
intent_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
}
|
||||
final_response = intent_response
|
||||
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
|
||||
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
|
||||
|
||||
final_response.metadata["hallucination"] = str(
|
||||
model_handler.hallucination_state.hallucination
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
error_messages = f"[{handler_name}] - Error in tool call extraction: {e}"
|
||||
raise
|
||||
except StopIteration as e:
|
||||
res.statuscode = 500
|
||||
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
|
||||
raise
|
||||
except Exception as e:
|
||||
res.status_code = 500
|
||||
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
|
||||
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
|
||||
raise
|
||||
|
||||
if error_messages is not None:
|
||||
|
|
@ -144,7 +135,7 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
@app.post("/guardrails")
|
||||
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
|
||||
logger.info("[Endpoint: /guardrails] - Gateway")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
|
||||
|
||||
final_response: GuardResponse = None
|
||||
error_messages = None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
|
||||
import time
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.utils.model_utils import ChatMessage, Message
|
||||
|
||||
|
|
@ -37,26 +37,9 @@ get_weather_api = {
|
|||
# get_data class return request, intent, hallucination, parameter_gathering
|
||||
|
||||
|
||||
def get_hallucination_data_complex():
|
||||
def get_hallucination_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
message2 = Message(
|
||||
role="assistant", content="Can you specify the unit you want the weather in?"
|
||||
)
|
||||
message3 = Message(role="user", content="In celcius please!")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
|
||||
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_medium():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in?")
|
||||
message1 = Message(role="user", content="How is the weather in Seattle in days?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
|
@ -65,26 +48,10 @@ def get_hallucination_data_medium():
|
|||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# first token will not be tool call
|
||||
return req, True, True, True
|
||||
return req, False, True
|
||||
|
||||
|
||||
def get_complete_data_2():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(
|
||||
role="user",
|
||||
content="what is the weather forecast for seattle in the next 10 days?",
|
||||
)
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
|
||||
|
||||
def get_complete_data():
|
||||
def get_success_tool_call_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
|
||||
|
||||
|
|
@ -94,7 +61,7 @@ def get_complete_data():
|
|||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
return req, True, False
|
||||
|
||||
|
||||
def get_irrelevant_data():
|
||||
|
|
@ -107,7 +74,7 @@ def get_irrelevant_data():
|
|||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
return req, False, False
|
||||
|
||||
|
||||
def get_greeting_data():
|
||||
|
|
@ -120,38 +87,29 @@ def get_greeting_data():
|
|||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
return req, False, False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"get_data_func",
|
||||
[
|
||||
get_hallucination_data_complex,
|
||||
get_complete_data,
|
||||
get_hallucination_data,
|
||||
get_greeting_data,
|
||||
get_irrelevant_data,
|
||||
get_complete_data_2,
|
||||
get_success_tool_call_data,
|
||||
],
|
||||
)
|
||||
async def test_function_calling(get_data_func):
|
||||
req, intent, hallucination, parameter_gathering = get_data_func()
|
||||
req, intent, hallucination = get_data_func()
|
||||
handler_name = "Arch-Function"
|
||||
use_agent_orchestrator = False
|
||||
model_handler: ArchFunctionHandler = handler_map[handler_name]
|
||||
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
start_time = time.perf_counter()
|
||||
final_response = await model_handler.chat_completion(req)
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
|
||||
assert intent == (len(final_response.choices[0].message.tool_calls) >= 1)
|
||||
|
||||
if intent:
|
||||
function_calling_response = await handler_map["Arch-Function"].chat_completion(
|
||||
req
|
||||
)
|
||||
assert (
|
||||
handler_map["Arch-Function"].hallucination_state.hallucination
|
||||
== hallucination
|
||||
)
|
||||
response_txt = function_calling_response.choices[0].message.content
|
||||
|
||||
if parameter_gathering:
|
||||
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
|
||||
assert any(
|
||||
response_txt.startswith(prefix) for prefix in prefill_prefix
|
||||
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"
|
||||
assert hallucination == model_handler.hallucination_state.hallucination
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from src.commons.globals import handler_map
|
||||
from src.core.function_calling import Message
|
||||
from src.core.function_calling import ArchFunctionHandler, Message
|
||||
|
||||
|
||||
test_input_history = [
|
||||
|
|
@ -7,34 +7,19 @@ test_input_history = [
|
|||
{
|
||||
"role": "assistant",
|
||||
"model": "Arch-Function",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3394",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
"content": '```json\n{"tool_calls": [{"name": "get_current_weather", "arguments": {"days": 5, "location": "Chicago, Illinois"}}]}\n```',
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"model": "Arch-Function",
|
||||
"content": '{"location":"Chicago%2C%20Illinois","temperature":[{"date":"2025-04-14","temperature":{"min":53,"max":65},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432817+00:00"},{"date":"2025-04-15","temperature":{"min":85,"max":97},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432830+00:00"},{"date":"2025-04-16","temperature":{"min":62,"max":78},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432835+00:00"},{"date":"2025-04-17","temperature":{"min":89,"max":101},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432839+00:00"},{"date":"2025-04-18","temperature":{"min":86,"max":104},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432843+00:00"}],"units":"Farenheit"}',
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_3394"},
|
||||
{"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"},
|
||||
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_5306",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
"model": "gpt-4o-2024-08-06",
|
||||
"content": '{"response": "Based on the forecast data you provided, here is the weather for the next 5 days in Chicago:\\n\\n- **April 14, 2025**: The temperature will range between 53\\u00b0F and 65\\u00b0F. \\n- **April 15, 2025**: The temperature will range between 85\\u00b0F and 97\\u00b0F.\\n- **April 16, 2025**: The temperature will range between 62\\u00b0F and 78\\u00b0F.\\n- **April 17, 2025**: The temperature will range between 89\\u00b0F and 101\\u00b0F.\\n- **April 18, 2025**: The temperature will range between 86\\u00b0F and 104\\u00b0F.\\n\\nPlease note that the temperatures are given in Fahrenheit."}',
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_5306"},
|
||||
{"role": "user", "content": "what about seattle?"},
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -44,7 +29,8 @@ def test_update_fc_history():
|
|||
for h in test_input_history:
|
||||
message_history.append(Message(**h))
|
||||
|
||||
updated_history = handler_map["Arch-Function"]._process_messages(message_history)
|
||||
assert len(updated_history) == 7
|
||||
handler: ArchFunctionHandler = handler_map["Arch-Function"]
|
||||
updated_history = handler._process_messages(message_history)
|
||||
assert len(updated_history) == 5
|
||||
# ensure that tool role does not exist anymore
|
||||
assert all([h["role"] != "tool" for h in updated_history])
|
||||
|
|
|
|||
|
|
@ -47,14 +47,11 @@ TEST_CASE_FIXTURES = {
|
|||
"tool_call_id": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_6009",
|
||||
"id": "call_2925",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": {
|
||||
"location": "Seattle, WA",
|
||||
"days": "2",
|
||||
},
|
||||
"arguments": {"location": "Seattle", "days": "2"},
|
||||
},
|
||||
}
|
||||
],
|
||||
|
|
@ -63,7 +60,11 @@ TEST_CASE_FIXTURES = {
|
|||
}
|
||||
],
|
||||
"model": "Arch-Function",
|
||||
"metadata": {"intent_latency": "455.092", "function_latency": "312.744"},
|
||||
"metadata": {
|
||||
"x-arch-fc-model-response": '{"tool_calls": [{"name": "get_current_weather", "arguments": {"location": "Seattle", "days": "2"}}]}',
|
||||
"function_latency": "361.841",
|
||||
"intent_latency": "361.841",
|
||||
},
|
||||
},
|
||||
"api_server_response": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
import pytest
|
||||
import requests
|
||||
from deepdiff import DeepDiff
|
||||
import re
|
||||
|
||||
from common import (
|
||||
PROMPT_GATEWAY_ENDPOINT,
|
||||
|
|
@ -11,6 +12,15 @@ from common import (
|
|||
)
|
||||
|
||||
|
||||
def cleanup_tool_call(tool_call):
|
||||
pattern = r"```json\n(.*?)\n```"
|
||||
match = re.search(pattern, tool_call, re.DOTALL)
|
||||
if match:
|
||||
tool_call = match.group(1)
|
||||
|
||||
return tool_call.strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_prompt_gateway(stream):
|
||||
expected_tool_call = {
|
||||
|
|
@ -42,9 +52,14 @@ def test_prompt_gateway(stream):
|
|||
assert "role" in choices[0]["delta"]
|
||||
role = choices[0]["delta"]["role"]
|
||||
assert role == "assistant"
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
print(f"choices: {choices}")
|
||||
tool_call_str = choices[0].get("delta", {}).get("content", "")
|
||||
print("tool_call_str: ", tool_call_str)
|
||||
cleaned_tool_call_str = cleanup_tool_call(tool_call_str)
|
||||
print("cleaned_tool_call_str: ", cleaned_tool_call_str)
|
||||
tool_calls = json.loads(cleaned_tool_call_str).get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
tool_call = tool_calls[0]
|
||||
location = tool_call["arguments"]["location"]
|
||||
assert expected_tool_call["arguments"]["location"] in location.lower()
|
||||
del expected_tool_call["arguments"]["location"]
|
||||
|
|
@ -62,7 +77,7 @@ def test_prompt_gateway(stream):
|
|||
|
||||
# third..end chunk is summarization (role = assistant)
|
||||
response_json = json.loads(chunks[2])
|
||||
assert response_json.get("model").startswith("llama-3.2-3b-preview")
|
||||
assert response_json.get("model").startswith("gpt-4o")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["delta"]
|
||||
|
|
@ -71,18 +86,24 @@ def test_prompt_gateway(stream):
|
|||
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("llama-3.2-3b-preview")
|
||||
assert response_json.get("model").startswith("gpt-4o")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["message"]
|
||||
assert choices[0]["message"]["role"] == "assistant"
|
||||
# now verify arch_messages (tool call and api response) that are sent as response metadata
|
||||
arch_messages = get_arch_messages(response_json)
|
||||
print("arch_messages: ", json.dumps(arch_messages))
|
||||
assert len(arch_messages) == 2
|
||||
tool_calls_message = arch_messages[0]
|
||||
tool_calls = tool_calls_message.get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
print("tool_calls_message: ", tool_calls_message)
|
||||
tool_calls = tool_calls_message.get("content", [])
|
||||
cleaned_tool_call_str = cleanup_tool_call(tool_calls)
|
||||
cleaned_tool_call_json = json.loads(cleaned_tool_call_str)
|
||||
print("cleaned_tool_call_json: ", json.dumps(cleaned_tool_call_json))
|
||||
tool_calls_list = cleaned_tool_call_json.get("tool_calls", [])
|
||||
assert len(tool_calls_list) > 0
|
||||
tool_call = tool_calls_list[0]
|
||||
location = tool_call["arguments"]["location"]
|
||||
assert expected_tool_call["arguments"]["location"] in location.lower()
|
||||
del expected_tool_call["arguments"]["location"]
|
||||
|
|
@ -231,7 +252,7 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
|
||||
# third..end chunk is summarization (role = assistant)
|
||||
response_json = json.loads(chunks[2])
|
||||
assert response_json.get("model").startswith("llama-3.2-3b-preview")
|
||||
assert response_json.get("model").startswith("gpt-4o")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["delta"]
|
||||
|
|
@ -240,7 +261,7 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("llama-3.2-3b-preview")
|
||||
assert response_json.get("model").startswith("gpt-4o")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["message"]
|
||||
|
|
@ -262,7 +283,7 @@ def test_prompt_gateway_default_target(stream):
|
|||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello, what can you do for me?",
|
||||
"content": "hello",
|
||||
},
|
||||
],
|
||||
"stream": stream,
|
||||
|
|
@ -273,17 +294,20 @@ def test_prompt_gateway_default_target(stream):
|
|||
chunks = get_data_chunks(response, n=3)
|
||||
assert len(chunks) > 0
|
||||
response_json = json.loads(chunks[0])
|
||||
print("response_json chunks[0]: ", response_json)
|
||||
assert response_json.get("model").startswith("api_server")
|
||||
assert len(response_json.get("choices", [])) > 0
|
||||
assert response_json.get("choices")[0]["delta"]["role"] == "assistant"
|
||||
|
||||
response_json = json.loads(chunks[1])
|
||||
print("response_json chunks[1]: ", response_json)
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
content = choices[0]["delta"]["content"]
|
||||
assert content == "I can help you with weather forecast"
|
||||
else:
|
||||
response_json = response.json()
|
||||
print("response_json: ", response_json)
|
||||
assert response_json.get("model").startswith("api_server")
|
||||
assert len(response_json.get("choices")) > 0
|
||||
assert response_json.get("choices")[0]["message"]["role"] == "assistant"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,9 @@ import requests
|
|||
import logging
|
||||
import yaml
|
||||
|
||||
pytestmark = pytest.mark.skip(
|
||||
reason="Skipping entire test file as hallucination is not enabled for archfc 1.1 yet"
|
||||
)
|
||||
|
||||
MODEL_SERVER_ENDPOINT = os.getenv(
|
||||
"MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@ import yaml
|
|||
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
pytestmark = pytest.mark.skip(
|
||||
reason="Skipping entire test file as this these tests are heavily dependent on model output"
|
||||
)
|
||||
|
||||
MODEL_SERVER_ENDPOINT = os.getenv(
|
||||
"MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling"
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ Content-Type: application/json
|
|||
{
|
||||
"role": "assistant",
|
||||
"content": "It seems I'm missing some information. Could you provide the following details days ?",
|
||||
"model": "Arch-Function-1.5b"
|
||||
"model": "Arch-Function"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -91,7 +91,7 @@ Content-Type: application/json
|
|||
{
|
||||
"role": "assistant",
|
||||
"content": "It seems I'm missing some information. Could you provide the following details days ?",
|
||||
"model": "Arch-Function-1.5b"
|
||||
"model": "Arch-Function"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue