Integrate Arch-Function-Chat (#449)

This commit is contained in:
Shuguang Chen 2025-04-15 14:39:12 -07:00 committed by GitHub
parent f31aa59fac
commit 7d4b261a68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 558 additions and 603 deletions

View file

@ -13,8 +13,11 @@ pub const MESSAGES_KEY: &str = "messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: [&str; 2] = ["/v1/chat/completions", "/openai/v1/chat/completions"];
pub const HEALTHZ_PATH: &str = "/healthz";
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message";
pub const X_ARCH_TOOL_CALL: &str = "x-arch-tool-call-message";
pub const X_ARCH_FC_MODEL_RESPONSE: &str = "x-arch-fc-model-response";
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function";
pub const REQUEST_ID_HEADER: &str = "x-request-id";
pub const TRACE_PARENT_HEADER: &str = "traceparent";
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";

View file

@ -50,8 +50,7 @@ pub trait Client: Context {
) -> Result<u32, ClientError> {
debug!(
"dispatching http call with args={:?} context={:?}",
call_args,
call_context
call_args, call_context
);
match self.dispatch_http_call(

View file

@ -101,9 +101,7 @@ impl RatelimitMap {
) -> Result<(), Error> {
debug!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider,
selector,
tokens_used
provider, selector, tokens_used
);
let provider_limits = match self.datastore.get(&provider) {

View file

@ -1,4 +1,4 @@
use log::{debug};
use log::debug;
#[allow(dead_code)]
pub fn token_count(model_name: &str, text: &str) -> Result<usize, String> {

View file

@ -428,7 +428,7 @@ impl HttpContext for StreamContext {
);
if self.request_body_sent_time.is_none() {
debug!("on_http_response_body: request body not sent, no doing any processing in llm filter");
debug!("on_http_response_body: request body not sent, not doing any processing in llm filter");
return Action::Continue;
}

View file

@ -4,10 +4,11 @@ use common::{
self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
},
consts::{
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER,
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE,
TRACE_PARENT_HEADER, USER_ROLE,
TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE,
X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL,
},
errors::ServerError,
http::{CallArgs, Client},
@ -125,8 +126,8 @@ impl HttpContext for StreamContext {
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
if metadata.contains_key(ARCH_STATE_HEADER) {
let arch_state_str = metadata[ARCH_STATE_HEADER].clone();
if metadata.contains_key(X_ARCH_STATE_HEADER) {
let arch_state_str = metadata[X_ARCH_STATE_HEADER].clone();
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
Some(arch_state)
} else {
@ -336,10 +337,10 @@ impl HttpContext for StreamContext {
if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
self.arch_fc_response.clone(),
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_string()),
self.tool_calls.to_owned(),
None,
),
ChatCompletionStreamResponse::new(
self.tool_call_response.clone(),
@ -381,17 +382,39 @@ impl HttpContext for StreamContext {
*metadata = Value::Object(serde_json::Map::new());
}
let fc_messages = vec![
self.generate_toll_call_message(),
self.generate_api_response_message(),
];
let tool_call_message = self.generate_tool_call_message();
let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap();
metadata.as_object_mut().unwrap().insert(
X_ARCH_TOOL_CALL.to_string(),
serde_json::Value::String(tool_call_message_str),
);
let api_response_message = self.generate_api_response_message();
let api_response_message_str =
serde_json::to_string(&api_response_message).unwrap();
metadata.as_object_mut().unwrap().insert(
X_ARCH_API_RESPONSE.to_string(),
serde_json::Value::String(api_response_message_str),
);
let fc_messages = vec![tool_call_message, api_response_message];
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
metadata.as_object_mut().unwrap().insert(
ARCH_STATE_HEADER.to_string(),
X_ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
if let Some(arch_fc_response) = self.arch_fc_response.as_ref() {
metadata.as_object_mut().unwrap().insert(
X_ARCH_FC_MODEL_RESPONSE.to_string(),
serde_json::Value::String(
serde_json::to_string(arch_fc_response).unwrap(),
),
);
}
let data_serialized = serde_json::to_string(&data).unwrap();
info!("archgw <= developer: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());

View file

@ -9,6 +9,7 @@ use common::consts::{
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY,
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
X_ARCH_FC_MODEL_RESPONSE,
};
use common::errors::ServerError;
use common::http::{CallArgs, Client};
@ -64,10 +65,10 @@ pub struct StreamContext {
pub time_to_first_token: Option<u128>,
pub traceparent: Option<String>,
pub _tracing: Rc<Option<Tracing>>,
pub arch_fc_response: Option<String>,
}
impl StreamContext {
#[allow(clippy::too_many_arguments)]
pub fn new(
context_id: u32,
metrics: Rc<Metrics>,
@ -98,6 +99,7 @@ impl StreamContext {
_tracing: tracing,
start_upstream_llm_request_time: 0,
time_to_first_token: None,
arch_fc_response: None,
}
}
@ -142,15 +144,16 @@ impl StreamContext {
}
};
// intent was matched if we see function_latency in metadata
let intent_matched = model_server_response
let intent_matched = check_intent_matched(&model_server_response);
info!("intent matched: {}", intent_matched);
self.arch_fc_response = model_server_response
.metadata
.as_ref()
.and_then(|metadata| metadata.get("function_latency"))
.is_some();
.and_then(|metadata| metadata.get(X_ARCH_FC_MODEL_RESPONSE))
.cloned();
if !intent_matched {
info!("intent not matched");
// check if we have a default prompt target
if let Some(default_prompt_target) = self
.prompt_targets
@ -278,9 +281,9 @@ impl StreamContext {
let direct_response_str = if self.streaming_response {
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
self.arch_fc_response.clone(),
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
Some(ARCH_FC_MODEL_NAME.to_string()),
None,
),
ChatCompletionStreamResponse::new(
@ -293,7 +296,7 @@ impl StreamContext {
.clone(),
),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())),
None,
),
];
@ -623,13 +626,24 @@ impl StreamContext {
messages
}
pub fn generate_toll_call_message(&mut self) -> Message {
Message {
role: ASSISTANT_ROLE.to_string(),
content: None,
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: self.tool_calls.clone(),
tool_call_id: None,
pub fn generate_tool_call_message(&mut self) -> Message {
if self.arch_fc_response.is_none() {
info!("arch_fc_response is none, generating tool call message");
Message {
role: ASSISTANT_ROLE.to_string(),
content: None,
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: self.tool_calls.clone(),
tool_call_id: None,
}
} else {
Message {
role: ASSISTANT_ROLE.to_string(),
content: self.arch_fc_response.as_ref().cloned(),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
}
}
}
@ -761,6 +775,23 @@ impl StreamContext {
}
}
fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool {
let content = model_server_response
.choices.first()
.and_then(|choice| choice.message.content.as_ref());
let content_has_value = content.is_some() && !content.unwrap().is_empty();
let tool_calls = model_server_response
.choices.first()
.and_then(|choice| choice.message.tool_calls.as_ref());
// intent was matched if content has some value or tool_calls is empty
content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty())
}
impl Client for StreamContext {
type CallContext = StreamCallContext;
@ -772,3 +803,77 @@ impl Client for StreamContext {
&self.metrics.active_http_calls
}
}
#[cfg(test)]
mod test {
use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall};
use crate::stream_context::check_intent_matched;
#[test]
fn test_intent_matched() {
let model_server_response = ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
content: Some("".to_string()),
tool_calls: Some(vec![]),
role: "assistant".to_string(),
model: None,
tool_call_id: None,
},
finish_reason: None,
index: None,
}],
usage: None,
model: "arch-fc".to_string(),
metadata: None,
};
assert!(!check_intent_matched(&model_server_response));
let model_server_response = ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
content: Some("hello".to_string()),
tool_calls: Some(vec![]),
role: "assistant".to_string(),
model: None,
tool_call_id: None,
},
finish_reason: None,
index: None,
}],
usage: None,
model: "arch-fc".to_string(),
metadata: None,
};
assert!(check_intent_matched(&model_server_response));
let model_server_response = ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
content: Some("".to_string()),
tool_calls: Some(vec![ToolCall {
id: "1".to_string(),
function: common::api::open_ai::FunctionCallDetail {
name: "test".to_string(),
arguments: None,
},
tool_type: common::api::open_ai::ToolType::Function,
}]),
role: "assistant".to_string(),
model: None,
tool_call_id: None,
},
finish_reason: None,
index: None,
}],
usage: None,
model: "arch-fc".to_string(),
metadata: None,
};
assert!(check_intent_matched(&model_server_response));
}
}

