diff --git a/crates/brightstaff/src/handlers/function_calling.rs b/crates/brightstaff/src/handlers/function_calling.rs index 79492691..cad52548 100644 --- a/crates/brightstaff/src/handlers/function_calling.rs +++ b/crates/brightstaff/src/handlers/function_calling.rs @@ -7,17 +7,12 @@ use serde_json::{json, Value}; use std::collections::HashMap; use thiserror::Error; use tracing::{info, error}; -use async_openai::{Client as OpenAIClient, config::OpenAIConfig}; -use async_openai::types::{ - CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, - ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, - ChatCompletionRequestAssistantMessage, -}; use futures::StreamExt; use bytes::Bytes; use http_body_util::{combinators::BoxBody, BodyExt, Full}; use hyper::body::Incoming; use hyper::{Request, Response, StatusCode}; +use eventsource_stream::Eventsource; @@ -103,17 +98,17 @@ pub struct ArchFunctionConfig { impl Default for ArchFunctionConfig { fn default() -> Self { Self { - task_prompt: String::from( - "You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.\ - \n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\ - \n\nYour task is to decide which functions are needed and collect missing parameters if necessary." - ), - format_prompt: String::from( - "\n\nBased on your analysis, provide your response in one of the following JSON formats:\ - \n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```\ - \n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```\ - \n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```" - ), + task_prompt: concat!( + "You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.", + "\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n", + "\n\nYour task is to decide which functions are needed and collect missing parameters if necessary." + ).to_string(), + format_prompt: concat!( + "\n\nBased on your analysis, provide your response in one of the following JSON formats:", + "\n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```", + "\n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```", + "\n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```" + ).to_string(), generation_params: GenerationParams::default(), support_data_types: vec![ "int".to_string(), @@ -170,9 +165,9 @@ impl Default for ArchAgentConfig { pub struct GenerationParams { pub temperature: f32, pub top_p: f32, - pub top_k: i32, + pub top_k: u32, pub max_tokens: u32, - pub stop_token_ids: Vec, + pub stop_token_ids: Vec, pub logprobs: Option, pub top_logprobs: Option, } @@ -235,7 +230,8 @@ pub struct ArchFunctionHandler { pub config: ArchFunctionConfig, pub default_prefix: String, pub clarify_prefix: String, - pub openai_client: OpenAIClient, + pub endpoint_url: String, + pub http_client: reqwest::Client, } impl ArchFunctionHandler { @@ -256,16 +252,13 @@ impl ArchFunctionHandler { .build() .expect("Failed to create HTTP client"); - // Configure OpenAI client to use custom endpoint - let openai_config = OpenAIConfig::new() - .with_api_base(endpoint_url); - Self { model_name, config, default_prefix: "```json\n{\"".to_string(), clarify_prefix: "```json\n{\"required_functions\":".to_string(), - openai_client: OpenAIClient::with_config(openai_config).with_http_client(http_client), + endpoint_url, + http_client, } } @@ -757,17 +750,25 @@ impl ArchFunctionHandler { } // Calculate from the end backwards + // Start with message_idx pointing past the end (will be used if no truncation needed) let mut message_idx = messages.len(); for i in (conversation_idx..messages.len()).rev() { if let MessageContent::Text(content) = &messages[i].content { num_tokens += content.len() / 4; if num_tokens >= max_tokens { if messages[i].role == Role::User { + // Set message_idx to current position and break + // This matches Python's behavior where message_idx is set before break + message_idx = i; break; } } } - message_idx = i; + // Only update message_idx if we haven't hit the token limit yet + // This ensures message_idx points to where truncation should start + if num_tokens < max_tokens { + message_idx = i; + } } // Return system message + truncated conversation @@ -792,52 +793,99 @@ impl ArchFunctionHandler { messages } - /// Converts internal Message format to async-openai's ChatCompletionRequestMessage format - fn convert_to_openai_messages(&self, messages: &[Message]) -> Result> { - let mut openai_messages = Vec::new(); + /// Helper to create a request with VLLM-specific parameters + fn create_request_with_extra_body(&self, messages: Vec, stream: bool) -> ChatCompletionsRequest { + ChatCompletionsRequest { + model: self.model_name.clone(), + messages, + temperature: Some(self.config.generation_params.temperature), + top_p: Some(self.config.generation_params.top_p), + max_tokens: Some(self.config.generation_params.max_tokens), + stream: Some(stream), + logprobs: self.config.generation_params.logprobs, + top_logprobs: self.config.generation_params.top_logprobs, + // VLLM-specific parameters + continue_final_message: Some(true), + add_generation_prompt: Some(false), + top_k: Some(self.config.generation_params.top_k), + stop_token_ids: if !self.config.generation_params.stop_token_ids.is_empty() { + Some(self.config.generation_params.stop_token_ids.clone()) + } else { + None + }, + ..Default::default() + } + } - for message in messages { - let content_str = match &message.content { - MessageContent::Text(text) => text.clone(), - MessageContent::Parts(_) => String::new(), // Handle parts if needed - }; + /// Makes a streaming request and returns the SSE event stream + async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result> + Send>>> { + let request_body = serde_json::to_string(&request) + .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?; - let openai_message = match message.role { - Role::System => { - ChatCompletionRequestMessage::System( - ChatCompletionRequestSystemMessage { - content: content_str.into(), - name: message.name.clone(), - } - ) - }, - Role::User | Role::Tool => { - // Convert both user and tool roles to user messages - ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: content_str.into(), - name: message.name.clone(), - } - ) - }, - Role::Assistant => { - #[allow(deprecated)] - let msg = ChatCompletionRequestAssistantMessage { - content: Some(content_str.into()), - name: message.name.clone(), - tool_calls: None, - refusal: None, - audio: None, - function_call: None, - }; - ChatCompletionRequestMessage::Assistant(msg) - }, - }; + let response = self.http_client + .post(&self.endpoint_url) + .header("Content-Type", "application/json") + .body(request_body) + .send() + .await + .map_err(|e| FunctionCallingError::HttpError(e))?; - openai_messages.push(openai_message); + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); + return Err(FunctionCallingError::InvalidModelResponse( + format!("HTTP error {}: {}", status, error_text) + )); } - Ok(openai_messages) + // Parse SSE stream + let stream = response.bytes_stream().eventsource(); + let parsed_stream = stream.filter_map(|event_result| async move { + match event_result { + Ok(event) => { + // Skip [DONE] sentinel + if event.data == "[DONE]" { + return None; + } + // Parse JSON + match serde_json::from_str::(&event.data) { + Ok(json) => Some(Ok(json)), + Err(e) => Some(Err(format!("JSON parse error: {}", e))), + } + } + Err(e) => Some(Err(format!("SSE stream error: {}", e))), + } + }); + + Ok(Box::pin(parsed_stream)) + } + + /// Makes a non-streaming request and returns the response + async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result { + let request_body = serde_json::to_string(&request) + .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?; + + let response = self.http_client + .post(&self.endpoint_url) + .header("Content-Type", "application/json") + .body(request_body) + .send() + .await + .map_err(|e| FunctionCallingError::HttpError(e))?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); + return Err(FunctionCallingError::InvalidModelResponse( + format!("HTTP error {}: {}", status, error_text) + )); + } + + let response_text = response.text().await + .map_err(|e| FunctionCallingError::HttpError(e))?; + + serde_json::from_str(&response_text) + .map_err(|e| FunctionCallingError::JsonParseError(e)) } pub async fn function_calling_chat( @@ -866,42 +914,24 @@ impl ArchFunctionHandler { .unwrap_or(false); let prefilled_messages = self.prefill_message(messages.clone(), &self.default_prefix); - let openai_messages = self.convert_to_openai_messages(&prefilled_messages)?; - let mut request_args = CreateChatCompletionRequestArgs::default(); - request_args - .model(&self.model_name) - .messages(openai_messages) - .stream(true) - .temperature(self.config.generation_params.temperature) - .top_p(self.config.generation_params.top_p) - .max_tokens(self.config.generation_params.max_tokens); - - if let Some(true) = self.config.generation_params.logprobs { - request_args.logprobs(true); - if let Some(top_logprobs) = self.config.generation_params.top_logprobs { - request_args.top_logprobs(top_logprobs as u8); - } - } - - let request_builder = request_args - .build() - .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to build request: {}", e)))?; - - let mut stream = self.openai_client - .chat() - .create_stream(request_builder) - .await - .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream creation failed: {}", e)))?; + // Create request with extra_body parameters + let stream_request = self.create_request_with_extra_body(prefilled_messages.clone(), true); + let mut stream = self.make_streaming_request(stream_request).await?; let mut model_response = String::new(); if use_agent_orchestrator { - while let Some(chunk) = stream.next().await { - let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?; - if let Some(choice) = chunk.choices.first() { - if let Some(content) = &choice.delta.content { - model_response.push_str(content); + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?; + // Extract content from JSON response + if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { + if let Some(choice) = choices.first() { + if let Some(content) = choice.get("delta") + .and_then(|d| d.get("content")) + .and_then(|c| c.as_str()) { + model_response.push_str(content); + } } } } @@ -912,35 +942,39 @@ impl ArchFunctionHandler { let mut has_tool_calls = None; let mut has_hallucination = false; - while let Some(chunk) = stream.next().await { - let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?; - if let Some(choice) = chunk.choices.first() { - if let Some(content) = &choice.delta.content { - let logprobs: Vec = if let Some(logprobs_data) = &choice.logprobs { - if let Some(content_vec) = &logprobs_data.content { - if let Some(token_logprob) = content_vec.first() { - token_logprob.top_logprobs - .iter() - .map(|top| top.logprob as f64) + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?; + + // Extract content and logprobs from JSON response + if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { + if let Some(choice) = choices.first() { + if let Some(content) = choice.get("delta") + .and_then(|d| d.get("content")) + .and_then(|c| c.as_str()) { + + // Extract logprobs + let logprobs: Vec = choice.get("logprobs") + .and_then(|lp| lp.get("content")) + .and_then(|c| c.as_array()) + .and_then(|arr| arr.first()) + .and_then(|token| token.get("top_logprobs")) + .and_then(|tlp| tlp.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64())) .collect() - } else { - vec![] - } - } else { - vec![] + }) + .unwrap_or_default(); + + if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) { + has_hallucination = true; + break; } - } else { - vec![] - }; - if hallucination_state.append_and_check_token_hallucination(content.clone(), logprobs) { - has_hallucination = true; - break; - } - - if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() { - let collected_content = hallucination_state.tokens.join(""); - has_tool_calls = Some(collected_content.contains("tool_calls")); + if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() { + let collected_content = hallucination_state.tokens.join(""); + has_tool_calls = Some(collected_content.contains("tool_calls")); + } } } } @@ -950,21 +984,9 @@ impl ArchFunctionHandler { info!("[Hallucination]: {}", hallucination_state.error_message); let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix); - let clarify_openai_messages = self.convert_to_openai_messages(&clarify_messages)?; + let clarify_request = self.create_request_with_extra_body(clarify_messages, false); - let clarify_request = CreateChatCompletionRequestArgs::default() - .model(&self.model_name) - .messages(clarify_openai_messages) - .stream(false) - .temperature(self.config.generation_params.temperature) - .build() - .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to build clarify request: {}", e)))?; - - let retry_response = self.openai_client - .chat() - .create(clarify_request) - .await - .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Clarify request failed: {}", e)))?; + let retry_response = self.make_non_streaming_request(clarify_request).await?; if let Some(choice) = retry_response.choices.first() { if let Some(content) = &choice.message.content { @@ -975,11 +997,15 @@ impl ArchFunctionHandler { model_response = hallucination_state.tokens.join(""); } } else { - while let Some(chunk) = stream.next().await { - let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?; - if let Some(choice) = chunk.choices.first() { - if let Some(content) = &choice.delta.content { - model_response.push_str(content); + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?; + if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { + if let Some(choice) = choices.first() { + if let Some(content) = choice.get("delta") + .and_then(|d| d.get("content")) + .and_then(|c| c.as_str()) { + model_response.push_str(content); + } } } } @@ -1342,6 +1368,17 @@ mod tests { assert!(config.format_prompt.contains("JSON formats")); assert_eq!(config.generation_params.temperature, 0.1); assert_eq!(config.support_data_types.len(), 14); // 8 Python-style + 6 JSON Schema names + + // Verify exact prompt formatting matches Python + // Task prompt should have actual newlines, not escaped strings + assert!(config.task_prompt.contains("\n\nYou are provided")); + assert!(config.task_prompt.contains("\n\n")); + + // Format prompt should have actual newlines and proper JSON with escaped quotes + assert!(config.format_prompt.contains("\n\nBased on your analysis")); + assert!(config.format_prompt.contains(r#"{"response": "Your response text here"}"#)); + assert!(config.format_prompt.contains(r#"{"tool_calls": [{"#)); + } #[test] diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 265ee5ba..87bdea36 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -146,7 +146,7 @@ async fn main() -> Result<(), Box> { (&Method::POST, "/function_calling") => { let fully_qualified_url = - format!("{}{}", llm_provider_url, "/v1"); + format!("{}{}", llm_provider_url, "/v1/chat/completions"); function_calling_chat_handler(req, fully_qualified_url) .with_context(parent_cx) .await diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 90a180ba..44b64485 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -101,6 +101,12 @@ pub struct ChatCompletionsRequest { pub top_logprobs: Option, pub user: Option, // pub web_search: Option, // GOOD FIRST ISSUE: Future support for web search + + // VLLM-specific parameters (used by Arch-Function) + pub top_k: Option, + pub stop_token_ids: Option>, + pub continue_final_message: Option, + pub add_generation_prompt: Option, } impl ChatCompletionsRequest {