mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fixing some edge cases with calls made to Arch-Function
This commit is contained in:
parent
00f95c93f3
commit
9922fe0cb9
3 changed files with 187 additions and 144 deletions
|
|
@ -7,17 +7,12 @@ use serde_json::{json, Value};
|
|||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tracing::{info, error};
|
||||
use async_openai::{Client as OpenAIClient, config::OpenAIConfig};
|
||||
use async_openai::types::{
|
||||
CreateChatCompletionRequestArgs, ChatCompletionRequestMessage,
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
|
||||
ChatCompletionRequestAssistantMessage,
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use eventsource_stream::Eventsource;
|
||||
|
||||
|
||||
|
||||
|
|
@ -103,17 +98,17 @@ pub struct ArchFunctionConfig {
|
|||
impl Default for ArchFunctionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
task_prompt: String::from(
|
||||
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.\
|
||||
\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tools}\n</tools>\
|
||||
\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
|
||||
),
|
||||
format_prompt: String::from(
|
||||
"\n\nBased on your analysis, provide your response in one of the following JSON formats:\
|
||||
\n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```\
|
||||
\n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```\
|
||||
\n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```"
|
||||
),
|
||||
task_prompt: concat!(
|
||||
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.",
|
||||
"\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tools}\n</tools>",
|
||||
"\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
|
||||
).to_string(),
|
||||
format_prompt: concat!(
|
||||
"\n\nBased on your analysis, provide your response in one of the following JSON formats:",
|
||||
"\n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```",
|
||||
"\n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```",
|
||||
"\n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```"
|
||||
).to_string(),
|
||||
generation_params: GenerationParams::default(),
|
||||
support_data_types: vec![
|
||||
"int".to_string(),
|
||||
|
|
@ -170,9 +165,9 @@ impl Default for ArchAgentConfig {
|
|||
pub struct GenerationParams {
|
||||
pub temperature: f32,
|
||||
pub top_p: f32,
|
||||
pub top_k: i32,
|
||||
pub top_k: u32,
|
||||
pub max_tokens: u32,
|
||||
pub stop_token_ids: Vec<i32>,
|
||||
pub stop_token_ids: Vec<u32>,
|
||||
pub logprobs: Option<bool>,
|
||||
pub top_logprobs: Option<u32>,
|
||||
}
|
||||
|
|
@ -235,7 +230,8 @@ pub struct ArchFunctionHandler {
|
|||
pub config: ArchFunctionConfig,
|
||||
pub default_prefix: String,
|
||||
pub clarify_prefix: String,
|
||||
pub openai_client: OpenAIClient<OpenAIConfig>,
|
||||
pub endpoint_url: String,
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl ArchFunctionHandler {
|
||||
|
|
@ -256,16 +252,13 @@ impl ArchFunctionHandler {
|
|||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
// Configure OpenAI client to use custom endpoint
|
||||
let openai_config = OpenAIConfig::new()
|
||||
.with_api_base(endpoint_url);
|
||||
|
||||
Self {
|
||||
model_name,
|
||||
config,
|
||||
default_prefix: "```json\n{\"".to_string(),
|
||||
clarify_prefix: "```json\n{\"required_functions\":".to_string(),
|
||||
openai_client: OpenAIClient::with_config(openai_config).with_http_client(http_client),
|
||||
endpoint_url,
|
||||
http_client,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -757,17 +750,25 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
|
||||
// Calculate from the end backwards
|
||||
// Start with message_idx pointing past the end (will be used if no truncation needed)
|
||||
let mut message_idx = messages.len();
|
||||
for i in (conversation_idx..messages.len()).rev() {
|
||||
if let MessageContent::Text(content) = &messages[i].content {
|
||||
num_tokens += content.len() / 4;
|
||||
if num_tokens >= max_tokens {
|
||||
if messages[i].role == Role::User {
|
||||
// Set message_idx to current position and break
|
||||
// This matches Python's behavior where message_idx is set before break
|
||||
message_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
message_idx = i;
|
||||
// Only update message_idx if we haven't hit the token limit yet
|
||||
// This ensures message_idx points to where truncation should start
|
||||
if num_tokens < max_tokens {
|
||||
message_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
// Return system message + truncated conversation
|
||||
|
|
@ -792,52 +793,99 @@ impl ArchFunctionHandler {
|
|||
messages
|
||||
}
|
||||
|
||||
/// Converts internal Message format to async-openai's ChatCompletionRequestMessage format
|
||||
fn convert_to_openai_messages(&self, messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
|
||||
let mut openai_messages = Vec::new();
|
||||
/// Helper to create a request with VLLM-specific parameters
|
||||
fn create_request_with_extra_body(&self, messages: Vec<Message>, stream: bool) -> ChatCompletionsRequest {
|
||||
ChatCompletionsRequest {
|
||||
model: self.model_name.clone(),
|
||||
messages,
|
||||
temperature: Some(self.config.generation_params.temperature),
|
||||
top_p: Some(self.config.generation_params.top_p),
|
||||
max_tokens: Some(self.config.generation_params.max_tokens),
|
||||
stream: Some(stream),
|
||||
logprobs: self.config.generation_params.logprobs,
|
||||
top_logprobs: self.config.generation_params.top_logprobs,
|
||||
// VLLM-specific parameters
|
||||
continue_final_message: Some(true),
|
||||
add_generation_prompt: Some(false),
|
||||
top_k: Some(self.config.generation_params.top_k),
|
||||
stop_token_ids: if !self.config.generation_params.stop_token_ids.is_empty() {
|
||||
Some(self.config.generation_params.stop_token_ids.clone())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
for message in messages {
|
||||
let content_str = match &message.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(_) => String::new(), // Handle parts if needed
|
||||
};
|
||||
/// Makes a streaming request and returns the SSE event stream
|
||||
async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>> {
|
||||
let request_body = serde_json::to_string(&request)
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
|
||||
|
||||
let openai_message = match message.role {
|
||||
Role::System => {
|
||||
ChatCompletionRequestMessage::System(
|
||||
ChatCompletionRequestSystemMessage {
|
||||
content: content_str.into(),
|
||||
name: message.name.clone(),
|
||||
}
|
||||
)
|
||||
},
|
||||
Role::User | Role::Tool => {
|
||||
// Convert both user and tool roles to user messages
|
||||
ChatCompletionRequestMessage::User(
|
||||
ChatCompletionRequestUserMessage {
|
||||
content: content_str.into(),
|
||||
name: message.name.clone(),
|
||||
}
|
||||
)
|
||||
},
|
||||
Role::Assistant => {
|
||||
#[allow(deprecated)]
|
||||
let msg = ChatCompletionRequestAssistantMessage {
|
||||
content: Some(content_str.into()),
|
||||
name: message.name.clone(),
|
||||
tool_calls: None,
|
||||
refusal: None,
|
||||
audio: None,
|
||||
function_call: None,
|
||||
};
|
||||
ChatCompletionRequestMessage::Assistant(msg)
|
||||
},
|
||||
};
|
||||
let response = self.http_client
|
||||
.post(&self.endpoint_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
|
||||
openai_messages.push(openai_message);
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(
|
||||
format!("HTTP error {}: {}", status, error_text)
|
||||
));
|
||||
}
|
||||
|
||||
Ok(openai_messages)
|
||||
// Parse SSE stream
|
||||
let stream = response.bytes_stream().eventsource();
|
||||
let parsed_stream = stream.filter_map(|event_result| async move {
|
||||
match event_result {
|
||||
Ok(event) => {
|
||||
// Skip [DONE] sentinel
|
||||
if event.data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
// Parse JSON
|
||||
match serde_json::from_str::<Value>(&event.data) {
|
||||
Ok(json) => Some(Ok(json)),
|
||||
Err(e) => Some(Err(format!("JSON parse error: {}", e))),
|
||||
}
|
||||
}
|
||||
Err(e) => Some(Err(format!("SSE stream error: {}", e))),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(parsed_stream))
|
||||
}
|
||||
|
||||
/// Makes a non-streaming request and returns the response
|
||||
async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result<ChatCompletionsResponse> {
|
||||
let request_body = serde_json::to_string(&request)
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
|
||||
|
||||
let response = self.http_client
|
||||
.post(&self.endpoint_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(
|
||||
format!("HTTP error {}: {}", status, error_text)
|
||||
));
|
||||
}
|
||||
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
|
||||
serde_json::from_str(&response_text)
|
||||
.map_err(|e| FunctionCallingError::JsonParseError(e))
|
||||
}
|
||||
|
||||
pub async fn function_calling_chat(
|
||||
|
|
@ -866,42 +914,24 @@ impl ArchFunctionHandler {
|
|||
.unwrap_or(false);
|
||||
|
||||
let prefilled_messages = self.prefill_message(messages.clone(), &self.default_prefix);
|
||||
let openai_messages = self.convert_to_openai_messages(&prefilled_messages)?;
|
||||
|
||||
let mut request_args = CreateChatCompletionRequestArgs::default();
|
||||
request_args
|
||||
.model(&self.model_name)
|
||||
.messages(openai_messages)
|
||||
.stream(true)
|
||||
.temperature(self.config.generation_params.temperature)
|
||||
.top_p(self.config.generation_params.top_p)
|
||||
.max_tokens(self.config.generation_params.max_tokens);
|
||||
|
||||
if let Some(true) = self.config.generation_params.logprobs {
|
||||
request_args.logprobs(true);
|
||||
if let Some(top_logprobs) = self.config.generation_params.top_logprobs {
|
||||
request_args.top_logprobs(top_logprobs as u8);
|
||||
}
|
||||
}
|
||||
|
||||
let request_builder = request_args
|
||||
.build()
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to build request: {}", e)))?;
|
||||
|
||||
let mut stream = self.openai_client
|
||||
.chat()
|
||||
.create_stream(request_builder)
|
||||
.await
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream creation failed: {}", e)))?;
|
||||
// Create request with extra_body parameters
|
||||
let stream_request = self.create_request_with_extra_body(prefilled_messages.clone(), true);
|
||||
let mut stream = self.make_streaming_request(stream_request).await?;
|
||||
|
||||
let mut model_response = String::new();
|
||||
|
||||
if use_agent_orchestrator {
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?;
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
model_response.push_str(content);
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
// Extract content from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
model_response.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -912,35 +942,39 @@ impl ArchFunctionHandler {
|
|||
let mut has_tool_calls = None;
|
||||
let mut has_hallucination = false;
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?;
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
let logprobs: Vec<f64> = if let Some(logprobs_data) = &choice.logprobs {
|
||||
if let Some(content_vec) = &logprobs_data.content {
|
||||
if let Some(token_logprob) = content_vec.first() {
|
||||
token_logprob.top_logprobs
|
||||
.iter()
|
||||
.map(|top| top.logprob as f64)
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
|
||||
// Extract content and logprobs from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
|
||||
// Extract logprobs
|
||||
let logprobs: Vec<f64> = choice.get("logprobs")
|
||||
.and_then(|lp| lp.get("content"))
|
||||
.and_then(|c| c.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|token| token.get("top_logprobs"))
|
||||
.and_then(|tlp| tlp.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) {
|
||||
has_hallucination = true;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
if hallucination_state.append_and_check_token_hallucination(content.clone(), logprobs) {
|
||||
has_hallucination = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
|
||||
let collected_content = hallucination_state.tokens.join("");
|
||||
has_tool_calls = Some(collected_content.contains("tool_calls"));
|
||||
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
|
||||
let collected_content = hallucination_state.tokens.join("");
|
||||
has_tool_calls = Some(collected_content.contains("tool_calls"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -950,21 +984,9 @@ impl ArchFunctionHandler {
|
|||
info!("[Hallucination]: {}", hallucination_state.error_message);
|
||||
|
||||
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
|
||||
let clarify_openai_messages = self.convert_to_openai_messages(&clarify_messages)?;
|
||||
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
|
||||
|
||||
let clarify_request = CreateChatCompletionRequestArgs::default()
|
||||
.model(&self.model_name)
|
||||
.messages(clarify_openai_messages)
|
||||
.stream(false)
|
||||
.temperature(self.config.generation_params.temperature)
|
||||
.build()
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to build clarify request: {}", e)))?;
|
||||
|
||||
let retry_response = self.openai_client
|
||||
.chat()
|
||||
.create(clarify_request)
|
||||
.await
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Clarify request failed: {}", e)))?;
|
||||
let retry_response = self.make_non_streaming_request(clarify_request).await?;
|
||||
|
||||
if let Some(choice) = retry_response.choices.first() {
|
||||
if let Some(content) = &choice.message.content {
|
||||
|
|
@ -975,11 +997,15 @@ impl ArchFunctionHandler {
|
|||
model_response = hallucination_state.tokens.join("");
|
||||
}
|
||||
} else {
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?;
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
model_response.push_str(content);
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
model_response.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1342,6 +1368,17 @@ mod tests {
|
|||
assert!(config.format_prompt.contains("JSON formats"));
|
||||
assert_eq!(config.generation_params.temperature, 0.1);
|
||||
assert_eq!(config.support_data_types.len(), 14); // 8 Python-style + 6 JSON Schema names
|
||||
|
||||
// Verify exact prompt formatting matches Python
|
||||
// Task prompt should have actual newlines, not escaped strings
|
||||
assert!(config.task_prompt.contains("\n\nYou are provided"));
|
||||
assert!(config.task_prompt.contains("</tools>\n\n"));
|
||||
|
||||
// Format prompt should have actual newlines and proper JSON with escaped quotes
|
||||
assert!(config.format_prompt.contains("\n\nBased on your analysis"));
|
||||
assert!(config.format_prompt.contains(r#"{"response": "Your response text here"}"#));
|
||||
assert!(config.format_prompt.contains(r#"{"tool_calls": [{"#));
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
(&Method::POST, "/function_calling") => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, "/v1");
|
||||
format!("{}{}", llm_provider_url, "/v1/chat/completions");
|
||||
function_calling_chat_handler(req, fully_qualified_url)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -101,6 +101,12 @@ pub struct ChatCompletionsRequest {
|
|||
pub top_logprobs: Option<u32>,
|
||||
pub user: Option<String>,
|
||||
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
|
||||
|
||||
// VLLM-specific parameters (used by Arch-Function)
|
||||
pub top_k: Option<u32>,
|
||||
pub stop_token_ids: Option<Vec<u32>>,
|
||||
pub continue_final_message: Option<bool>,
|
||||
pub add_generation_prompt: Option<bool>,
|
||||
}
|
||||
|
||||
impl ChatCompletionsRequest {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue