diff --git a/crates/brightstaff/src/handlers/function_calling.rs b/crates/brightstaff/src/handlers/function_calling.rs
index 79492691..cad52548 100644
--- a/crates/brightstaff/src/handlers/function_calling.rs
+++ b/crates/brightstaff/src/handlers/function_calling.rs
@@ -7,17 +7,12 @@ use serde_json::{json, Value};
use std::collections::HashMap;
use thiserror::Error;
use tracing::{info, error};
-use async_openai::{Client as OpenAIClient, config::OpenAIConfig};
-use async_openai::types::{
- CreateChatCompletionRequestArgs, ChatCompletionRequestMessage,
- ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
- ChatCompletionRequestAssistantMessage,
-};
use futures::StreamExt;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
+use eventsource_stream::Eventsource;
@@ -103,17 +98,17 @@ pub struct ArchFunctionConfig {
impl Default for ArchFunctionConfig {
fn default() -> Self {
Self {
- task_prompt: String::from(
- "You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.\
- \n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\
- \n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
- ),
- format_prompt: String::from(
- "\n\nBased on your analysis, provide your response in one of the following JSON formats:\
- \n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```\
- \n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```\
- \n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```"
- ),
+ task_prompt: concat!(
+ "You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.",
+ "\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n",
+ "\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
+ ).to_string(),
+ format_prompt: concat!(
+ "\n\nBased on your analysis, provide your response in one of the following JSON formats:",
+ "\n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```",
+ "\n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```",
+ "\n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```"
+ ).to_string(),
generation_params: GenerationParams::default(),
support_data_types: vec![
"int".to_string(),
@@ -170,9 +165,9 @@ impl Default for ArchAgentConfig {
pub struct GenerationParams {
pub temperature: f32,
pub top_p: f32,
- pub top_k: i32,
+ pub top_k: u32,
pub max_tokens: u32,
- pub stop_token_ids: Vec,
+ pub stop_token_ids: Vec,
pub logprobs: Option,
pub top_logprobs: Option,
}
@@ -235,7 +230,8 @@ pub struct ArchFunctionHandler {
pub config: ArchFunctionConfig,
pub default_prefix: String,
pub clarify_prefix: String,
- pub openai_client: OpenAIClient,
+ pub endpoint_url: String,
+ pub http_client: reqwest::Client,
}
impl ArchFunctionHandler {
@@ -256,16 +252,13 @@ impl ArchFunctionHandler {
.build()
.expect("Failed to create HTTP client");
- // Configure OpenAI client to use custom endpoint
- let openai_config = OpenAIConfig::new()
- .with_api_base(endpoint_url);
-
Self {
model_name,
config,
default_prefix: "```json\n{\"".to_string(),
clarify_prefix: "```json\n{\"required_functions\":".to_string(),
- openai_client: OpenAIClient::with_config(openai_config).with_http_client(http_client),
+ endpoint_url,
+ http_client,
}
}
@@ -757,17 +750,25 @@ impl ArchFunctionHandler {
}
// Calculate from the end backwards
+ // Start with message_idx pointing past the end (will be used if no truncation needed)
let mut message_idx = messages.len();
for i in (conversation_idx..messages.len()).rev() {
if let MessageContent::Text(content) = &messages[i].content {
num_tokens += content.len() / 4;
if num_tokens >= max_tokens {
if messages[i].role == Role::User {
+ // Set message_idx to current position and break
+ // This matches Python's behavior where message_idx is set before break
+ message_idx = i;
break;
}
}
}
- message_idx = i;
+ // Only update message_idx if we haven't hit the token limit yet
+ // This ensures message_idx points to where truncation should start
+ if num_tokens < max_tokens {
+ message_idx = i;
+ }
}
// Return system message + truncated conversation
@@ -792,52 +793,99 @@ impl ArchFunctionHandler {
messages
}
- /// Converts internal Message format to async-openai's ChatCompletionRequestMessage format
- fn convert_to_openai_messages(&self, messages: &[Message]) -> Result> {
- let mut openai_messages = Vec::new();
+ /// Helper to create a request with VLLM-specific parameters
+ fn create_request_with_extra_body(&self, messages: Vec, stream: bool) -> ChatCompletionsRequest {
+ ChatCompletionsRequest {
+ model: self.model_name.clone(),
+ messages,
+ temperature: Some(self.config.generation_params.temperature),
+ top_p: Some(self.config.generation_params.top_p),
+ max_tokens: Some(self.config.generation_params.max_tokens),
+ stream: Some(stream),
+ logprobs: self.config.generation_params.logprobs,
+ top_logprobs: self.config.generation_params.top_logprobs,
+ // VLLM-specific parameters
+ continue_final_message: Some(true),
+ add_generation_prompt: Some(false),
+ top_k: Some(self.config.generation_params.top_k),
+ stop_token_ids: if !self.config.generation_params.stop_token_ids.is_empty() {
+ Some(self.config.generation_params.stop_token_ids.clone())
+ } else {
+ None
+ },
+ ..Default::default()
+ }
+ }
- for message in messages {
- let content_str = match &message.content {
- MessageContent::Text(text) => text.clone(),
- MessageContent::Parts(_) => String::new(), // Handle parts if needed
- };
+ /// Makes a streaming request and returns the SSE event stream
+ async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result> + Send>>> {
+ let request_body = serde_json::to_string(&request)
+ .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
- let openai_message = match message.role {
- Role::System => {
- ChatCompletionRequestMessage::System(
- ChatCompletionRequestSystemMessage {
- content: content_str.into(),
- name: message.name.clone(),
- }
- )
- },
- Role::User | Role::Tool => {
- // Convert both user and tool roles to user messages
- ChatCompletionRequestMessage::User(
- ChatCompletionRequestUserMessage {
- content: content_str.into(),
- name: message.name.clone(),
- }
- )
- },
- Role::Assistant => {
- #[allow(deprecated)]
- let msg = ChatCompletionRequestAssistantMessage {
- content: Some(content_str.into()),
- name: message.name.clone(),
- tool_calls: None,
- refusal: None,
- audio: None,
- function_call: None,
- };
- ChatCompletionRequestMessage::Assistant(msg)
- },
- };
+ let response = self.http_client
+ .post(&self.endpoint_url)
+ .header("Content-Type", "application/json")
+ .body(request_body)
+ .send()
+ .await
+ .map_err(|e| FunctionCallingError::HttpError(e))?;
- openai_messages.push(openai_message);
+ if !response.status().is_success() {
+ let status = response.status();
+ let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
+ return Err(FunctionCallingError::InvalidModelResponse(
+ format!("HTTP error {}: {}", status, error_text)
+ ));
}
- Ok(openai_messages)
+ // Parse SSE stream
+ let stream = response.bytes_stream().eventsource();
+ let parsed_stream = stream.filter_map(|event_result| async move {
+ match event_result {
+ Ok(event) => {
+ // Skip [DONE] sentinel
+ if event.data == "[DONE]" {
+ return None;
+ }
+ // Parse JSON
+ match serde_json::from_str::(&event.data) {
+ Ok(json) => Some(Ok(json)),
+ Err(e) => Some(Err(format!("JSON parse error: {}", e))),
+ }
+ }
+ Err(e) => Some(Err(format!("SSE stream error: {}", e))),
+ }
+ });
+
+ Ok(Box::pin(parsed_stream))
+ }
+
+ /// Makes a non-streaming request and returns the response
+ async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result {
+ let request_body = serde_json::to_string(&request)
+ .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
+
+ let response = self.http_client
+ .post(&self.endpoint_url)
+ .header("Content-Type", "application/json")
+ .body(request_body)
+ .send()
+ .await
+ .map_err(|e| FunctionCallingError::HttpError(e))?;
+
+ if !response.status().is_success() {
+ let status = response.status();
+ let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
+ return Err(FunctionCallingError::InvalidModelResponse(
+ format!("HTTP error {}: {}", status, error_text)
+ ));
+ }
+
+ let response_text = response.text().await
+ .map_err(|e| FunctionCallingError::HttpError(e))?;
+
+ serde_json::from_str(&response_text)
+ .map_err(|e| FunctionCallingError::JsonParseError(e))
}
pub async fn function_calling_chat(
@@ -866,42 +914,24 @@ impl ArchFunctionHandler {
.unwrap_or(false);
let prefilled_messages = self.prefill_message(messages.clone(), &self.default_prefix);
- let openai_messages = self.convert_to_openai_messages(&prefilled_messages)?;
- let mut request_args = CreateChatCompletionRequestArgs::default();
- request_args
- .model(&self.model_name)
- .messages(openai_messages)
- .stream(true)
- .temperature(self.config.generation_params.temperature)
- .top_p(self.config.generation_params.top_p)
- .max_tokens(self.config.generation_params.max_tokens);
-
- if let Some(true) = self.config.generation_params.logprobs {
- request_args.logprobs(true);
- if let Some(top_logprobs) = self.config.generation_params.top_logprobs {
- request_args.top_logprobs(top_logprobs as u8);
- }
- }
-
- let request_builder = request_args
- .build()
- .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to build request: {}", e)))?;
-
- let mut stream = self.openai_client
- .chat()
- .create_stream(request_builder)
- .await
- .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream creation failed: {}", e)))?;
+ // Create request with extra_body parameters
+ let stream_request = self.create_request_with_extra_body(prefilled_messages.clone(), true);
+ let mut stream = self.make_streaming_request(stream_request).await?;
let mut model_response = String::new();
if use_agent_orchestrator {
- while let Some(chunk) = stream.next().await {
- let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?;
- if let Some(choice) = chunk.choices.first() {
- if let Some(content) = &choice.delta.content {
- model_response.push_str(content);
+ while let Some(chunk_result) = stream.next().await {
+ let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
+ // Extract content from JSON response
+ if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
+ if let Some(choice) = choices.first() {
+ if let Some(content) = choice.get("delta")
+ .and_then(|d| d.get("content"))
+ .and_then(|c| c.as_str()) {
+ model_response.push_str(content);
+ }
}
}
}
@@ -912,35 +942,39 @@ impl ArchFunctionHandler {
let mut has_tool_calls = None;
let mut has_hallucination = false;
- while let Some(chunk) = stream.next().await {
- let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?;
- if let Some(choice) = chunk.choices.first() {
- if let Some(content) = &choice.delta.content {
- let logprobs: Vec = if let Some(logprobs_data) = &choice.logprobs {
- if let Some(content_vec) = &logprobs_data.content {
- if let Some(token_logprob) = content_vec.first() {
- token_logprob.top_logprobs
- .iter()
- .map(|top| top.logprob as f64)
+ while let Some(chunk_result) = stream.next().await {
+ let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
+
+ // Extract content and logprobs from JSON response
+ if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
+ if let Some(choice) = choices.first() {
+ if let Some(content) = choice.get("delta")
+ .and_then(|d| d.get("content"))
+ .and_then(|c| c.as_str()) {
+
+ // Extract logprobs
+ let logprobs: Vec = choice.get("logprobs")
+ .and_then(|lp| lp.get("content"))
+ .and_then(|c| c.as_array())
+ .and_then(|arr| arr.first())
+ .and_then(|token| token.get("top_logprobs"))
+ .and_then(|tlp| tlp.as_array())
+ .map(|arr| {
+ arr.iter()
+ .filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
- } else {
- vec![]
- }
- } else {
- vec![]
+ })
+ .unwrap_or_default();
+
+ if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) {
+ has_hallucination = true;
+ break;
}
- } else {
- vec![]
- };
- if hallucination_state.append_and_check_token_hallucination(content.clone(), logprobs) {
- has_hallucination = true;
- break;
- }
-
- if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
- let collected_content = hallucination_state.tokens.join("");
- has_tool_calls = Some(collected_content.contains("tool_calls"));
+ if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
+ let collected_content = hallucination_state.tokens.join("");
+ has_tool_calls = Some(collected_content.contains("tool_calls"));
+ }
}
}
}
@@ -950,21 +984,9 @@ impl ArchFunctionHandler {
info!("[Hallucination]: {}", hallucination_state.error_message);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
- let clarify_openai_messages = self.convert_to_openai_messages(&clarify_messages)?;
+ let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
- let clarify_request = CreateChatCompletionRequestArgs::default()
- .model(&self.model_name)
- .messages(clarify_openai_messages)
- .stream(false)
- .temperature(self.config.generation_params.temperature)
- .build()
- .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to build clarify request: {}", e)))?;
-
- let retry_response = self.openai_client
- .chat()
- .create(clarify_request)
- .await
- .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Clarify request failed: {}", e)))?;
+ let retry_response = self.make_non_streaming_request(clarify_request).await?;
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
@@ -975,11 +997,15 @@ impl ArchFunctionHandler {
model_response = hallucination_state.tokens.join("");
}
} else {
- while let Some(chunk) = stream.next().await {
- let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?;
- if let Some(choice) = chunk.choices.first() {
- if let Some(content) = &choice.delta.content {
- model_response.push_str(content);
+ while let Some(chunk_result) = stream.next().await {
+ let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
+ if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
+ if let Some(choice) = choices.first() {
+ if let Some(content) = choice.get("delta")
+ .and_then(|d| d.get("content"))
+ .and_then(|c| c.as_str()) {
+ model_response.push_str(content);
+ }
}
}
}
@@ -1342,6 +1368,17 @@ mod tests {
assert!(config.format_prompt.contains("JSON formats"));
assert_eq!(config.generation_params.temperature, 0.1);
assert_eq!(config.support_data_types.len(), 14); // 8 Python-style + 6 JSON Schema names
+
+ // Verify exact prompt formatting matches Python
+ // Task prompt should have actual newlines, not escaped strings
+ assert!(config.task_prompt.contains("\n\nYou are provided"));
+ assert!(config.task_prompt.contains("\n\n"));
+
+ // Format prompt should have actual newlines and proper JSON with escaped quotes
+ assert!(config.format_prompt.contains("\n\nBased on your analysis"));
+ assert!(config.format_prompt.contains(r#"{"response": "Your response text here"}"#));
+ assert!(config.format_prompt.contains(r#"{"tool_calls": [{"#));
+
}
#[test]
diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs
index 265ee5ba..87bdea36 100644
--- a/crates/brightstaff/src/main.rs
+++ b/crates/brightstaff/src/main.rs
@@ -146,7 +146,7 @@ async fn main() -> Result<(), Box> {
(&Method::POST, "/function_calling") => {
let fully_qualified_url =
- format!("{}{}", llm_provider_url, "/v1");
+ format!("{}{}", llm_provider_url, "/v1/chat/completions");
function_calling_chat_handler(req, fully_qualified_url)
.with_context(parent_cx)
.await
diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs
index 90a180ba..44b64485 100644
--- a/crates/hermesllm/src/apis/openai.rs
+++ b/crates/hermesllm/src/apis/openai.rs
@@ -101,6 +101,12 @@ pub struct ChatCompletionsRequest {
pub top_logprobs: Option,
pub user: Option,
// pub web_search: Option, // GOOD FIRST ISSUE: Future support for web search
+
+ // VLLM-specific parameters (used by Arch-Function)
+ pub top_k: Option,
+ pub stop_token_ids: Option>,
+ pub continue_final_message: Option,
+ pub add_generation_prompt: Option,
}
impl ChatCompletionsRequest {