diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index ead0a351..6d63f1ed 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -304,6 +304,14 @@ def validate_and_render_schema(): } ) + # Always add arch-function model provider + updated_model_providers.append( + { + "name": "arch-function", + "provider_interface": "arch", + "model": "Arch-Function", + } + ) config_yaml["model_providers"] = deepcopy(updated_model_providers) listeners_with_provider = 0 diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 0115151e..5797d5a2 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -78,6 +78,43 @@ dependencies = [ "serde_json", ] +[[package]] +name = "async-openai" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bf39a15c8d613eb61892dc9a287c02277639ebead41ee611ad23aaa613f1a82" +dependencies = [ + "async-openai-macros", + "backoff", + "base64 0.22.1", + "bytes", + "derive_builder", + "eventsource-stream", + "futures", + "rand 0.9.2", + "reqwest", + "reqwest-eventsource", + "secrecy", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + +[[package]] +name = "async-openai-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -130,6 +167,20 @@ dependencies = [ "time", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.16", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -201,7 +252,9 @@ dependencies = [ name = "brightstaff" version = "0.1.0" dependencies = [ + "async-openai", "bytes", + "chrono", "common", "eventsource-client", "eventsource-stream", @@ -219,6 +272,7 @@ dependencies = [ "opentelemetry-stdout", "opentelemetry_sdk", "pretty_assertions", + "rand 0.9.2", "reqwest", "serde", "serde_json", @@ -231,6 +285,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", + "uuid", ] [[package]] @@ -281,6 +336,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.41" @@ -289,8 +350,10 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", + "js-sys", "num-traits", "serde", + "wasm-bindgen", "windows-link", ] @@ -336,6 +399,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -426,6 +499,37 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.101", +] + [[package]] name = "diff" version = "0.1.13" @@ -650,6 +754,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -685,8 +795,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -696,9 +808,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasi 0.14.2+wasi-0.2.4", + "wasm-bindgen", ] [[package]] @@ -934,7 +1048,7 @@ dependencies = [ "hyper 0.14.32", "log", "rustls 0.21.12", - "rustls-native-certs", + "rustls-native-certs 0.6.3", "tokio", "tokio-rustls 0.24.1", ] @@ -949,6 +1063,7 @@ dependencies = [ "hyper 1.6.0", "hyper-util", "rustls 0.23.27", + "rustls-native-certs 0.8.2", "rustls-pki-types", "tokio", "tokio-rustls 0.26.2", @@ -1181,6 +1296,15 @@ dependencies = [ "serde", ] +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -1285,6 +1409,12 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "matchers" version = "0.1.0" @@ -1312,6 +1442,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1354,7 +1494,7 @@ dependencies = [ "hyper 1.6.0", "hyper-util", "log", - "rand 0.9.1", + "rand 0.9.2", "regex", "serde_json", "serde_urlencoded", @@ -1374,7 +1514,7 @@ dependencies = [ "openssl-probe", "openssl-sys", "schannel", - "security-framework", + "security-framework 2.11.1", "security-framework-sys", "tempfile", ] @@ -1581,7 +1721,7 @@ dependencies = [ "glob", "opentelemetry", "percent-encoding", - "rand 0.9.1", + "rand 0.9.2", "serde_json", "thiserror 2.0.12", "tracing", @@ -1770,6 +1910,61 @@ dependencies = [ "log", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash 2.1.1", + "rustls 0.23.27", + "socket2", + "thiserror 2.0.12", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.3", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash 2.1.1", + "rustls 0.23.27", + "rustls-pki-types", + "slab", + "thiserror 2.0.12", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.59.0", +] + [[package]] name = "quote" version = "1.0.40" @@ -1798,9 +1993,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", @@ -1941,10 +2136,14 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", "pin-project-lite", + "quinn", + "rustls 0.23.27", + "rustls-native-certs 0.8.2", "rustls-pki-types", "serde", "serde_json", @@ -1952,6 +2151,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-rustls 0.26.2", "tokio-util", "tower 0.5.2", "tower-http", @@ -1963,6 +2163,22 @@ dependencies = [ "web-sys", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror 1.0.69", +] + [[package]] name = "ring" version = "0.17.14" @@ -1989,6 +2205,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustix" version = "1.0.7" @@ -2021,6 +2243,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ "once_cell", + "ring", "rustls-pki-types", "rustls-webpki 0.103.3", "subtle", @@ -2036,7 +2259,19 @@ dependencies = [ "openssl-probe", "rustls-pemfile", "schannel", - "security-framework", + "security-framework 2.11.1", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9980d917ebb0c0536119ba501e90834767bffc3d60641457fd84a1f3fd337923" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework 3.5.1", ] [[package]] @@ -2054,6 +2289,7 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" dependencies = [ + "web-time", "zeroize", ] @@ -2142,6 +2378,16 @@ version = "3.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21" +[[package]] +name = "secrecy" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -2149,7 +2395,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags", - "core-foundation", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" +dependencies = [ + "bitflags", + "core-foundation 0.10.1", "core-foundation-sys", "libc", "security-framework-sys", @@ -2420,7 +2679,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -2509,7 +2768,7 @@ dependencies = [ "fancy-regex", "lazy_static", "parking_lot", - "rustc-hash", + "rustc-hash 1.1.0", ] [[package]] @@ -2553,6 +2812,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.45.1" @@ -2829,6 +3103,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.18" @@ -2870,6 +3150,18 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "uuid" +version = "1.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +dependencies = [ + "getrandom 0.3.3", + "js-sys", + "serde", + "wasm-bindgen", +] + [[package]] name = "valuable" version = "0.1.1" diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index d424b0e6..3dfd1abe 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -4,7 +4,9 @@ version = "0.1.0" edition = "2021" [dependencies] +async-openai = "0.30.1" bytes = "1.10.1" +chrono = "0.4" common = { version = "0.1.0", path = "../common" } eventsource-client = "0.15.0" eventsource-stream = "0.2.3" @@ -21,6 +23,7 @@ opentelemetry-otlp = {version="0.29.0", features=["trace", "tonic", "grpc-tonic" opentelemetry-stdout = "0.29.0" opentelemetry_sdk = "0.29.0" pretty_assertions = "1.4.1" +rand = "0.9.2" reqwest = { version = "0.12.15", features = ["stream"] } serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" @@ -32,6 +35,7 @@ tokio-stream = "0.1" time = { version = "0.3", features = ["formatting", "macros"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } +uuid = { version = "1.0", features = ["v4", "serde"] } [dev-dependencies] mockito = "1.0" diff --git a/crates/brightstaff/src/handlers/function_calling.rs b/crates/brightstaff/src/handlers/function_calling.rs new file mode 100644 index 00000000..6c853ad4 --- /dev/null +++ b/crates/brightstaff/src/handlers/function_calling.rs @@ -0,0 +1,1888 @@ +use hermesllm::apis::openai::{ + ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message, + MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use thiserror::Error; +use tracing::{info, error}; +use async_openai::{Client as OpenAIClient, config::OpenAIConfig}; +use async_openai::types::{ + CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, +}; +use futures::StreamExt; +use bytes::Bytes; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use hyper::body::Incoming; +use hyper::{Request, Response, StatusCode}; + +// ============================================================================ +// CONSTANTS FOR HALLUCINATION DETECTION +// ============================================================================ + +const FUNC_NAME_START_PATTERN: &[&str] = &[r#"{"name":""#, r#"{'name':'"#]; +const FUNC_NAME_END_TOKEN: &[&str] = &["\",", "',"]; +const END_TOOL_CALL_TOKEN: &str = "}}"; + +const FIRST_PARAM_NAME_START_PATTERN: &[&str] = &[r#""arguments":{"#, r#"'arguments':{'"#]; +const PARAMETER_NAME_END_TOKENS: &[&str] = &["\":", ":\"", "':", ":'", "\":\"", "':'"]; +const PARAMETER_NAME_START_PATTERN: &[&str] = &["\",\"", "','"]; +const PARAMETER_VALUE_START_PATTERN: &[&str] = &["\":", "':"]; +const PARAMETER_VALUE_END_TOKEN: &[&str] = &["\",", "\"}"]; + +/// Default hallucination detection thresholds +#[derive(Debug, Clone)] +pub struct HallucinationThresholds { + pub entropy: f64, + pub varentropy: f64, + pub probability: f64, +} + +impl Default for HallucinationThresholds { + fn default() -> Self { + Self { + entropy: 0.0001, + varentropy: 0.0001, + probability: 0.8, + } + } +} + +// ============================================================================ +// ERROR TYPES +// ============================================================================ + +#[derive(Debug, Error)] +pub enum FunctionCallingError { + #[error("Failed to parse JSON: {0}")] + JsonParseError(#[from] serde_json::Error), + + #[error("Failed to fix malformed JSON: {0}")] + JsonFixError(String), + + #[error("Invalid model response: {0}")] + InvalidModelResponse(String), + + #[error("Tool call verification failed: {0}")] + ToolCallVerificationError(String), + + #[error("Data type conversion error: {0}")] + DataTypeConversionError(String), + + #[error("Unsupported data type: {0}")] + UnsupportedDataType(String), + + #[error("HTTP request error: {0}")] + HttpError(#[from] reqwest::Error), + + #[error("Invalid tool call: {0}")] + InvalidToolCall(String), +} + +pub type Result = std::result::Result; + +// ============================================================================ +// CONFIGURATION STRUCTURES +// ============================================================================ + +/// Configuration for Arch Function Calling +#[derive(Debug, Clone)] +pub struct ArchFunctionConfig { + pub task_prompt: String, + pub format_prompt: String, + pub generation_params: GenerationParams, + pub support_data_types: Vec, +} + +impl Default for ArchFunctionConfig { + fn default() -> Self { + Self { + task_prompt: String::from( + "You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.\ + \n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\ + \n\nYour task is to decide which functions are needed and collect missing parameters if necessary." + ), + format_prompt: String::from( + "\n\nBased on your analysis, provide your response in one of the following JSON formats:\ + \n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```\ + \n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```\ + \n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```" + ), + generation_params: GenerationParams::default(), + support_data_types: vec![ + "int".to_string(), + "float".to_string(), + "bool".to_string(), + "str".to_string(), + "list".to_string(), + "tuple".to_string(), + "set".to_string(), + "dict".to_string(), + // JSON Schema names (standard) + "integer".to_string(), + "number".to_string(), + "boolean".to_string(), + "string".to_string(), + "array".to_string(), + "object".to_string(), + ], + } + } +} + +/// Configuration for Arch Agent (extends ArchFunctionConfig with different generation params) +#[derive(Debug, Clone)] +pub struct ArchAgentConfig { + pub task_prompt: String, + pub format_prompt: String, + pub generation_params: GenerationParams, + pub support_data_types: Vec, +} + +impl Default for ArchAgentConfig { + fn default() -> Self { + let base = ArchFunctionConfig::default(); + Self { + task_prompt: base.task_prompt, + format_prompt: base.format_prompt, + generation_params: GenerationParams { + temperature: 0.01, + top_p: 1.0, + top_k: 10, + max_tokens: 1024, + stop_token_ids: vec![151645], + logprobs: Some(true), + top_logprobs: Some(10), + }, + support_data_types: base.support_data_types, + } + } +} + +/// Generation parameters for LLM +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GenerationParams { + pub temperature: f32, + pub top_p: f32, + pub top_k: i32, + pub max_tokens: u32, + pub stop_token_ids: Vec, + pub logprobs: Option, + pub top_logprobs: Option, +} + +impl Default for GenerationParams { + fn default() -> Self { + Self { + temperature: 0.1, + top_p: 1.0, + top_k: 10, + max_tokens: 1024, + stop_token_ids: vec![151645], + logprobs: Some(true), + top_logprobs: Some(10), + } + } +} + +// ============================================================================ +// PARSED MODEL RESPONSE +// ============================================================================ + +/// Parsed response from the model +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ParsedModelResponse { + pub raw_response: String, + pub response: Option, + pub required_functions: Vec, + pub clarification: String, + pub tool_calls: Vec, + pub is_valid: bool, + pub error_message: String, +} + +// ============================================================================ +// TOOL CALL VERIFICATION RESULT +// ============================================================================ + +/// Result of tool call verification +#[derive(Debug, Clone)] +pub struct ToolCallVerification { + pub is_valid: bool, + pub invalid_tool_call: Option, + pub error_message: String, +} + +impl Default for ToolCallVerification { + fn default() -> Self { + Self { + is_valid: true, + invalid_tool_call: None, + error_message: String::new(), + } + } +} + +/// Main handler for Arch Function Calling +pub struct ArchFunctionHandler { + pub model_name: String, + pub config: ArchFunctionConfig, + pub default_prefix: String, + pub clarify_prefix: String, + pub openai_client: OpenAIClient, +} + +impl ArchFunctionHandler { + /// Creates a new ArchFunctionHandler + pub fn new(model_name: String, config: ArchFunctionConfig, endpoint_url: String) -> Self { + use common::consts::ARCH_PROVIDER_HINT_HEADER; + use reqwest::header; + + // Create custom HTTP client with Arch provider hint header + let mut headers = header::HeaderMap::new(); + headers.insert( + header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), + header::HeaderValue::from_str(&model_name).unwrap(), + ); + + let http_client = reqwest::ClientBuilder::new() + .default_headers(headers) + .build() + .expect("Failed to create HTTP client"); + + // Configure OpenAI client to use custom endpoint + let openai_config = OpenAIConfig::new() + .with_api_base(endpoint_url); + + Self { + model_name, + config, + default_prefix: "```json\n{\"".to_string(), + clarify_prefix: "```json\n{\"required_functions\":".to_string(), + openai_client: OpenAIClient::with_config(openai_config).with_http_client(http_client), + } + } + + /// Converts a list of tools into JSON format string + pub fn convert_tools(&self, tools: &[Tool]) -> Result { + let converted: std::result::Result, serde_json::Error> = tools + .iter() + .map(|tool| serde_json::to_string(&tool.function)) + .collect(); + + converted + .map(|v| v.join("\n")) + .map_err(FunctionCallingError::from) + } + + /// Fixes malformed JSON strings by ensuring proper bracket matching + pub fn fix_json_string(&self, json_str: &str) -> Result { + let json_str = json_str.trim(); + let mut stack: Vec = Vec::new(); + let mut fixed_str = String::new(); + + let matching_bracket: HashMap = + [(')', '('), ('}', '{'), (']', '[')] + .iter() + .cloned() + .collect(); + + let opening_bracket: HashMap = matching_bracket + .iter() + .map(|(k, v)| (*v, *k)) + .collect(); + + for ch in json_str.chars() { + if ch == '{' || ch == '[' || ch == '(' { + stack.push(ch); + fixed_str.push(ch); + } else if ch == '}' || ch == ']' || ch == ')' { + if let Some(&last) = stack.last() { + if matching_bracket.get(&ch) == Some(&last) { + stack.pop(); + fixed_str.push(ch); + } + // Ignore unmatched closing brackets + } + } else { + fixed_str.push(ch); + } + } + + // Add corresponding closing brackets for unmatched opening brackets + while let Some(unmatched_opening) = stack.pop() { + if let Some(&closing) = opening_bracket.get(&unmatched_opening) { + fixed_str.push(closing); + } + } + + // Try to parse the fixed JSON + match serde_json::from_str::(&fixed_str) { + Ok(val) => serde_json::to_string(&val).map_err(FunctionCallingError::from), + Err(_) => { + // Try replacing single quotes with double quotes + let fixed_str = fixed_str.replace('\'', "\""); + match serde_json::from_str::(&fixed_str) { + Ok(val) => serde_json::to_string(&val).map_err(FunctionCallingError::from), + Err(e) => Err(FunctionCallingError::JsonFixError(format!( + "Failed to fix JSON: {}", + e + ))), + } + } + } + } + + /// Parses the model response and extracts tool call information + pub fn parse_model_response(&self, content: &str) -> ParsedModelResponse { + let mut response_dict = ParsedModelResponse::default(); + + // Remove markdown code blocks + let mut content = content.trim().to_string(); + if content.starts_with("```") && content.ends_with("```") { + content = content.trim_start_matches("```").trim_end_matches("```").trim().to_string(); + if content.starts_with("json") { + content = content.trim_start_matches("json").trim().to_string(); + } + } + + // Try to fix JSON if needed + let fixed_content = match self.fix_json_string(&content) { + Ok(fixed) => { + response_dict.raw_response = format!("```json\n{}\n```", fixed); + fixed + } + Err(e) => { + response_dict.is_valid = false; + response_dict.error_message = format!("Failed to fix JSON: {}", e); + return response_dict; + } + }; + + // Parse the JSON + match serde_json::from_str::(&fixed_content) { + Ok(model_response) => { + // Successfully parsed - mark as valid + response_dict.is_valid = true; + + // Extract response field + if let Some(resp) = model_response.get("response") { + if let Some(resp_str) = resp.as_str() { + response_dict.response = Some(resp_str.to_string()); + } + } + + // Extract required_functions + if let Some(funcs) = model_response.get("required_functions") { + if let Some(funcs_arr) = funcs.as_array() { + response_dict.required_functions = funcs_arr + .iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect(); + } + } + + // Extract clarification + if let Some(clarif) = model_response.get("clarification") { + if let Some(clarif_str) = clarif.as_str() { + response_dict.clarification = clarif_str.to_string(); + } + } + + // Extract tool_calls + if let Some(tool_calls) = model_response.get("tool_calls") { + if let Some(tool_calls_arr) = tool_calls.as_array() { + for tool_call_val in tool_calls_arr { + let id = format!("call_{}", rand::random::() % 10000 + 1000); + + let name = tool_call_val + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let arguments = tool_call_val + .get("arguments") + .map(|v| serde_json::to_string(v).unwrap_or_default()) + .unwrap_or_default(); + + response_dict.tool_calls.push(ToolCall { + id, + call_type: "function".to_string(), + function: FunctionCall { name, arguments }, + }); + } + } + } + } + Err(e) => { + response_dict.is_valid = false; + response_dict.error_message = format!("Failed to parse model response: {}", e); + } + } + + response_dict + } + + /// Converts data type from one type to another + pub fn convert_data_type(&self, value: &Value, target_type: &str) -> Result { + match target_type { + // Handle float/number conversions + "float" | "number" => { + if let Some(int_val) = value.as_i64() { + return Ok(json!(int_val as f64)); + } + } + // Handle list/array conversions + "list" | "array" => { + if let Some(str_val) = value.as_str() { + // Try to parse as JSON array + if let Ok(arr) = serde_json::from_str::>(str_val) { + return Ok(json!(arr)); + } + } + } + // Handle str/string conversions + "str" | "string" => { + if !value.is_string() { + return Ok(json!(value.to_string())); + } + } + _ => {} + } + Ok(value.clone()) + } + + /// Helper method to check if a value matches the expected type + fn check_value_type(&self, value: &Value, target_type: &str) -> bool { + match target_type { + "int" | "integer" => value.is_i64() || value.is_u64(), + "float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(), + "bool" | "boolean" => value.is_boolean(), + "str" | "string" => value.is_string(), + "list" | "array" => value.is_array(), + "dict" | "object" => value.is_object(), + _ => true, + } + } + + /// Helper method to validate and potentially convert a parameter value to match the target type + /// Returns Ok(true) if the value is valid (either originally or after conversion) + /// Returns Ok(false) if the value cannot be converted to the target type + fn validate_or_convert_parameter( + &self, + param_value: &Value, + target_type: &str, + ) -> Result { + // First check: Is it already the correct type? + if self.check_value_type(param_value, target_type) { + return Ok(true); + } + + // Try to convert + let converted = self.convert_data_type(param_value, target_type)?; + + // Second check: Is it the correct type after conversion? + Ok(self.check_value_type(&converted, target_type)) + } + + /// Verifies the validity of extracted tool calls against the provided tools + pub fn verify_tool_calls( + &self, + tools: &[Tool], + tool_calls: &[ToolCall], + ) -> ToolCallVerification { + let mut verification = ToolCallVerification::default(); + + // Build a map of function name to parameters + let mut functions: HashMap = HashMap::new(); + for tool in tools { + functions.insert(tool.function.name.clone(), &tool.function.parameters); + } + + for tool_call in tool_calls { + if !verification.is_valid { + break; + } + + let func_name = &tool_call.function.name; + + // Parse arguments as JSON + let func_args: HashMap = match serde_json::from_str(&tool_call.function.arguments) { + Ok(args) => args, + Err(e) => { + verification.is_valid = false; + verification.invalid_tool_call = Some(tool_call.clone()); + verification.error_message = format!("Failed to parse arguments for function '{}': {}", func_name, e); + break; + } + }; + + // Check if function is available + if let Some(function_params) = functions.get(func_name) { + // Check if all required parameters are present + if let Some(required) = function_params.get("required") { + if let Some(required_arr) = required.as_array() { + for required_param in required_arr { + if let Some(param_name) = required_param.as_str() { + if !func_args.contains_key(param_name) { + verification.is_valid = false; + verification.invalid_tool_call = Some(tool_call.clone()); + verification.error_message = format!( + "`{}` is required by the function `{}` but not found in the tool call!", + param_name, func_name + ); + break; + } + } + } + } + } + + // Verify the data type of each parameter + if let Some(properties) = function_params.get("properties") { + if let Some(properties_obj) = properties.as_object() { + for (param_name, param_value) in &func_args { + if let Some(param_schema) = properties_obj.get(param_name) { + if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) { + if self.config.support_data_types.contains(&target_type.to_string()) { + // Validate data type using helper method + match self.validate_or_convert_parameter(param_value, target_type) { + Ok(is_valid) => { + if !is_valid { + verification.is_valid = false; + verification.invalid_tool_call = Some(tool_call.clone()); + verification.error_message = format!( + "Parameter `{}` is expected to have the data type `{}`, got incompatible type.", + param_name, target_type + ); + break; + } + } + Err(_) => { + verification.is_valid = false; + verification.invalid_tool_call = Some(tool_call.clone()); + verification.error_message = format!( + "Parameter `{}` is expected to have the data type `{}`, got incompatible type.", + param_name, target_type + ); + break; + } + } + } else { + verification.is_valid = false; + verification.invalid_tool_call = Some(tool_call.clone()); + verification.error_message = format!("Data type `{}` is not supported.", target_type); + break; + } + } + } else { + verification.is_valid = false; + verification.invalid_tool_call = Some(tool_call.clone()); + verification.error_message = format!( + "Parameter `{}` is not defined in the function `{}`.", + param_name, func_name + ); + break; + } + } + } + } + } else { + verification.is_valid = false; + verification.invalid_tool_call = Some(tool_call.clone()); + verification.error_message = format!("{} is not available!", func_name); + } + } + + verification + } + + /// Formats the system prompt with tools + pub fn format_system_prompt(&self, tools: &[Tool]) -> Result { + let tools_str = self.convert_tools(tools)?; + let today_date = chrono::Local::now().format("%Y-%m-%d").to_string(); + + let system_prompt = self + .config + .task_prompt + .replace("{today_date}", &today_date) + .replace("{tools}", &tools_str) + + &self.config.format_prompt; + + Ok(system_prompt) + } + + /// Processes messages and formats them appropriately for the model + pub fn process_messages( + &self, + messages: &[Message], + tools: Option<&[Tool]>, + extra_instruction: Option<&str>, + max_tokens: usize, + metadata: Option<&HashMap>, + ) -> Result> { + let mut processed_messages = Vec::new(); + + // Add system message with tools if provided + if let Some(tools) = tools { + let system_prompt = self.format_system_prompt(tools)?; + processed_messages.push(Message { + role: Role::System, + content: MessageContent::Text(system_prompt), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + + // Process each message + for (idx, message) in messages.iter().enumerate() { + let mut role = message.role.clone(); + let mut content = match &message.content { + MessageContent::Text(text) => text.clone(), + MessageContent::Parts(_) => String::new(), + }; + + // Handle tool calls + if let Some(tool_calls) = &message.tool_calls { + if !tool_calls.is_empty() { + role = Role::Assistant; + let tool_call_json = serde_json::to_string(&tool_calls[0].function)?; + content = format!("\n{}\n", tool_call_json); + } + } else if role == Role::Tool { + role = Role::User; + + // Check if we should optimize context window + let optimize_context = metadata + .and_then(|m| m.get("optimize_context_window")) + .and_then(|v| v.as_str()) + .map(|s| s.to_lowercase() == "true") + .unwrap_or(false); + + if optimize_context { + content = "\n\n".to_string(); + } else { + // Get the tool call from previous message + if idx > 0 { + if let MessageContent::Text(prev_content) = &messages[idx - 1].content { + let mut tool_call_msg = prev_content.clone(); + + // Strip markdown code blocks + if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") { + tool_call_msg = tool_call_msg.trim_start_matches("```").trim_end_matches("```").trim().to_string(); + if tool_call_msg.starts_with("json") { + tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string(); + } + } + + // Extract function name + if let Ok(parsed) = serde_json::from_str::(&tool_call_msg) { + if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) { + if let Some(first_tool_call) = tool_calls_arr.first() { + let func_name = first_tool_call + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("no_name"); + + let tool_response = json!({ + "name": func_name, + "result": content, + }); + + content = format!("\n{}\n", + serde_json::to_string(&tool_response)?); + } + } + } + } + } + } + } + + processed_messages.push(Message { + role, + content: MessageContent::Text(content), + name: message.name.clone(), + tool_calls: None, + tool_call_id: None, + }); + } + + // Ensure last message is from user + if let Some(last) = processed_messages.last() { + if last.role != Role::User { + return Err(FunctionCallingError::InvalidModelResponse( + "Last message must be from user".to_string(), + )); + } + } + + // Add extra instruction if provided + if let Some(instruction) = extra_instruction { + if let Some(last) = processed_messages.last_mut() { + if let MessageContent::Text(content) = &mut last.content { + content.push_str("\n"); + content.push_str(instruction); + } + } + } + + // Truncate messages if they exceed max_tokens + let processed_messages = self.truncate_messages(processed_messages, max_tokens); + + Ok(processed_messages) + } + + /// Truncates messages to fit within max_tokens limit + fn truncate_messages(&self, messages: Vec, max_tokens: usize) -> Vec { + let mut num_tokens = 0; + let mut conversation_idx = 0; + + // Keep system message if present + if let Some(first) = messages.first() { + if first.role == Role::System { + if let MessageContent::Text(content) = &first.content { + num_tokens += content.len() / 4; // Approximate 4 chars per token + } + conversation_idx = 1; + } + } + + // Calculate from the end backwards + let mut message_idx = messages.len(); + for i in (conversation_idx..messages.len()).rev() { + if let MessageContent::Text(content) = &messages[i].content { + num_tokens += content.len() / 4; + if num_tokens >= max_tokens { + if messages[i].role == Role::User { + break; + } + } + } + message_idx = i; + } + + // Return system message + truncated conversation + let mut result = Vec::new(); + if conversation_idx > 0 { + result.push(messages[0].clone()); + } + result.extend_from_slice(&messages[message_idx..]); + + result + } + + /// Prefills a message by adding an assistant message with the prefix + pub fn prefill_message(&self, mut messages: Vec, prefill: &str) -> Vec { + messages.push(Message { + role: Role::Assistant, + content: MessageContent::Text(prefill.to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }); + messages + } + + /// Converts internal Message format to async-openai's ChatCompletionRequestMessage format + fn convert_to_openai_messages(&self, messages: &[Message]) -> Result> { + let mut openai_messages = Vec::new(); + + for message in messages { + let content_str = match &message.content { + MessageContent::Text(text) => text.clone(), + MessageContent::Parts(_) => String::new(), // Handle parts if needed + }; + + let openai_message = match message.role { + Role::System => { + ChatCompletionRequestMessage::System( + ChatCompletionRequestSystemMessage { + content: content_str.into(), + name: message.name.clone(), + } + ) + }, + Role::User | Role::Tool => { + // Convert both user and tool roles to user messages + ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: content_str.into(), + name: message.name.clone(), + } + ) + }, + Role::Assistant => { + #[allow(deprecated)] + let msg = ChatCompletionRequestAssistantMessage { + content: Some(content_str.into()), + name: message.name.clone(), + tool_calls: None, + refusal: None, + audio: None, + function_call: None, + }; + ChatCompletionRequestMessage::Assistant(msg) + }, + }; + + openai_messages.push(openai_message); + } + + Ok(openai_messages) + } + + pub async fn function_calling_chat( + &self, + request: ChatCompletionsRequest, + ) -> Result { + use tracing::{info, error}; + + info!("[Arch-Function] - ChatCompletion"); + + let messages = self.process_messages( + &request.messages, + request.tools.as_deref(), + None, + self.config.generation_params.max_tokens as usize, + request.metadata.as_ref(), + )?; + + info!("[request to arch-fc]: model: {}, messages count: {}", + self.model_name, messages.len()); + + let use_agent_orchestrator = request.metadata + .as_ref() + .and_then(|m| m.get("use_agent_orchestrator")) + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + let prefilled_messages = self.prefill_message(messages.clone(), &self.default_prefix); + let openai_messages = self.convert_to_openai_messages(&prefilled_messages)?; + + let mut request_args = CreateChatCompletionRequestArgs::default(); + request_args + .model(&self.model_name) + .messages(openai_messages) + .stream(true) + .temperature(self.config.generation_params.temperature) + .top_p(self.config.generation_params.top_p) + .max_tokens(self.config.generation_params.max_tokens); + + if let Some(true) = self.config.generation_params.logprobs { + request_args.logprobs(true); + if let Some(top_logprobs) = self.config.generation_params.top_logprobs { + request_args.top_logprobs(top_logprobs as u8); + } + } + + let request_builder = request_args + .build() + .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to build request: {}", e)))?; + + let mut stream = self.openai_client + .chat() + .create_stream(request_builder) + .await + .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream creation failed: {}", e)))?; + + let mut model_response = String::new(); + + if use_agent_orchestrator { + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?; + if let Some(choice) = chunk.choices.first() { + if let Some(content) = &choice.delta.content { + model_response.push_str(content); + } + } + } + info!("[Agent Orchestrator]: response received"); + } else { + if let Some(tools) = request.tools.as_ref() { + let mut hallucination_state = HallucinationState::new(tools); + let mut has_tool_calls = None; + let mut has_hallucination = false; + + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?; + if let Some(choice) = chunk.choices.first() { + if let Some(content) = &choice.delta.content { + let logprobs: Vec = if let Some(logprobs_data) = &choice.logprobs { + if let Some(content_vec) = &logprobs_data.content { + if let Some(token_logprob) = content_vec.first() { + token_logprob.top_logprobs + .iter() + .map(|top| top.logprob as f64) + .collect() + } else { + vec![] + } + } else { + vec![] + } + } else { + vec![] + }; + + if hallucination_state.append_and_check_token_hallucination(content.clone(), logprobs) { + has_hallucination = true; + break; + } + + if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() { + let collected_content = hallucination_state.tokens.join(""); + has_tool_calls = Some(collected_content.contains("tool_calls")); + } + } + } + } + + if has_tool_calls == Some(true) && has_hallucination { + info!("[Hallucination]: {}", hallucination_state.error_message); + + let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix); + let clarify_openai_messages = self.convert_to_openai_messages(&clarify_messages)?; + + let clarify_request = CreateChatCompletionRequestArgs::default() + .model(&self.model_name) + .messages(clarify_openai_messages) + .stream(false) + .temperature(self.config.generation_params.temperature) + .build() + .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to build clarify request: {}", e)))?; + + let retry_response = self.openai_client + .chat() + .create(clarify_request) + .await + .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Clarify request failed: {}", e)))?; + + if let Some(choice) = retry_response.choices.first() { + if let Some(content) = &choice.message.content { + model_response = content.clone(); + } + } + } else { + model_response = hallucination_state.tokens.join(""); + } + } else { + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Stream error: {}", e)))?; + if let Some(choice) = chunk.choices.first() { + if let Some(content) = &choice.delta.content { + model_response.push_str(content); + } + } + } + } + } + + let response_dict = self.parse_model_response(&model_response); + + info!("[arch-fc]: raw model response: {}", response_dict.raw_response); + + let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) { + ResponseMessage { + role: Role::Assistant, + content: response_dict.response.clone(), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: None, + } + } else if !response_dict.required_functions.is_empty() { + if !use_agent_orchestrator { + ResponseMessage { + role: Role::Assistant, + content: Some(response_dict.clarification.clone()), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: None, + } + } else { + ResponseMessage { + role: Role::Assistant, + content: Some(String::new()), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: None, + } + } + } else if !response_dict.tool_calls.is_empty() { + if response_dict.is_valid { + if !use_agent_orchestrator { + if let Some(tools) = request.tools.as_ref() { + let verification = self.verify_tool_calls(tools, &response_dict.tool_calls); + + if verification.is_valid { + info!("[Tool calls]: {:?}", + response_dict.tool_calls.iter() + .map(|tc| &tc.function) + .collect::>() + ); + ResponseMessage { + role: Role::Assistant, + content: Some(String::new()), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: Some(response_dict.tool_calls.clone()), + } + } else { + error!("Invalid tool call - {}", verification.error_message); + ResponseMessage { + role: Role::Assistant, + content: Some(String::new()), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: None, + } + } + } else { + error!("Tool calls present but no tools provided in request"); + ResponseMessage { + role: Role::Assistant, + content: Some(String::new()), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: None, + } + } + } else { + info!("[Tool calls]: {:?}", + response_dict.tool_calls.iter() + .map(|tc| &tc.function) + .collect::>() + ); + ResponseMessage { + role: Role::Assistant, + content: Some(String::new()), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: Some(response_dict.tool_calls.clone()), + } + } + } else { + error!("Invalid tool calls in response: {}", response_dict.error_message); + ResponseMessage { + role: Role::Assistant, + content: Some(String::new()), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: None, + } + } + } else { + error!("Invalid model response - {}", model_response); + ResponseMessage { + role: Role::Assistant, + content: Some(String::new()), + refusal: None, + annotations: None, + audio: None, + function_call: None, + tool_calls: None, + } + }; + + let chat_completion_response = ChatCompletionsResponse { + id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), + object: Some("chat.completion".to_string()), + created: chrono::Utc::now().timestamp() as u64, + model: request.model.clone(), + choices: vec![Choice { + index: 0, + message: model_message, + finish_reason: Some(FinishReason::Stop), + logprobs: None, + }], + usage: Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + prompt_tokens_details: None, + completion_tokens_details: None, + }, + system_fingerprint: None, + service_tier: None, + metadata: None, + }; + + info!("[response arch-fc]: {:?}", chat_completion_response); + + Ok(chat_completion_response) + } +} + +// ============================================================================ +// ARCH AGENT HANDLER +// ============================================================================ + +/// Handler for Arch Agent (extends ArchFunctionHandler with specialized behavior) +pub struct ArchAgentHandler { + pub function_handler: ArchFunctionHandler, +} + +impl ArchAgentHandler { + /// Creates a new ArchAgentHandler + pub fn new(model_name: String, endpoint_url: String) -> Self { + let config = ArchAgentConfig::default(); + Self { + function_handler: ArchFunctionHandler::new( + model_name, + ArchFunctionConfig { + task_prompt: config.task_prompt, + format_prompt: config.format_prompt, + generation_params: GenerationParams { + temperature: config.generation_params.temperature, + top_p: config.generation_params.top_p, + top_k: config.generation_params.top_k, + max_tokens: config.generation_params.max_tokens, + stop_token_ids: config.generation_params.stop_token_ids, + logprobs: config.generation_params.logprobs, + top_logprobs: config.generation_params.top_logprobs, + }, + support_data_types: config.support_data_types, + }, + endpoint_url, + ), + } + } + + /// Converts tools with special handling for empty parameters + /// This is the key difference from ArchFunctionHandler + pub fn convert_tools(&self, tools: &[Tool]) -> Result { + let mut converted = Vec::new(); + + for tool in tools { + let mut tool_copy = tool.clone(); + + // Delete parameters key if its empty + if let Some(props) = tool_copy.function.parameters.get("properties") { + if props.is_object() && props.as_object().unwrap().is_empty() { + // Create new parameters without properties + if let Some(params_obj) = tool_copy.function.parameters.as_object_mut() { + params_obj.remove("properties"); + } + } + } + + converted.push(serde_json::to_string(&tool_copy.function)?); + } + + Ok(converted.join("\n")) + } +} + +// ============================================================================ +// HTTP HANDLER FOR FUNCTION CALLING ENDPOINT +// ============================================================================ + +fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + +pub async fn function_calling_chat_handler( + req: Request, + llm_provider_url: String, +) -> std::result::Result>, hyper::Error> { + + use hermesllm::apis::openai::ChatCompletionsRequest; + let whole_body = req.collect().await?.to_bytes(); + + // Parse as JSON Value first to modify it + let mut body_json: Value = match serde_json::from_slice(&whole_body) { + Ok(json) => json, + Err(e) => { + error!("Failed to parse request body as JSON: {}", e); + let mut response = Response::new(full( + serde_json::json!({ + "error": format!("Invalid request body: {}", e) + }).to_string() + )); + *response.status_mut() = StatusCode::BAD_REQUEST; + response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); + return Ok(response); + } + }; + + // Add "model": "Arch-Function" to the request + if let Some(obj) = body_json.as_object_mut() { + obj.insert("model".to_string(), json!("Arch-Function")); + } + + // Parse as ChatCompletionsRequest + let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) { + Ok(req) => { + info!("[request body]: {}", serde_json::to_string(&req).unwrap_or_default()); + req + }, + Err(e) => { + error!("Failed to parse request body: {}", e); + let mut response = Response::new(full( + serde_json::json!({ + "error": format!("Invalid request body: {}", e) + }).to_string() + )); + *response.status_mut() = StatusCode::BAD_REQUEST; + response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); + return Ok(response); + } + }; + + // Determine which handler to use based on metadata + let use_agent_orchestrator = chat_request.metadata + .as_ref() + .and_then(|m| m.get("use_agent_orchestrator")) + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + info!("Use agent orchestrator: {}", use_agent_orchestrator); + + // Create the appropriate handler + let handler_name = if use_agent_orchestrator { + "Arch-Agent" + } else { + "Arch-Function" + }; + + // Call the handler + let final_response = if use_agent_orchestrator { + let handler = ArchAgentHandler::new( + chat_request.model.clone(), + llm_provider_url.clone(), + ); + handler.function_handler.function_calling_chat(chat_request).await + } else { + let handler = ArchFunctionHandler::new( + chat_request.model.clone(), + ArchFunctionConfig::default(), + llm_provider_url.clone(), + ); + handler.function_calling_chat(chat_request).await + }; + + match final_response { + Ok(response_data) => { + let response_json = serde_json::to_string(&response_data).unwrap_or_else(|e| { + error!("Failed to serialize response: {}", e); + serde_json::json!({"error": "Failed to serialize response"}).to_string() + }); + + let mut response = Response::new(full(response_json)); + *response.status_mut() = StatusCode::OK; + response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); + + Ok(response) + } + Err(e) => { + error!("[{}] - Error in function calling: {}", handler_name, e); + + let error_response = serde_json::json!({ + "error": format!("[{}] - Error in function calling: {}", handler_name, e) + }); + + let mut response = Response::new(full(error_response.to_string())); + *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); + Ok(response) + } + } +} + + +// ============================================================================ +// TESTS +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_arch_function_config_default() { + let config = ArchFunctionConfig::default(); + assert!(config.task_prompt.contains("helpful assistant")); + assert!(config.format_prompt.contains("JSON formats")); + assert_eq!(config.generation_params.temperature, 0.1); + assert_eq!(config.support_data_types.len(), 14); // 8 Python-style + 6 JSON Schema names + } + + #[test] + fn test_arch_agent_config_default() { + let config = ArchAgentConfig::default(); + assert_eq!(config.generation_params.temperature, 0.01); // Different from ArchFunctionConfig + } + + #[test] + fn test_fix_json_string_valid() { + let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); + let json_str = r#"{"name": "test", "value": 123}"#; + let result = handler.fix_json_string(json_str); + assert!(result.is_ok()); + } + + #[test] + fn test_fix_json_string_missing_bracket() { + let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); + let json_str = r#"{"name": "test", "value": 123"#; + let result = handler.fix_json_string(json_str); + assert!(result.is_ok()); + let fixed = result.unwrap(); + assert!(fixed.contains("}")); + } + + #[test] + fn test_parse_model_response_with_tool_calls() { + let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); + let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#; + let result = handler.parse_model_response(content); + + assert!(result.is_valid); + assert_eq!(result.tool_calls.len(), 1); + assert_eq!(result.tool_calls[0].function.name, "get_weather"); + } + + #[test] + fn test_parse_model_response_with_clarification() { + let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); + let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#; + let result = handler.parse_model_response(content); + + assert!(result.is_valid); + assert_eq!(result.required_functions.len(), 1); + assert_eq!(result.clarification, "What location?"); + } + + #[test] + fn test_convert_data_type_int_to_float() { + let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); + let value = json!(42); + let result = handler.convert_data_type(&value, "float"); + assert!(result.is_ok()); + assert!(result.unwrap().is_f64()); + } +} + +// ============================================================================ +// HALLUCINATION DETECTION MODULE +// ============================================================================ + +/// Mask token types for tracking parsing state +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MaskToken { + FunctionName, + ParameterValue, + ParameterName, + NotUsed, + ToolCall, +} + +/// Uncertainty metrics calculated from log probabilities +#[derive(Debug, Clone)] +pub struct UncertaintyMetrics { + pub entropy: f64, + pub varentropy: f64, + pub probability: f64, +} + +/// Calculates uncertainty metrics from log probabilities +/// +/// This is a simplified Rust implementation that avoids torch/tensor dependencies. +/// Uses basic statistical calculations instead of tensor operations. +pub fn calculate_uncertainty(log_probs: &[f64]) -> UncertaintyMetrics { + if log_probs.is_empty() { + return UncertaintyMetrics { + entropy: 0.0, + varentropy: 0.0, + probability: 0.0, + }; + } + + // Convert log probabilities to probabilities + let token_probs: Vec = log_probs.iter().map(|&lp| lp.exp()).collect(); + + // Calculate entropy: -sum(p * log(p)) / log(2) + let mut entropy = 0.0; + for i in 0..log_probs.len() { + entropy -= log_probs[i] * token_probs[i]; + } + entropy /= 2_f64.ln(); // Convert to bits + + // Calculate variance of entropy + let mut varentropy = 0.0; + for i in 0..log_probs.len() { + let diff = log_probs[i] / 2_f64.ln() + entropy; + varentropy += token_probs[i] * diff * diff; + } + + // Get the top probability + let probability = token_probs.first().copied().unwrap_or(0.0); + + UncertaintyMetrics { + entropy, + varentropy, + probability, + } +} + +/// Checks if uncertainty metrics exceed thresholds +pub fn check_threshold( + entropy: f64, + varentropy: f64, + thresholds: &HallucinationThresholds, +) -> bool { + entropy > thresholds.entropy && varentropy > thresholds.varentropy +} + +/// Checks if a parameter is required in the function description +pub fn is_parameter_required( + function_description: &Value, + parameter_name: &str, +) -> bool { + if let Some(required) = function_description.get("required") { + if let Some(required_arr) = required.as_array() { + return required_arr.iter().any(|v| v.as_str() == Some(parameter_name)); + } + } + false +} + +/// Checks if a parameter has a specific property +pub fn is_parameter_property( + function_description: &Value, + parameter_name: &str, + property_name: &str, +) -> bool { + if let Some(properties) = function_description.get("properties") { + if let Some(param_info) = properties.get(parameter_name) { + return param_info.get(property_name).is_some(); + } + } + false +} + +/// State for hallucination detection during streaming +/// +/// This is a simplified version of the Python HallucinationState that doesn't +/// require torch/tensor dependencies. It provides the core functionality needed +/// for detecting hallucinations during function calling. +#[derive(Debug)] +pub struct HallucinationState { + pub tokens: Vec, + pub logprobs: Vec>, + pub state: Option, + pub mask: Vec, + pub parameter_name_done: bool, + pub hallucination: bool, + pub error_message: String, + pub parameter_name: Vec, + pub token_probs_map: Vec<(String, f64, f64, f64)>, + pub function_properties: HashMap, + pub open_bracket: bool, + pub bracket: Option, + pub function_name: String, + pub check_parameter_name: HashMap, + pub thresholds: HallucinationThresholds, +} + +impl HallucinationState { + /// Creates a new HallucinationState with function definitions + pub fn new(functions: &[Tool]) -> Self { + let function_properties: HashMap = functions + .iter() + .map(|tool| { + ( + tool.function.name.clone(), + tool.function.parameters.clone(), + ) + }) + .collect(); + + Self { + tokens: Vec::new(), + logprobs: Vec::new(), + state: None, + mask: Vec::new(), + parameter_name_done: false, + hallucination: false, + error_message: String::new(), + parameter_name: Vec::new(), + token_probs_map: Vec::new(), + function_properties, + open_bracket: false, + bracket: None, + function_name: String::new(), + check_parameter_name: HashMap::new(), + thresholds: HallucinationThresholds::default(), + } + } + + /// Appends a token and checks for hallucination + pub fn append_and_check_token_hallucination( + &mut self, + token: String, + logprob: Vec, + ) -> bool { + self.tokens.push(token); + self.logprobs.push(logprob); + self.process_token(); + self.hallucination + } + + /// Resets internal parameters + fn reset_parameters(&mut self) { + self.state = None; + self.parameter_name_done = false; + self.hallucination = false; + self.error_message.clear(); + self.open_bracket = false; + self.bracket = None; + self.check_parameter_name.clear(); + } + + /// Processes the current token and updates state + fn process_token(&mut self) { + let content: String = self.tokens.join("").replace(' ', ""); + + // Handle end of tool call + if content.ends_with(END_TOOL_CALL_TOKEN) { + self.reset_parameters(); + } + + // Function name extraction logic + if self.state.as_deref() == Some("function_name") { + if !FUNC_NAME_END_TOKEN.iter().any(|&t| self.tokens.last().map_or(false, |tok| tok == t)) { + self.mask.push(MaskToken::FunctionName); + } else { + self.state = None; + self.get_function_name(); + } + } + + // Check for function name start + if FUNC_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) { + self.state = Some("function_name".to_string()); + } + + // Parameter name extraction logic + if self.state.as_deref() == Some("parameter_name") + && !PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) { + self.mask.push(MaskToken::ParameterName); + } else if self.state.as_deref() == Some("parameter_name") + && PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) { + self.state = None; + self.parameter_name_done = true; + self.get_parameter_name(); + } else if self.parameter_name_done + && !self.open_bracket + && PARAMETER_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) { + self.state = Some("parameter_name".to_string()); + } + + // First parameter value start + if FIRST_PARAM_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) { + self.state = Some("parameter_name".to_string()); + } + + // Parameter value extraction logic + if self.state.as_deref() == Some("parameter_value") + && !PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) { + + // Check for brackets + if let Some(last_token) = self.tokens.last() { + let open_brackets: Vec = last_token + .trim() + .chars() + .filter(|&c| c == '(' || c == '{' || c == '[') + .collect(); + + if !open_brackets.is_empty() { + self.open_bracket = true; + self.bracket = Some(open_brackets[0]); + } + + if self.open_bracket { + let closing = match self.bracket { + Some('(') => ')', + Some('{') => '}', + Some('[') => ']', + _ => '\0', + }; + if last_token.trim().contains(closing) { + self.open_bracket = false; + self.bracket = None; + } + } + + // Check if token has actual value content + let has_non_punct = last_token.trim().chars().any(|c| !c.is_ascii_punctuation()); + if has_non_punct && !last_token.trim().is_empty() { + self.mask.push(MaskToken::ParameterValue); + + // Check hallucination for required parameters without enum + if self.function_properties.contains_key(&self.function_name) { + if self.mask.len() > 1 + && self.mask[self.mask.len() - 2] != MaskToken::ParameterValue + && !self.parameter_name.is_empty() + { + let last_param = self.parameter_name[self.parameter_name.len() - 1].clone(); + if let Some(func_props) = self.function_properties.get(&self.function_name) { + if is_parameter_required(func_props, &last_param) + && !is_parameter_property(func_props, &last_param, "enum") + && !self.check_parameter_name.contains_key(&last_param) + { + self.check_logprob(); + self.check_parameter_name.insert(last_param, true); + } + } + } + } else if !self.function_name.is_empty() { + self.check_logprob(); + self.error_message = format!( + "Function name {} not found in function properties", + self.function_name + ); + } + } else { + self.mask.push(MaskToken::NotUsed); + } + } + } else if self.state.as_deref() == Some("parameter_value") + && !self.open_bracket + && PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) { + self.state = None; + } else if self.parameter_name_done + && PARAMETER_VALUE_START_PATTERN.iter().any(|&p| content.ends_with(p)) { + self.state = Some("parameter_value".to_string()); + } + + // Maintain consistency between tokens and mask + if self.mask.len() != self.tokens.len() { + self.mask.push(MaskToken::NotUsed); + } + } + + /// Checks log probability and detects hallucination + fn check_logprob(&mut self) { + if let Some(probs) = self.logprobs.last() { + let metrics = calculate_uncertainty(probs); + + if let Some(token) = self.tokens.last() { + self.token_probs_map.push(( + token.clone(), + metrics.entropy, + metrics.varentropy, + metrics.probability, + )); + + if check_threshold(metrics.entropy, metrics.varentropy, &self.thresholds) { + self.hallucination = true; + self.error_message = format!( + "token '{}' is uncertain. Generated response:\n{}", + token, + self.tokens.join("") + ); + } + } + } + } + + /// Counts consecutive tokens of a specific type in the mask + fn count_consecutive_token(&self, token_type: MaskToken) -> usize { + if self.mask.is_empty() || self.mask.last() != Some(&token_type) { + return 0; + } + + self.mask + .iter() + .rev() + .take_while(|&&t| t == token_type) + .count() + } + + /// Extracts the parameter name from recent tokens + fn get_parameter_name(&mut self) { + let p_len = self.count_consecutive_token(MaskToken::ParameterName); + if p_len > 0 && self.tokens.len() > 1 { + let start_idx = self.tokens.len().saturating_sub(p_len + 1); + let end_idx = self.tokens.len().saturating_sub(1); + let parameter_name: String = self.tokens[start_idx..end_idx].join(""); + self.parameter_name.push(parameter_name); + } + } + + /// Extracts the function name from recent tokens + fn get_function_name(&mut self) { + let f_len = self.count_consecutive_token(MaskToken::FunctionName); + if f_len > 0 && self.tokens.len() > 1 { + let start_idx = self.tokens.len().saturating_sub(f_len + 1); + let end_idx = self.tokens.len().saturating_sub(1); + self.function_name = self.tokens[start_idx..end_idx].join(""); + } + } +} + +#[cfg(test)] +mod hallucination_tests { + use super::*; + + #[test] + fn test_calculate_uncertainty() { + let log_probs = vec![-0.1, -2.0, -3.0]; + let metrics = calculate_uncertainty(&log_probs); + assert!(metrics.entropy >= 0.0); + assert!(metrics.varentropy >= 0.0); + assert!(metrics.probability > 0.0 && metrics.probability <= 1.0); + } + + #[test] + fn test_calculate_uncertainty_empty() { + let log_probs: Vec = vec![]; + let metrics = calculate_uncertainty(&log_probs); + assert_eq!(metrics.entropy, 0.0); + assert_eq!(metrics.varentropy, 0.0); + assert_eq!(metrics.probability, 0.0); + } + + #[test] + fn test_check_threshold() { + let thresholds = HallucinationThresholds::default(); + assert!(check_threshold(0.001, 0.001, &thresholds)); + assert!(!check_threshold(0.00001, 0.00001, &thresholds)); + } + + #[test] + fn test_is_parameter_required() { + let func_desc = json!({ + "required": ["param1", "param2"] + }); + assert!(is_parameter_required(&func_desc, "param1")); + assert!(!is_parameter_required(&func_desc, "param3")); + } + + #[test] + fn test_is_parameter_property() { + let func_desc = json!({ + "properties": { + "param1": { + "type": "string", + "enum": ["a", "b"] + } + } + }); + assert!(is_parameter_property(&func_desc, "param1", "enum")); + assert!(!is_parameter_property(&func_desc, "param1", "default")); + } + + #[test] + fn test_check_value_type() { + let handler = ArchFunctionHandler::new( + "test-model".to_string(), + ArchFunctionConfig::default(), + "http://localhost:8000".to_string() + ); + + // Test integer types + assert!(handler.check_value_type(&json!(42), "integer")); + assert!(handler.check_value_type(&json!(42), "int")); + assert!(!handler.check_value_type(&json!(3.14), "integer")); + + // Test number types (accepts both int and float) + assert!(handler.check_value_type(&json!(3.14), "number")); + assert!(handler.check_value_type(&json!(42), "number")); + assert!(handler.check_value_type(&json!(3.14), "float")); + + // Test boolean + assert!(handler.check_value_type(&json!(true), "boolean")); + assert!(handler.check_value_type(&json!(false), "bool")); + assert!(!handler.check_value_type(&json!("true"), "boolean")); + + // Test string + assert!(handler.check_value_type(&json!("hello"), "string")); + assert!(handler.check_value_type(&json!("hello"), "str")); + assert!(!handler.check_value_type(&json!(123), "string")); + + // Test array + assert!(handler.check_value_type(&json!([1, 2, 3]), "array")); + assert!(handler.check_value_type(&json!([1, 2, 3]), "list")); + assert!(!handler.check_value_type(&json!({}), "array")); + + // Test object + assert!(handler.check_value_type(&json!({"key": "value"}), "object")); + assert!(handler.check_value_type(&json!({"key": "value"}), "dict")); + assert!(!handler.check_value_type(&json!([]), "object")); + + // Test unknown type (should return true) + assert!(handler.check_value_type(&json!(42), "unknown_type")); + } + + #[test] + fn test_validate_or_convert_parameter() { + let handler = ArchFunctionHandler::new( + "test-model".to_string(), + ArchFunctionConfig::default(), + "http://localhost:8000".to_string() + ); + + // Test valid type - no conversion needed + assert!(handler.validate_or_convert_parameter(&json!(42), "integer").unwrap()); + assert!(handler.validate_or_convert_parameter(&json!("hello"), "string").unwrap()); + + // Test integer to float conversion (convert_data_type supports this) + let result = handler.validate_or_convert_parameter(&json!(42), "float"); + assert!(result.is_ok()); + assert!(result.unwrap()); // Should be valid after conversion + + // Test invalid type that cannot be converted + // A string cannot be converted to integer (convert_data_type doesn't support this) + let result = handler.validate_or_convert_parameter(&json!("abc"), "integer"); + // Since convert_data_type returns Ok(value.clone()) for unsupported conversions, + // the validation will fail because "abc" string is not an integer + assert!(!result.unwrap()); + + // Test number accepting both int and float + assert!(handler.validate_or_convert_parameter(&json!(42), "number").unwrap()); + assert!(handler.validate_or_convert_parameter(&json!(3.14), "number").unwrap()); + } + + #[test] + fn test_hallucination_state_new() { + let tools = vec![Tool { + tool_type: "function".to_string(), + function: hermesllm::apis::openai::Function { + name: "test_func".to_string(), + description: Some("Test function".to_string()), + parameters: json!({"type": "object"}), + strict: None, + }, + }]; + + let state = HallucinationState::new(&tools); + assert_eq!(state.tokens.len(), 0); + assert!(!state.hallucination); + assert!(state.function_properties.contains_key("test_func")); + } +} diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index 66c5449b..2583b41e 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,9 +1,11 @@ pub mod agent_chat_completions; pub mod agent_selector; -pub mod chat_completions; +pub mod router; pub mod models; +pub mod function_calling; pub mod pipeline_processor; pub mod response_handler; +pub mod utils; #[cfg(test)] mod integration_tests; diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/router.rs similarity index 89% rename from crates/brightstaff/src/handlers/chat_completions.rs rename to crates/brightstaff/src/handlers/router.rs index 1b15e389..d27bab55 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/router.rs @@ -6,18 +6,15 @@ use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use hermesllm::clients::SupportedAPIs; use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; -use http_body_util::{BodyExt, Full, StreamBody}; -use hyper::body::Frame; +use http_body_util::{BodyExt, Full}; use hyper::header::{self}; use hyper::{Request, Response, StatusCode}; use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; -use tokio_stream::StreamExt; use tracing::{debug, info, warn}; use crate::router::llm_router::RouterService; +use crate::handlers::utils::{create_streaming_response, PassthroughProcessor}; fn full>(chunk: T) -> BoxBody { Full::new(chunk.into()) @@ -25,7 +22,7 @@ fn full>(chunk: T) -> BoxBody { .boxed() } -pub async fn chat( +pub async fn router_chat( request: Request, router_service: Arc, full_qualified_llm_provider_url: String, @@ -237,34 +234,12 @@ pub async fn chat( headers.insert(header_name, header_value.clone()); } - // channel to create async stream - let (tx, rx) = mpsc::channel::(16); + // Use the streaming utility with a passthrough processor (no modification of chunks) + let byte_stream = llm_response.bytes_stream(); + let processor = PassthroughProcessor; + let streaming_response = create_streaming_response(byte_stream, processor, 16); - // Spawn a task to send data as it becomes available - tokio::spawn(async move { - let mut byte_stream = llm_response.bytes_stream(); - - while let Some(item) = byte_stream.next().await { - let item = match item { - Ok(item) => item, - Err(err) => { - warn!("Error receiving chunk: {:?}", err); - break; - } - }; - - if tx.send(item).await.is_err() { - warn!("Receiver dropped"); - break; - } - } - }); - - let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); - - let stream_body = BoxBody::new(StreamBody::new(stream)); - - match response.body(stream_body) { + match response.body(streaming_response.body) { Ok(response) => Ok(response), Err(err) => { let err_msg = format!("Failed to create response: {}", err); diff --git a/crates/brightstaff/src/handlers/utils.rs b/crates/brightstaff/src/handlers/utils.rs new file mode 100644 index 00000000..2d000874 --- /dev/null +++ b/crates/brightstaff/src/handlers/utils.rs @@ -0,0 +1,93 @@ +use bytes::Bytes; +use http_body_util::combinators::BoxBody; +use http_body_util::StreamBody; +use hyper::body::Frame; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; +use tracing::warn; + +/// Trait for processing streaming chunks +/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging) +pub trait StreamProcessor: Send + 'static { + /// Process an incoming chunk of bytes + fn process_chunk(&mut self, chunk: Bytes) -> Result, String>; + + /// Called when streaming completes successfully + fn on_complete(&mut self) {} + + /// Called when streaming encounters an error + fn on_error(&mut self, _error: &str) {} +} + +/// A no-op processor that just forwards chunks as-is +pub struct PassthroughProcessor; + +impl StreamProcessor for PassthroughProcessor { + fn process_chunk(&mut self, chunk: Bytes) -> Result, String> { + Ok(Some(chunk)) + } +} + +/// Result of creating a streaming response +pub struct StreamingResponse { + pub body: BoxBody, + pub processor_handle: tokio::task::JoinHandle<()>, +} + +pub fn create_streaming_response( + mut byte_stream: S, + mut processor: P, + buffer_size: usize, +) -> StreamingResponse +where + S: StreamExt> + Send + Unpin + 'static, + P: StreamProcessor, +{ + let (tx, rx) = mpsc::channel::(buffer_size); + + // Spawn a task to process and forward chunks + let processor_handle = tokio::spawn(async move { + while let Some(item) = byte_stream.next().await { + let chunk = match item { + Ok(chunk) => chunk, + Err(err) => { + let err_msg = format!("Error receiving chunk: {:?}", err); + warn!("{}", err_msg); + processor.on_error(&err_msg); + break; + } + }; + + // Process the chunk + match processor.process_chunk(chunk) { + Ok(Some(processed_chunk)) => { + if tx.send(processed_chunk).await.is_err() { + warn!("Receiver dropped"); + break; + } + } + Ok(None) => { + // Skip this chunk + continue; + } + Err(err) => { + warn!("Processor error: {}", err); + processor.on_error(&err); + break; + } + } + } + + processor.on_complete(); + }); + + // Convert channel receiver to HTTP stream + let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); + let stream_body = BoxBody::new(StreamBody::new(stream)); + + StreamingResponse { + body: stream_body, + processor_handle, + } +} diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 57dd9fe9..265ee5ba 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,6 +1,7 @@ use brightstaff::handlers::agent_chat_completions::agent_chat; -use brightstaff::handlers::chat_completions::chat; +use brightstaff::handlers::router::router_chat; use brightstaff::handlers::models::list_models; +use brightstaff::handlers::function_calling::{function_calling_chat_handler}; use brightstaff::router::llm_router::RouterService; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; @@ -125,7 +126,7 @@ async fn main() -> Result<(), Box> { (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => { let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path()); - chat(req, router_service, fully_qualified_url, model_aliases) + router_chat(req, router_service, fully_qualified_url, model_aliases) .with_context(parent_cx) .await } @@ -142,6 +143,14 @@ async fn main() -> Result<(), Box> { .with_context(parent_cx) .await } + + (&Method::POST, "/function_calling") => { + let fully_qualified_url = + format!("{}{}", llm_provider_url, "/v1"); + function_calling_chat_handler(req, fully_qualified_url) + .with_context(parent_cx) + .await + } (&Method::GET, "/v1/models" | "/agents/v1/models") => { Ok(list_models(llm_providers).await) } diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 82c5d1a1..90a180ba 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -385,6 +385,8 @@ pub struct ChatCompletionsResponse { pub usage: Usage, pub system_fingerprint: Option, pub service_tier: Option, + // This isn't a standard OpenAI field, but we include it for extensibility + pub metadata: Option>, } impl Default for ChatCompletionsResponse { @@ -398,6 +400,7 @@ impl Default for ChatCompletionsResponse { usage: Usage::default(), system_fingerprint: None, service_tier: None, + metadata: None, } } } diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index f09b2c04..54fda8c4 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -316,6 +316,17 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent { // Create a new transformed event based on the original let mut transformed_event = sse_event; + // Handle [DONE] marker early - don't try to parse as JSON + if transformed_event.is_done() { + // For OpenAI client API, keep [DONE] as-is + // For Anthropic client API, it will be transformed via ProviderStreamResponseType + if matches!(client_api, SupportedAPIs::OpenAIChatCompletions(_)) { + // Keep the [DONE] marker as-is for OpenAI clients + transformed_event.sse_transform_buffer = "data: [DONE]".to_string(); + return Ok(transformed_event); + } + } + // If has data, parse the data as a provider stream response (business logic layer) if transformed_event.data.is_some() { let data_str = transformed_event.data.as_ref().unwrap(); diff --git a/crates/hermesllm/src/transforms/response/to_openai.rs b/crates/hermesllm/src/transforms/response/to_openai.rs index acbdb420..b44afc96 100644 --- a/crates/hermesllm/src/transforms/response/to_openai.rs +++ b/crates/hermesllm/src/transforms/response/to_openai.rs @@ -83,8 +83,7 @@ impl TryFrom for ChatCompletionsResponse { model: resp.model, choices: vec![choice], usage, - system_fingerprint: None, - service_tier: None, + ..Default::default() }) } } @@ -169,8 +168,7 @@ impl TryFrom for ChatCompletionsResponse { model, choices: vec![choice], usage, - system_fingerprint: None, - service_tier: None, + ..Default::default() }) } } diff --git a/tests/modelserver/test_hallucination.py b/tests/modelserver/test_hallucination.py deleted file mode 100644 index 323db3fc..00000000 --- a/tests/modelserver/test_hallucination.py +++ /dev/null @@ -1,44 +0,0 @@ -import os -import pytest -import requests -import logging -import yaml - -pytestmark = pytest.mark.skip( - reason="Skipping entire test file as hallucination is not enabled for archfc 1.1 yet" -) - -MODEL_SERVER_ENDPOINT = os.getenv( - "MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling" -) - -# Load test data from YAML file -script_dir = os.path.dirname(__file__) - -# Construct the full path to the YAML file -yaml_file_path = os.path.join(script_dir, "test_hallucination_data.yaml") - -# Load test data from YAML file -with open(yaml_file_path, "r") as file: - test_data_yaml = yaml.safe_load(file) - - -@pytest.mark.parametrize( - "test_data", - [ - pytest.param(test_case, id=test_case["id"]) - for test_case in test_data_yaml["test_cases"] - ], -) -def test_model_server(test_data): - input = test_data["input"] - expected = test_data["expected"] - - response = requests.post(MODEL_SERVER_ENDPOINT, json=input) - assert response.status_code == 200 - assert response.headers["content-type"] == "application/json" - - response_json = response.json() - assert response_json - metadata = response_json.get("metadata", {}) - assert (metadata["hallucination"].lower() == "true") == expected[0]["hallucination"] diff --git a/tests/modelserver/test_hallucination_data.yaml b/tests/modelserver/test_hallucination_data.yaml deleted file mode 100644 index 935a8f5f..00000000 --- a/tests/modelserver/test_hallucination_data.yaml +++ /dev/null @@ -1,257 +0,0 @@ -test_cases: - - id: "[WEATHER AGENT] - single turn, single tool, prompt prefilling" - input: - messages: - - role: "user" - content: "what is the weather forecast for seattle?" - tools: - - type: "function" - function: - name: "get_current_weather" - description: "Get current weather at a location." - parameters: - type: "object" - properties: - location: - type: "string" - description: "The location to get the weather for" - format: "City, State" - days: - type: "integer" - description: "The number of days for the request." - required: - - location - - days - expected: - - type: "metadata" - hallucination: false - - - id: "[WEATHER AGENT] - single turn, single tool, hallucination" - input: - messages: - - role: "user" - content: "what is the weather in Seattle in days?" - tools: - - type: "function" - function: - name: "get_current_weather" - description: "Get current weather at a location." - parameters: - type: "object" - properties: - location: - type: "str" - description: "The location to get the weather for" - format: "City, State" - days: - type: "int" - description: "the number of days for the request." - required: ["location", "days"] - expected: - - type: "metadata" - hallucination: true - - - id: "[WEATHER AGENT] - multi turn, single tool, all params passed" - input: - messages: - - role: "user" - content: "how is the weather in chicago for next 5 days?" - - role: "assistant" - content: "Can you tell me your location and how many days you want?" - - role: "user" - content: "Seattle" - - role: "assistant" - content: "Can you please provide me the days for the weather forecast?" - - role: "user" - content: "5 days" - tools: - - type: "function" - function: - name: "get_current_weather" - description: "Get current weather at a location." - parameters: - type: "object" - properties: - location: - type: "str" - description: "The location to get the weather for" - format: "City, State" - days: - type: "int" - description: "the number of days for the request." - required: ["location", "days"] - expected: - - type: "metadata" - hallucination: false - - - id: "[WEATHER AGENT] - multi turn, single tool, clarification" - input: - messages: - - role: "user" - content: "how is the weather for next 5 days?" - - role: "assistant" - content: "Can you tell me your location and how many days you want?" - - role: "user" - content: "Seattle" - - role: "assistant" - content: "Can you please provide me the days for the weather forecast?" - - role: "user" - content: "Sorry, the location is actually los angeles in 5 days" - tools: - - type: "function" - function: - name: "get_current_weather" - description: "Get current weather at a location." - parameters: - type: "object" - properties: - location: - type: "str" - description: "The location to get the weather for" - format: "City, State" - days: - type: "int" - description: "the number of days for the request." - required: ["location", "days"] - expected: - - type: "metadata" - hallucination: false - - - id: "[SALE AGENT] - single turn, single tool, hallucination region" - input: - messages: - - role: "user" - content: "get me sales opportunities of tech" - tools: - - type: "function" - function: - name: "sales_opportunity" - description: "Retrieve potential sales opportunities based for a particular industry type in a region." - parameters: - type: "object" - properties: - region: - type: "str" - description: "Geographical region to identify sales opportunities." - industry: - type: "str" - description: "Industry type." - max_results: - type: "int" - description: "Maximum number of sales opportunities to retrieve." - default: 20 - required: ["region", "industry"] - expected: - - type: "metadata" - hallucination: true - - - id: "[SALE AGENT] - single turn, single tool, hallucination industry" - input: - messages: - - role: "user" - content: "get me sales opportunities in NA" - tools: - - type: "function" - function: - name: "sales_opportunity" - description: "Retrieve potential sales opportunities based for a particular industry type in a region." - parameters: - type: "object" - properties: - region: - type: "str" - description: "Geographical region to identify sales opportunities." - industry: - type: "str" - description: "Industry type." - max_results: - type: "int" - description: "Maximum number of sales opportunities to retrieve." - default: 20 - required: ["region", "industry"] - expected: - - type: "metadata" - hallucination: true - - - id: "[PRODUCT AGENT] - single turn, single tool, hallucination industry" - input: - messages: - - role: "user" - content: "get me sales opportunities in NA" - tools: - - type: "function" - function: - name: "product_recommendation" - description: "Place an order for an iphone with user_id 195 and location is 1600 pensylvania ave" - parameters: - type: "object" - properties: - user_id: - type: "str" - description: "Unique identifier for the user." - category: - type: "str" - description: "Product category for recommendations." - max_results: - type: "int" - description: "Maximum number of recommended products to show." - default: 10 - required: ["user_id", "category"] - - type: "function" - function: - name: "place_order" - description: "Place and pay for an order for one or more products to ship to the an address." - parameters: - type: "object" - properties: - user_id: - type: "str" - description: "Unique identifier for the user placing the order." - product_ids: - type: "array" - description: "List of product IDs to include in the order." - shipping_address: - type: "str" - description: "Shipping address for the order." - payment_method: - type: "str" - description: "Payment method for the order." - required: ["user_id", "product_ids", "shipping_address", "payment_method"] - - type: "function" - function: - name: "sales_opportunity" - description: "Retrieve potential sales opportunities based for a particular industry type in a region." - parameters: - type: "object" - properties: - region: - type: "str" - description: "Geographical region to identify sales opportunities." - industry: - type: "str" - description: "Industry type." - max_results: - type: "int" - description: "Maximum number of sales opportunities to retrieve." - default: 20 - required: ["region", "industry"] - - type: "function" - function: - name: "query_database" - description: "Perform a database query to retrieve or update information." - parameters: - type: "object" - properties: - query: - type: "str" - description: "SQL query string to execute against the database." - parameters: - type: "array" - description: "List of parameters to safely inject into the SQL query (to prevent SQL injection)." - operation: - type: "str" - description: "Type of operation." - required: ["query", "operation"] - expected: - - type: "metadata" - hallucination: true diff --git a/tests/modelserver/test_modelserver.py b/tests/modelserver/test_modelserver.py index 4596606f..f18c803c 100644 --- a/tests/modelserver/test_modelserver.py +++ b/tests/modelserver/test_modelserver.py @@ -10,7 +10,7 @@ pytestmark = pytest.mark.skip( ) MODEL_SERVER_ENDPOINT = os.getenv( - "MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling" + "MODEL_SERVER_ENDPOINT", "http://localhost:12000/function_calling" ) # Load test data from YAML file diff --git a/tests/rest/api_model_server.rest b/tests/rest/api_model_server.rest index 5fdbf968..9c094c19 100644 --- a/tests/rest/api_model_server.rest +++ b/tests/rest/api_model_server.rest @@ -1,4 +1,4 @@ -@model_server_endpoint = http://localhost:51000 +@model_server_endpoint = http://localhost:12000 @archfc_endpoint = https://archfc.katanemo.dev ### talk to function calling endpoint diff --git a/tests/rest/insurance_agent.rest b/tests/rest/insurance_agent.rest index c45ebb85..f5a86f8f 100644 --- a/tests/rest/insurance_agent.rest +++ b/tests/rest/insurance_agent.rest @@ -1,4 +1,4 @@ -@model_server_endpoint = http://localhost:51000 +@model_server_endpoint = http://localhost:12000 @archfc_endpoint = https://archfc.katanemo.dev ### multi turn conversation with intent, except parameter gathering @@ -54,26 +54,8 @@ Content-Type: application/json } ] } -### talk to Arch-Intent directly for completion -POST https://archfc.katanemo.dev/v1/chat/completions HTTP/1.1 -Content-Type: application/json - -{ - "model": "Arch-Intent", - "messages": [ - { - "role": "system", - "content": "You are a helpful assistant.\n\nYou task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.\n\n\n{\"index\": \"T0\", \"type\": \"function\", \"function\": {\"name\": \"weather_forecast\", \"parameters\": {\"type\": \"object\", \"properties\": {\"city\": {\"type\": \"str\"}, \"days\": {\"type\": \"int\"}}, \"required\": [\"city\", \"days\"]}}}\n\n\nProvide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:\n- First line must read 'Yes' or 'No'.\n- If yes, a second line must include a comma-separated list of tool indexes.\n" - }, - { "role": "user", "content": "hi" } - ], - "stream": false -} - - - -### multi turn conversation with correct parameters +### multi turn conversation with intent, except parameter gathering POST {{model_server_endpoint}}/function_calling HTTP/1.1 Content-Type: application/json @@ -125,21 +107,6 @@ Content-Type: application/json } ] } -### talk to Arch-Intent directly for completion, expect No -POST https://archfc.katanemo.dev/v1/chat/completions HTTP/1.1 -Content-Type: application/json - -{ - "model": "Arch-Intent", - "messages": [ - { - "role": "system", - "content": "You are a helpful assistant.\n\nYou task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.\n\n\n{\"index\": \"T0\", \"type\": \"function\", \"function\": {\"name\": \"weather_forecast\", \"parameters\": {\"type\": \"object\", \"properties\": {\"city\": {\"type\": \"str\"}, \"days\": {\"type\": \"int\"}}, \"required\": [\"city\", \"days\"]}}}\n\n\nProvide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:\n- First line must read 'Yes' or 'No'.\n- If yes, a second line must include a comma-separated list of tool indexes.\n" - }, - { "role": "user", "content": "what is your name" } - ], - "stream": false -} ### multi turn conversation with correct parameters POST {{model_server_endpoint}}/function_calling HTTP/1.1 diff --git a/tests/rest/network_agent.rest b/tests/rest/network_agent.rest index dc03fa6c..07f746ca 100644 --- a/tests/rest/network_agent.rest +++ b/tests/rest/network_agent.rest @@ -1,4 +1,4 @@ -@model_server_endpoint = http://localhost:51000 +@model_server_endpoint = http://localhost:12000 @archfc_endpoint = https://archfc.katanemo.dev ### single turn function calling all parameters insurance agent summary