adding function_calling functionality via rust

This commit is contained in:
Salman Paracha 2025-11-11 21:09:20 -08:00
parent 126b029345
commit d6a1b70594
17 changed files with 2338 additions and 389 deletions

View file

@ -304,6 +304,14 @@ def validate_and_render_schema():
}
)
# Always add arch-function model provider
updated_model_providers.append(
{
"name": "arch-function",
"provider_interface": "arch",
"model": "Arch-Function",
}
)
config_yaml["model_providers"] = deepcopy(updated_model_providers)
listeners_with_provider = 0

312
crates/Cargo.lock generated
View file

@ -78,6 +78,43 @@ dependencies = [
"serde_json",
]
[[package]]
name = "async-openai"
version = "0.30.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bf39a15c8d613eb61892dc9a287c02277639ebead41ee611ad23aaa613f1a82"
dependencies = [
"async-openai-macros",
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand 0.9.2",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.12",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-openai-macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "async-trait"
version = "0.1.88"
@ -130,6 +167,20 @@ dependencies = [
"time",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.16",
"instant",
"pin-project-lite",
"rand 0.8.5",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.75"
@ -201,7 +252,9 @@ dependencies = [
name = "brightstaff"
version = "0.1.0"
dependencies = [
"async-openai",
"bytes",
"chrono",
"common",
"eventsource-client",
"eventsource-stream",
@ -219,6 +272,7 @@ dependencies = [
"opentelemetry-stdout",
"opentelemetry_sdk",
"pretty_assertions",
"rand 0.9.2",
"reqwest",
"serde",
"serde_json",
@ -231,6 +285,7 @@ dependencies = [
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"uuid",
]
[[package]]
@ -281,6 +336,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
version = "0.4.41"
@ -289,8 +350,10 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-link",
]
@ -336,6 +399,16 @@ dependencies = [
"libc",
]
[[package]]
name = "core-foundation"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
@ -426,6 +499,37 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "derive_builder"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "derive_builder_macro"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
"syn 2.0.101",
]
[[package]]
name = "diff"
version = "0.1.13"
@ -650,6 +754,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
@ -685,8 +795,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
@ -696,9 +808,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"r-efi",
"wasi 0.14.2+wasi-0.2.4",
"wasm-bindgen",
]
[[package]]
@ -934,7 +1048,7 @@ dependencies = [
"hyper 0.14.32",
"log",
"rustls 0.21.12",
"rustls-native-certs",
"rustls-native-certs 0.6.3",
"tokio",
"tokio-rustls 0.24.1",
]
@ -949,6 +1063,7 @@ dependencies = [
"hyper 1.6.0",
"hyper-util",
"rustls 0.23.27",
"rustls-native-certs 0.8.2",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.26.2",
@ -1181,6 +1296,15 @@ dependencies = [
"serde",
]
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if",
]
[[package]]
name = "ipnet"
version = "2.11.0"
@ -1285,6 +1409,12 @@ version = "0.4.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
name = "lru-slab"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]]
name = "matchers"
version = "0.1.0"
@ -1312,6 +1442,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -1354,7 +1494,7 @@ dependencies = [
"hyper 1.6.0",
"hyper-util",
"log",
"rand 0.9.1",
"rand 0.9.2",
"regex",
"serde_json",
"serde_urlencoded",
@ -1374,7 +1514,7 @@ dependencies = [
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework 2.11.1",
"security-framework-sys",
"tempfile",
]
@ -1581,7 +1721,7 @@ dependencies = [
"glob",
"opentelemetry",
"percent-encoding",
"rand 0.9.1",
"rand 0.9.2",
"serde_json",
"thiserror 2.0.12",
"tracing",
@ -1770,6 +1910,61 @@ dependencies = [
"log",
]
[[package]]
name = "quinn"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
dependencies = [
"bytes",
"cfg_aliases",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash 2.1.1",
"rustls 0.23.27",
"socket2",
"thiserror 2.0.12",
"tokio",
"tracing",
"web-time",
]
[[package]]
name = "quinn-proto"
version = "0.11.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
dependencies = [
"bytes",
"getrandom 0.3.3",
"lru-slab",
"rand 0.9.2",
"ring",
"rustc-hash 2.1.1",
"rustls 0.23.27",
"rustls-pki-types",
"slab",
"thiserror 2.0.12",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.40"
@ -1798,9 +1993,9 @@ dependencies = [
[[package]]
name = "rand"
version = "0.9.1"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
@ -1941,10 +2136,14 @@ dependencies = [
"js-sys",
"log",
"mime",
"mime_guess",
"native-tls",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls 0.23.27",
"rustls-native-certs 0.8.2",
"rustls-pki-types",
"serde",
"serde_json",
@ -1952,6 +2151,7 @@ dependencies = [
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-rustls 0.26.2",
"tokio-util",
"tower 0.5.2",
"tower-http",
@ -1963,6 +2163,22 @@ dependencies = [
"web-sys",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.14"
@ -1989,6 +2205,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustix"
version = "1.0.7"
@ -2021,6 +2243,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321"
dependencies = [
"once_cell",
"ring",
"rustls-pki-types",
"rustls-webpki 0.103.3",
"subtle",
@ -2036,7 +2259,19 @@ dependencies = [
"openssl-probe",
"rustls-pemfile",
"schannel",
"security-framework",
"security-framework 2.11.1",
]
[[package]]
name = "rustls-native-certs"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9980d917ebb0c0536119ba501e90834767bffc3d60641457fd84a1f3fd337923"
dependencies = [
"openssl-probe",
"rustls-pki-types",
"schannel",
"security-framework 3.5.1",
]
[[package]]
@ -2054,6 +2289,7 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79"
dependencies = [
"web-time",
"zeroize",
]
@ -2142,6 +2378,16 @@ version = "3.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
@ -2149,7 +2395,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework"
version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef"
dependencies = [
"bitflags",
"core-foundation 0.10.1",
"core-foundation-sys",
"libc",
"security-framework-sys",
@ -2420,7 +2679,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation 0.9.4",
"system-configuration-sys",
]
@ -2509,7 +2768,7 @@ dependencies = [
"fancy-regex",
"lazy_static",
"parking_lot",
"rustc-hash",
"rustc-hash 1.1.0",
]
[[package]]
@ -2553,6 +2812,21 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.45.1"
@ -2829,6 +3103,12 @@ version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-ident"
version = "1.0.18"
@ -2870,6 +3150,18 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "uuid"
version = "1.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2"
dependencies = [
"getrandom 0.3.3",
"js-sys",
"serde",
"wasm-bindgen",
]
[[package]]
name = "valuable"
version = "0.1.1"

View file

@ -4,7 +4,9 @@ version = "0.1.0"
edition = "2021"
[dependencies]
async-openai = "0.30.1"
bytes = "1.10.1"
chrono = "0.4"
common = { version = "0.1.0", path = "../common" }
eventsource-client = "0.15.0"
eventsource-stream = "0.2.3"
@ -21,6 +23,7 @@ opentelemetry-otlp = {version="0.29.0", features=["trace", "tonic", "grpc-tonic"
opentelemetry-stdout = "0.29.0"
opentelemetry_sdk = "0.29.0"
pretty_assertions = "1.4.1"
rand = "0.9.2"
reqwest = { version = "0.12.15", features = ["stream"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
@ -32,6 +35,7 @@ tokio-stream = "0.1"
time = { version = "0.3", features = ["formatting", "macros"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.0", features = ["v4", "serde"] }
[dev-dependencies]
mockito = "1.0"

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,11 @@
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod chat_completions;
pub mod router;
pub mod models;
pub mod function_calling;
pub mod pipeline_processor;
pub mod response_handler;
pub mod utils;
#[cfg(test)]
mod integration_tests;

View file

@ -6,18 +6,15 @@ use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::clients::SupportedAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use http_body_util::{BodyExt, Full};
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::handlers::utils::{create_streaming_response, PassthroughProcessor};
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
@ -25,7 +22,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
.boxed()
}
pub async fn chat(
pub async fn router_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
@ -237,34 +234,12 @@ pub async fn chat(
headers.insert(header_name, header_value.clone());
}
// channel to create async stream
let (tx, rx) = mpsc::channel::<Bytes>(16);
// Use the streaming utility with a passthrough processor (no modification of chunks)
let byte_stream = llm_response.bytes_stream();
let processor = PassthroughProcessor;
let streaming_response = create_streaming_response(byte_stream, processor, 16);
// Spawn a task to send data as it becomes available
tokio::spawn(async move {
let mut byte_stream = llm_response.bytes_stream();
while let Some(item) = byte_stream.next().await {
let item = match item {
Ok(item) => item,
Err(err) => {
warn!("Error receiving chunk: {:?}", err);
break;
}
};
if tx.send(item).await.is_err() {
warn!("Receiver dropped");
break;
}
}
});
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
match response.body(stream_body) {
match response.body(streaming_response.body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);

View file

@ -0,0 +1,93 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody;
use hyper::body::Frame;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::warn;
/// Trait for processing streaming chunks
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
pub trait StreamProcessor: Send + 'static {
/// Process an incoming chunk of bytes
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String>;
/// Called when streaming completes successfully
fn on_complete(&mut self) {}
/// Called when streaming encounters an error
fn on_error(&mut self, _error: &str) {}
}
/// A no-op processor that just forwards chunks as-is
pub struct PassthroughProcessor;
impl StreamProcessor for PassthroughProcessor {
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
Ok(Some(chunk))
}
}
/// Result of creating a streaming response
pub struct StreamingResponse {
pub body: BoxBody<Bytes, hyper::Error>,
pub processor_handle: tokio::task::JoinHandle<()>,
}
pub fn create_streaming_response<S, P>(
mut byte_stream: S,
mut processor: P,
buffer_size: usize,
) -> StreamingResponse
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size);
// Spawn a task to process and forward chunks
let processor_handle = tokio::spawn(async move {
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
let err_msg = format!("Error receiving chunk: {:?}", err);
warn!("{}", err_msg);
processor.on_error(&err_msg);
break;
}
};
// Process the chunk
match processor.process_chunk(chunk) {
Ok(Some(processed_chunk)) => {
if tx.send(processed_chunk).await.is_err() {
warn!("Receiver dropped");
break;
}
}
Ok(None) => {
// Skip this chunk
continue;
}
Err(err) => {
warn!("Processor error: {}", err);
processor.on_error(&err);
break;
}
}
}
processor.on_complete();
});
// Convert channel receiver to HTTP stream
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
StreamingResponse {
body: stream_body,
processor_handle,
}
}

View file

@ -1,6 +1,7 @@
use brightstaff::handlers::agent_chat_completions::agent_chat;
use brightstaff::handlers::chat_completions::chat;
use brightstaff::handlers::router::router_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::function_calling::{function_calling_chat_handler};
use brightstaff::router::llm_router::RouterService;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
@ -125,7 +126,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
chat(req, router_service, fully_qualified_url, model_aliases)
router_chat(req, router_service, fully_qualified_url, model_aliases)
.with_context(parent_cx)
.await
}
@ -142,6 +143,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.with_context(parent_cx)
.await
}
(&Method::POST, "/function_calling") => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, "/v1");
function_calling_chat_handler(req, fully_qualified_url)
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
Ok(list_models(llm_providers).await)
}

View file

@ -385,6 +385,8 @@ pub struct ChatCompletionsResponse {
pub usage: Usage,
pub system_fingerprint: Option<String>,
pub service_tier: Option<String>,
// This isn't a standard OpenAI field, but we include it for extensibility
pub metadata: Option<HashMap<String, Value>>,
}
impl Default for ChatCompletionsResponse {
@ -398,6 +400,7 @@ impl Default for ChatCompletionsResponse {
usage: Usage::default(),
system_fingerprint: None,
service_tier: None,
metadata: None,
}
}
}

View file

@ -316,6 +316,17 @@ impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedUpstreamAPIs)> for SseEvent {
// Create a new transformed event based on the original
let mut transformed_event = sse_event;
// Handle [DONE] marker early - don't try to parse as JSON
if transformed_event.is_done() {
// For OpenAI client API, keep [DONE] as-is
// For Anthropic client API, it will be transformed via ProviderStreamResponseType
if matches!(client_api, SupportedAPIs::OpenAIChatCompletions(_)) {
// Keep the [DONE] marker as-is for OpenAI clients
transformed_event.sse_transform_buffer = "data: [DONE]".to_string();
return Ok(transformed_event);
}
}
// If has data, parse the data as a provider stream response (business logic layer)
if transformed_event.data.is_some() {
let data_str = transformed_event.data.as_ref().unwrap();

View file

@ -83,8 +83,7 @@ impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
model: resp.model,
choices: vec![choice],
usage,
system_fingerprint: None,
service_tier: None,
..Default::default()
})
}
}
@ -169,8 +168,7 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
model,
choices: vec![choice],
usage,
system_fingerprint: None,
service_tier: None,
..Default::default()
})
}
}

