fixing some edge cases with calls made to Arch-Function

This commit is contained in:
Salman Paracha 2025-11-20 03:58:13 -08:00
parent 00f95c93f3
commit 9922fe0cb9
3 changed files with 187 additions and 144 deletions

View file

@ -7,17 +7,12 @@ use serde_json::{json, Value};
use std::collections::HashMap;
use thiserror::Error;
use tracing::{info, error};
use async_openai::{Client as OpenAIClient, config::OpenAIConfig};
use async_openai::types::{
CreateChatCompletionRequestArgs, ChatCompletionRequestMessage,
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestAssistantMessage,
};
use futures::StreamExt;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use eventsource_stream::Eventsource;
@ -103,17 +98,17 @@ pub struct ArchFunctionConfig {
impl Default for ArchFunctionConfig {
fn default() -> Self {
Self {
task_prompt: String::from(
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.\
\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tools}\n</tools>\
\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
),
format_prompt: String::from(
"\n\nBased on your analysis, provide your response in one of the following JSON formats:\
\n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```\
\n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```\
\n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```"
),
task_prompt: concat!(
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.",
"\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tools}\n</tools>",
"\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
).to_string(),
format_prompt: concat!(
"\n\nBased on your analysis, provide your response in one of the following JSON formats:",
"\n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```",
"\n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```",
"\n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```"
).to_string(),
generation_params: GenerationParams::default(),
support_data_types: vec![
"int".to_string(),
@ -170,9 +165,9 @@ impl Default for ArchAgentConfig {
pub struct GenerationParams {
pub temperature: f32,
pub top_p: f32,
pub top_k: i32,
pub top_k: u32,
pub max_tokens: u32,
pub stop_token_ids: Vec<i32>,
pub stop_token_ids: Vec<u32>,
pub logprobs: Option<bool>,
pub top_logprobs: Option<u32>,
}
@ -235,7 +230,8 @@ pub struct ArchFunctionHandler {
pub config: ArchFunctionConfig,
pub default_prefix: String,
pub clarify_prefix: String,
pub openai_client: OpenAIClient<OpenAIConfig>,
pub endpoint_url: String,
pub http_client: reqwest::Client,
}
impl ArchFunctionHandler {
@ -256,16 +252,13 @@ impl ArchFunctionHandler {
.build()
.expect("Failed to create HTTP client");
// Configure OpenAI client to use custom endpoint
let openai_config = OpenAIConfig::new()
.with_api_base(endpoint_url);
Self {
model_name,
config,
default_prefix: "```json\n{\"".to_string(),
clarify_prefix: "```json\n{\"required_functions\":".to_string(),
openai_client: OpenAIClient::with_config(openai_config).with_http_client(http_client),
endpoint_url,
http_client,
}
}
@ -757,17 +750,25 @@ impl ArchFunctionHandler {
}
// Calculate from the end backwards
// Start with message_idx pointing past the end (will be used if no truncation needed)
let mut message_idx = messages.len();
for i in (conversation_idx..messages.len()).rev() {
if let MessageContent::Text(content) = &messages[i].content {
num_tokens += content.len() / 4;
if num_tokens >= max_tokens {
if messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
}
}
message_idx = i;
// Only update message_idx if we haven't hit the token limit yet
// This ensures message_idx points to where truncation should start
if num_tokens < max_tokens {
message_idx = i;
}
}
// Return system message + truncated conversation
@ -792,52 +793,99 @@ impl ArchFunctionHandler {
messages
}
/// Converts internal Message format to async-openai's ChatCompletionRequestMessage format
fn convert_to_openai_messages(&self, messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
let mut openai_messages = Vec::new();
/// Helper to create a request with VLLM-specific parameters
fn create_request_with_extra_body(&self, messages: Vec<Message>, stream: bool) -> ChatCompletionsRequest {
ChatCompletionsRequest {
model: self.model_name.clone(),
messages,
temperature: Some(self.config.generation_params.temperature),
top_p: Some(self.config.generation_params.top_p),
max_tokens: Some(self.config.generation_params.max_tokens),
stream: Some(stream),
logprobs: self.config.generation_params.logprobs,
top_logprobs: self.config.generation_params.top_logprobs,
// VLLM-specific parameters
continue_final_message: Some(true),
add_generation_prompt: Some(false),
top_k: Some(self.config.generation_params.top_k),
stop_token_ids: if !self.config.generation_params.stop_token_ids.is_empty() {
Some(self.config.generation_params.stop_token_ids.clone())
} else {
None
},
..Default::default()
}
}
for message in messages {
let content_str = match &message.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(_) => String::new(), // Handle parts if needed
};
/// Makes a streaming request and returns the SSE event stream
async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>> {
let request_body = serde_json::to_string(&request)
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
let openai_message = match message.role {
Role::System => {
ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: content_str.into(),
name: message.name.clone(),
}
)
},
Role::User | Role::Tool => {
// Convert both user and tool roles to user messages
ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: content_str.into(),
name: message.name.clone(),
}
)
},
Role::Assistant => {
#[allow(deprecated)]
let msg = ChatCompletionRequestAssistantMessage {
content: Some(content_str.into()),
name: message.name.clone(),
tool_calls: None,
refusal: None,
audio: None,
function_call: None,
};
ChatCompletionRequestMessage::Assistant(msg)
},
};
let response = self.http_client
.post(&self.endpoint_url)
.header("Content-Type", "application/json")
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
openai_messages.push(openai_message);
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(
format!("HTTP error {}: {}", status, error_text)
));
}
Ok(openai_messages)
// Parse SSE stream
let stream = response.bytes_stream().eventsource();
let parsed_stream = stream.filter_map(|event_result| async move {
match event_result {
Ok(event) => {
// Skip [DONE] sentinel
if event.data == "[DONE]" {
return None;
}
// Parse JSON
match serde_json::from_str::<Value>(&event.data) {
Ok(json) => Some(Ok(json)),
Err(e) => Some(Err(format!("JSON parse error: {}", e))),
}
}
Err(e) => Some(Err(format!("SSE stream error: {}", e))),
}
});
Ok(Box::pin(parsed_stream))
}
/// Makes a non-streaming request and returns the response
async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result<ChatCompletionsResponse> {
let request_body = serde_json::to_string(&request)
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
let response = self.http_client
.post(&self.endpoint_url)
.header("Content-Type", "application/json")
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(
format!("HTTP error {}: {}", status, error_text)
));
}
let response_text = response.text().await
.map_err(|e| FunctionCallingError::HttpError(e))?;
serde_json::from_str(&response_text)
.map_err(|e| FunctionCallingError::JsonParseError(e))
}
pub async fn function_calling_chat(
@ -866,42 +914,24 @@ impl ArchFunctionHandler {
.unwrap_or(false);
let prefilled_messages = self.prefill_message(messages.clone(), &self.default_prefix);
let openai_messages = self.convert_to_openai_messages(&prefilled_messages)?;
let mut request_args = CreateChatCompletionRequestArgs::default();
request_args
.model(&self.model_name)
.messages(openai_messages)
.stream(true)
.temperature(self.config.generation_params.temperature)
.top_p(self.config.generation_params.top_p)
.max_tokens(self.config.generation_params.max_tokens);
if let Some(true) = self.config.generation_params.logprobs {
request_args.logprobs(true);
if let Some(top_logprobs) = self.config.generation_params.top_logprobs {
request_args.top_logprobs(top_logprobs as u8);
}
}
let request_builder = request_args
.build()
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to build request: {}", e)))?;
let mut stream = self.openai_client
.chat()
.create_stream(request_builder)
.await
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream creation failed: {}", e)))?;
// Create request with extra_body parameters
let stream_request = self.create_request_with_extra_body(prefilled_messages.clone(), true);
let mut stream = self.make_streaming_request(stream_request).await?;
let mut model_response = String::new();
if use_agent_orchestrator {
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?;
if let Some(choice) = chunk.choices.first() {
if let Some(content) = &choice.delta.content {
model_response.push_str(content);
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
// Extract content from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
model_response.push_str(content);
}
}
}
}
@ -912,35 +942,39 @@ impl ArchFunctionHandler {
let mut has_tool_calls = None;
let mut has_hallucination = false;
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?;
if let Some(choice) = chunk.choices.first() {
if let Some(content) = &choice.delta.content {
let logprobs: Vec<f64> = if let Some(logprobs_data) = &choice.logprobs {
if let Some(content_vec) = &logprobs_data.content {
if let Some(token_logprob) = content_vec.first() {
token_logprob.top_logprobs
.iter()
.map(|top| top.logprob as f64)
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
// Extract logprobs
let logprobs: Vec<f64> = choice.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
} else {
vec![]
}
} else {
vec![]
})
.unwrap_or_default();
if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) {
has_hallucination = true;
break;
}
} else {
vec![]
};
if hallucination_state.append_and_check_token_hallucination(content.clone(), logprobs) {
has_hallucination = true;
break;
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
}
}
}
@ -950,21 +984,9 @@ impl ArchFunctionHandler {
info!("[Hallucination]: {}", hallucination_state.error_message);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_openai_messages = self.convert_to_openai_messages(&clarify_messages)?;
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let clarify_request = CreateChatCompletionRequestArgs::default()
.model(&self.model_name)
.messages(clarify_openai_messages)
.stream(false)
.temperature(self.config.generation_params.temperature)
.build()
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to build clarify request: {}", e)))?;
let retry_response = self.openai_client
.chat()
.create(clarify_request)
.await
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Clarify request failed: {}", e)))?;
let retry_response = self.make_non_streaming_request(clarify_request).await?;
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
@ -975,11 +997,15 @@ impl ArchFunctionHandler {
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?;
if let Some(choice) = chunk.choices.first() {
if let Some(content) = &choice.delta.content {
model_response.push_str(content);
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
model_response.push_str(content);
}
}
}
}
@ -1342,6 +1368,17 @@ mod tests {
assert!(config.format_prompt.contains("JSON formats"));
assert_eq!(config.generation_params.temperature, 0.1);
assert_eq!(config.support_data_types.len(), 14); // 8 Python-style + 6 JSON Schema names
// Verify exact prompt formatting matches Python
// Task prompt should have actual newlines, not escaped strings
assert!(config.task_prompt.contains("\n\nYou are provided"));
assert!(config.task_prompt.contains("</tools>\n\n"));
// Format prompt should have actual newlines and proper JSON with escaped quotes
assert!(config.format_prompt.contains("\n\nBased on your analysis"));
assert!(config.format_prompt.contains(r#"{"response": "Your response text here"}"#));
assert!(config.format_prompt.contains(r#"{"tool_calls": [{"#));
}
#[test]

View file

@ -146,7 +146,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
(&Method::POST, "/function_calling") => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, "/v1");
format!("{}{}", llm_provider_url, "/v1/chat/completions");
function_calling_chat_handler(req, fully_qualified_url)
.with_context(parent_cx)
.await

View file

@ -101,6 +101,12 @@ pub struct ChatCompletionsRequest {
pub top_logprobs: Option<u32>,
pub user: Option<String>,
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
// VLLM-specific parameters (used by Arch-Function)
pub top_k: Option<u32>,
pub stop_token_ids: Option<Vec<u32>>,
pub continue_final_message: Option<bool>,
pub add_generation_prompt: Option<bool>,
}
impl ChatCompletionsRequest {