View file

@ -380,6 +380,7 @@ fn prompt_gateway_request_to_llm_gateway() {
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
@ -453,6 +454,7 @@ fn prompt_gateway_request_to_llm_gateway() {
.expect_log(Some(LogLevel::Info), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
@ -493,19 +495,9 @@ fn prompt_gateway_request_no_intent_match() {
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "system".to_string(),
role: "assistant".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: Some(HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)])),
},
}]),
tool_calls: None,
model: None,
tool_call_id: None,
},
@ -523,7 +515,7 @@ fn prompt_gateway_request_no_intent_match() {
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), Some("intent not matched"))
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
.expect_log(
Some(LogLevel::Info),
Some("no default prompt target found, forwarding request to upstream llm"),
@ -651,17 +643,7 @@ fn prompt_gateway_request_no_intent_match_default_target() {
message: Message {
role: "system".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: Some(HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)])),
},
}]),
tool_calls: None,
model: None,
tool_call_id: None,
},
@ -679,7 +661,7 @@ fn prompt_gateway_request_no_intent_match_default_target() {
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), Some("intent not matched"))
.expect_log(Some(LogLevel::Info), Some("intent matched: false"))
.expect_log(
Some(LogLevel::Info),
Some("default prompt target found, forwarding request to default prompt target"),