View file

@ -1,44 +0,0 @@
import os
import pytest
import requests
import logging
import yaml
pytestmark = pytest.mark.skip(
reason="Skipping entire test file as hallucination is not enabled for archfc 1.1 yet"
)
MODEL_SERVER_ENDPOINT = os.getenv(
"MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling"
)
# Load test data from YAML file
script_dir = os.path.dirname(__file__)
# Construct the full path to the YAML file
yaml_file_path = os.path.join(script_dir, "test_hallucination_data.yaml")
# Load test data from YAML file
with open(yaml_file_path, "r") as file:
test_data_yaml = yaml.safe_load(file)
@pytest.mark.parametrize(
"test_data",
[
pytest.param(test_case, id=test_case["id"])
for test_case in test_data_yaml["test_cases"]
],
)
def test_model_server(test_data):
input = test_data["input"]
expected = test_data["expected"]
response = requests.post(MODEL_SERVER_ENDPOINT, json=input)
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
response_json = response.json()
assert response_json
metadata = response_json.get("metadata", {})
assert (metadata["hallucination"].lower() == "true") == expected[0]["hallucination"]

View file

@ -1,257 +0,0 @@
test_cases:
- id: "[WEATHER AGENT] - single turn, single tool, prompt prefilling"
input:
messages:
- role: "user"
content: "what is the weather forecast for seattle?"
tools:
- type: "function"
function:
name: "get_current_weather"
description: "Get current weather at a location."
parameters:
type: "object"
properties:
location:
type: "string"
description: "The location to get the weather for"
format: "City, State"
days:
type: "integer"
description: "The number of days for the request."
required:
- location
- days
expected:
- type: "metadata"
hallucination: false
- id: "[WEATHER AGENT] - single turn, single tool, hallucination"
input:
messages:
- role: "user"
content: "what is the weather in Seattle in days?"
tools:
- type: "function"
function:
name: "get_current_weather"
description: "Get current weather at a location."
parameters:
type: "object"
properties:
location:
type: "str"
description: "The location to get the weather for"
format: "City, State"
days:
type: "int"
description: "the number of days for the request."
required: ["location", "days"]
expected:
- type: "metadata"
hallucination: true
- id: "[WEATHER AGENT] - multi turn, single tool, all params passed"
input:
messages:
- role: "user"
content: "how is the weather in chicago for next 5 days?"
- role: "assistant"
content: "Can you tell me your location and how many days you want?"
- role: "user"
content: "Seattle"
- role: "assistant"
content: "Can you please provide me the days for the weather forecast?"
- role: "user"
content: "5 days"
tools:
- type: "function"
function:
name: "get_current_weather"
description: "Get current weather at a location."
parameters:
type: "object"
properties:
location:
type: "str"
description: "The location to get the weather for"
format: "City, State"
days:
type: "int"
description: "the number of days for the request."
required: ["location", "days"]
expected:
- type: "metadata"
hallucination: false
- id: "[WEATHER AGENT] - multi turn, single tool, clarification"
input:
messages:
- role: "user"
content: "how is the weather for next 5 days?"
- role: "assistant"
content: "Can you tell me your location and how many days you want?"
- role: "user"
content: "Seattle"
- role: "assistant"
content: "Can you please provide me the days for the weather forecast?"
- role: "user"
content: "Sorry, the location is actually los angeles in 5 days"
tools:
- type: "function"
function:
name: "get_current_weather"
description: "Get current weather at a location."
parameters:
type: "object"
properties:
location:
type: "str"
description: "The location to get the weather for"
format: "City, State"
days:
type: "int"
description: "the number of days for the request."
required: ["location", "days"]
expected:
- type: "metadata"
hallucination: false
- id: "[SALE AGENT] - single turn, single tool, hallucination region"
input:
messages:
- role: "user"
content: "get me sales opportunities of tech"
tools:
- type: "function"
function:
name: "sales_opportunity"
description: "Retrieve potential sales opportunities based for a particular industry type in a region."
parameters:
type: "object"
properties:
region:
type: "str"
description: "Geographical region to identify sales opportunities."
industry:
type: "str"
description: "Industry type."
max_results:
type: "int"
description: "Maximum number of sales opportunities to retrieve."
default: 20
required: ["region", "industry"]
expected:
- type: "metadata"
hallucination: true
- id: "[SALE AGENT] - single turn, single tool, hallucination industry"
input:
messages:
- role: "user"
content: "get me sales opportunities in NA"
tools:
- type: "function"
function:
name: "sales_opportunity"
description: "Retrieve potential sales opportunities based for a particular industry type in a region."
parameters:
type: "object"
properties:
region:
type: "str"
description: "Geographical region to identify sales opportunities."
industry:
type: "str"
description: "Industry type."
max_results:
type: "int"
description: "Maximum number of sales opportunities to retrieve."
default: 20
required: ["region", "industry"]
expected:
- type: "metadata"
hallucination: true
- id: "[PRODUCT AGENT] - single turn, single tool, hallucination industry"
input:
messages:
- role: "user"
content: "get me sales opportunities in NA"
tools:
- type: "function"
function:
name: "product_recommendation"
description: "Place an order for an iphone with user_id 195 and location is 1600 pensylvania ave"
parameters:
type: "object"
properties:
user_id:
type: "str"
description: "Unique identifier for the user."
category:
type: "str"
description: "Product category for recommendations."
max_results:
type: "int"
description: "Maximum number of recommended products to show."
default: 10
required: ["user_id", "category"]
- type: "function"
function:
name: "place_order"
description: "Place and pay for an order for one or more products to ship to the an address."
parameters:
type: "object"
properties:
user_id:
type: "str"
description: "Unique identifier for the user placing the order."
product_ids:
type: "array"
description: "List of product IDs to include in the order."
shipping_address:
type: "str"
description: "Shipping address for the order."
payment_method:
type: "str"
description: "Payment method for the order."
required: ["user_id", "product_ids", "shipping_address", "payment_method"]
- type: "function"
function:
name: "sales_opportunity"
description: "Retrieve potential sales opportunities based for a particular industry type in a region."
parameters:
type: "object"
properties:
region:
type: "str"
description: "Geographical region to identify sales opportunities."
industry:
type: "str"
description: "Industry type."
max_results:
type: "int"
description: "Maximum number of sales opportunities to retrieve."
default: 20
required: ["region", "industry"]
- type: "function"
function:
name: "query_database"
description: "Perform a database query to retrieve or update information."
parameters:
type: "object"
properties:
query:
type: "str"
description: "SQL query string to execute against the database."
parameters:
type: "array"
description: "List of parameters to safely inject into the SQL query (to prevent SQL injection)."
operation:
type: "str"
description: "Type of operation."
required: ["query", "operation"]
expected:
- type: "metadata"
hallucination: true

View file

@ -10,7 +10,7 @@ pytestmark = pytest.mark.skip(
)
MODEL_SERVER_ENDPOINT = os.getenv(
"MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling"
"MODEL_SERVER_ENDPOINT", "http://localhost:12000/function_calling"
)
# Load test data from YAML file

View file

@ -1,4 +1,4 @@
@model_server_endpoint = http://localhost:51000
@model_server_endpoint = http://localhost:12000
@archfc_endpoint = https://archfc.katanemo.dev
### talk to function calling endpoint

View file

@ -1,4 +1,4 @@
@model_server_endpoint = http://localhost:51000
@model_server_endpoint = http://localhost:12000
@archfc_endpoint = https://archfc.katanemo.dev
### multi turn conversation with intent, except parameter gathering
@ -54,26 +54,8 @@ Content-Type: application/json
}
]
}
### talk to Arch-Intent directly for completion
POST https://archfc.katanemo.dev/v1/chat/completions HTTP/1.1
Content-Type: application/json
{
"model": "Arch-Intent",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant.\n\nYou task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.\n\n<tools>\n{\"index\": \"T0\", \"type\": \"function\", \"function\": {\"name\": \"weather_forecast\", \"parameters\": {\"type\": \"object\", \"properties\": {\"city\": {\"type\": \"str\"}, \"days\": {\"type\": \"int\"}}, \"required\": [\"city\", \"days\"]}}}\n</tools>\n\nProvide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:\n- First line must read 'Yes' or 'No'.\n- If yes, a second line must include a comma-separated list of tool indexes.\n"
},
{ "role": "user", "content": "hi" }
],
"stream": false
}
### multi turn conversation with correct parameters
### multi turn conversation with intent, except parameter gathering
POST {{model_server_endpoint}}/function_calling HTTP/1.1
Content-Type: application/json
@ -125,21 +107,6 @@ Content-Type: application/json
}
]
}
### talk to Arch-Intent directly for completion, expect No
POST https://archfc.katanemo.dev/v1/chat/completions HTTP/1.1
Content-Type: application/json
{
"model": "Arch-Intent",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant.\n\nYou task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.\n\n<tools>\n{\"index\": \"T0\", \"type\": \"function\", \"function\": {\"name\": \"weather_forecast\", \"parameters\": {\"type\": \"object\", \"properties\": {\"city\": {\"type\": \"str\"}, \"days\": {\"type\": \"int\"}}, \"required\": [\"city\", \"days\"]}}}\n</tools>\n\nProvide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:\n- First line must read 'Yes' or 'No'.\n- If yes, a second line must include a comma-separated list of tool indexes.\n"
},
{ "role": "user", "content": "what is your name" }
],
"stream": false
}
### multi turn conversation with correct parameters
POST {{model_server_endpoint}}/function_calling HTTP/1.1

View file

@ -1,4 +1,4 @@
@model_server_endpoint = http://localhost:51000
@model_server_endpoint = http://localhost:12000
@archfc_endpoint = https://archfc.katanemo.dev
### single turn function calling all parameters insurance agent summary