diff --git a/.github/workflows/e2e_archgw.yml b/.github/workflows/e2e_archgw.yml index 633a32d8..4c3bade3 100644 --- a/.github/workflows/e2e_archgw.yml +++ b/.github/workflows/e2e_archgw.yml @@ -24,7 +24,7 @@ jobs: - name: build arch docker image run: | - cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.5 -t katanemo/archgw:latest + cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.7 -t katanemo/archgw:latest - name: start archgw env: diff --git a/.github/workflows/e2e_test_currency_convert.yml b/.github/workflows/e2e_test_currency_convert.yml index de29ed72..352245f0 100644 --- a/.github/workflows/e2e_test_currency_convert.yml +++ b/.github/workflows/e2e_test_currency_convert.yml @@ -24,7 +24,7 @@ jobs: - name: build arch docker image run: | - docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.5 + docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.7 - name: install poetry run: | diff --git a/.github/workflows/e2e_test_preference_based_routing.yml b/.github/workflows/e2e_test_preference_based_routing.yml index c1ab7050..db8cb3d5 100644 --- a/.github/workflows/e2e_test_preference_based_routing.yml +++ b/.github/workflows/e2e_test_preference_based_routing.yml @@ -24,7 +24,7 @@ jobs: - name: build arch docker image run: | - docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.5 + docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.7 - name: install poetry run: | diff --git a/.github/workflows/validate_arch_config.yml b/.github/workflows/validate_arch_config.yml index c3dbfeb3..901600ab 100644 --- a/.github/workflows/validate_arch_config.yml +++ b/.github/workflows/validate_arch_config.yml @@ -24,7 +24,7 @@ jobs: - name: build arch docker image run: | - docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.5 + docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.7 - name: validate arch config run: | diff --git a/README.md b/README.md index d71cf1b4..e93db8a7 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@
-_The proxy server and the universal data plane for AI-native apps._

-Arch handles the *pesky low-level work* in building AI agents like clarifying vague user inputs, routing prompts to the right agents, calling tools for simple tasks, and unifying access to large language models (LLMs) - all without locking you into a framework. Move faster by focusing on the high-level logic of your agents. +_Arch is a smart proxy server designed as a modular edge and AI gateway for agentic apps_

+ Arch handles the *pesky low-level work* in building agentic apps — like applying guardrails, clarifying vague user input, routing prompts to the right agent, and unifying access to any LLM. It’s a language and framework friendly infrastructure layer designed to help you build and ship agentic apps faster. [Quickstart](#Quickstart) • @@ -80,9 +80,9 @@ Arch's CLI allows you to manage and interact with the Arch gateway efficiently. > We recommend that developers create a new Python virtual environment to isolate dependencies before installing Arch. This ensures that archgw and its dependencies do not interfere with other packages on your system. ```console -$ python -m venv venv +$ python3.12 -m venv venv $ source venv/bin/activate # On Windows, use: venv\Scripts\activate -$ pip install archgw==0.3.5 +$ pip install archgw==0.3.7 ``` ### Build Agentic Apps with Arch Gateway @@ -148,13 +148,10 @@ endpoints: ```sh $ archgw up arch_config.yaml -2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.1.5 -... +2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.3.7 2024-12-05 16:56:28,485 - cli.utils - INFO - Schema validation successful! 2024-12-05 16:56:28,485 - cli.main - INFO - Starting arch model server and arch gateway -... 2024-12-05 16:56:51,647 - cli.core - INFO - Container is healthy! - ``` Once the gateway is up you can start interacting with at port 10000 using openai chat completion API. diff --git a/arch/supervisord.conf b/arch/supervisord.conf index dfb4d0d2..bec147cc 100644 --- a/arch/supervisord.conf +++ b/arch/supervisord.conf @@ -2,14 +2,14 @@ nodaemon=true [program:brightstaff] -command=sh -c "RUST_LOG=info /app/brightstaff 2>&1 | tee /var/log/brightstaff.log" +command=sh -c "RUST_LOG=debug /app/brightstaff 2>&1 | tee /var/log/brightstaff.log" stdout_logfile=/dev/stdout redirect_stderr=true stdout_logfile_maxbytes=0 stderr_logfile_maxbytes=0 [program:envoy] -command=/bin/sh -c "python /app/config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:info 2>&1 | tee /var/log//envoy.log" +command=/bin/sh -c "python /app/config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:debug 2>&1 | tee /var/log//envoy.log" stdout_logfile=/dev/stdout redirect_stderr=true stdout_logfile_maxbytes=0 diff --git a/arch/tools/README.md b/arch/tools/README.md index e4e45284..7a33998e 100644 --- a/arch/tools/README.md +++ b/arch/tools/README.md @@ -19,7 +19,7 @@ source venv/bin/activate ### Step 3: Run the build script ```bash -pip install archgw==0.3.5 +pip install archgw==0.3.7 ``` ## Uninstall Instructions: archgw CLI diff --git a/arch/tools/cli/consts.py b/arch/tools/cli/consts.py index 9114f53f..213f0134 100644 --- a/arch/tools/cli/consts.py +++ b/arch/tools/cli/consts.py @@ -10,4 +10,4 @@ SERVICE_NAME_MODEL_SERVER = "model_server" SERVICE_ALL = "all" MODEL_SERVER_LOG_FILE = "~/archgw_logs/modelserver.log" ARCHGW_DOCKER_NAME = "archgw" -ARCHGW_DOCKER_IMAGE = os.getenv("ARCHGW_DOCKER_IMAGE", "katanemo/archgw:0.3.5") +ARCHGW_DOCKER_IMAGE = os.getenv("ARCHGW_DOCKER_IMAGE", "katanemo/archgw:0.3.7") diff --git a/arch/tools/poetry.lock b/arch/tools/poetry.lock index 4491dd1f..a8d5e85f 100644 --- a/arch/tools/poetry.lock +++ b/arch/tools/poetry.lock @@ -2,7 +2,7 @@ [[package]] name = "archgw_modelserver" -version = "0.3.5" +version = "0.3.7" description = "A model server for serving models" optional = false python-versions = "*" @@ -104,13 +104,13 @@ i18n = ["Babel (>=2.7)"] [[package]] name = "jsonschema" -version = "4.24.0" +version = "4.25.0" description = "An implementation of JSON Schema validation for Python" optional = false python-versions = ">=3.9" files = [ - {file = "jsonschema-4.24.0-py3-none-any.whl", hash = "sha256:a462455f19f5faf404a7902952b6f0e3ce868f3ee09a359b05eca6673bd8412d"}, - {file = "jsonschema-4.24.0.tar.gz", hash = "sha256:0b4e8069eb12aedfa881333004bccaec24ecef5a8a6a4b6df142b2cc9599d196"}, + {file = "jsonschema-4.25.0-py3-none-any.whl", hash = "sha256:24c2e8da302de79c8b9382fee3e76b355e44d2a4364bb207159ce10b517bd716"}, + {file = "jsonschema-4.25.0.tar.gz", hash = "sha256:e63acf5c11762c0e6672ffb61482bdf57f0876684d8d249c0fe2d730d48bc55f"}, ] [package.dependencies] @@ -121,7 +121,7 @@ rpds-py = ">=0.7.1" [package.extras] format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] -format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "rfc3987-syntax (>=1.1.0)", "uri-template", "webcolors (>=24.6.0)"] [[package]] name = "jsonschema-specifications" @@ -576,4 +576,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e86085ff732244cb68d2e3f7f4c2903f4a8a50cc7e0963324c2506f0de90df11" +content-hash = "1875c613e62e116d557ad2d30491891557b4114a99c7c65b22b26d690e9e268b" diff --git a/arch/tools/pyproject.toml b/arch/tools/pyproject.toml index cf75165c..c62b8656 100644 --- a/arch/tools/pyproject.toml +++ b/arch/tools/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "archgw" -version = "0.3.5" +version = "0.3.7" description = "Python-based CLI tool to manage Arch Gateway." authors = ["Katanemo Labs, Inc."] packages = [ @@ -10,7 +10,7 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.10" -archgw_modelserver = "^0.3.5" +archgw_modelserver = "^0.3.7" click = "^8.1.7" jinja2 = "^3.1.4" jsonschema = "^4.23.0" diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 89c9ee13..37da961f 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -27,10 +27,13 @@ pub async fn chat_completions( router_service: Arc, llm_provider_endpoint: String, ) -> Result>, hyper::Error> { + let request_path = request.uri().path().to_string(); let mut request_headers = request.headers().clone(); let chat_request_bytes = request.collect().await?.to_bytes(); + debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes)); + let chat_request_parsed = serde_json::from_slice::(&chat_request_bytes) .inspect_err(|err| { warn!( @@ -61,20 +64,15 @@ pub async fn chat_completions( // remove metadata from the request let mut chat_request_user_preferences_removed = chat_request_parsed; if let Some(metadata) = chat_request_user_preferences_removed.get_mut("metadata") { - info!("Removing metadata from request"); + debug!("Removing metadata from request"); if let Some(m) = metadata.as_object_mut() { m.remove("archgw_preference_config"); - info!("Removed archgw_preference_config from metadata"); + debug!("Removed archgw_preference_config from metadata"); } - // metadata.as_object_mut().map(|m| { - // m.remove("archgw_preference_config"); - // info!("Removed archgw_preference_config from metadata"); - // }); - // if metadata is empty, remove it if metadata.as_object().map_or(false, |m| m.is_empty()) { - info!("Removing empty metadata from request"); + debug!("Removing empty metadata from request"); chat_request_user_preferences_removed .as_object_mut() .map(|m| m.remove("metadata")); @@ -102,9 +100,33 @@ pub async fn chat_completions( .as_ref() .and_then(|s| serde_yaml::from_str(s).ok()); + let latest_message_for_log = + chat_completion_request + .messages + .last() + .map_or("None".to_string(), |msg| { + msg.content.as_ref().map_or("None".to_string(), |content| { + content.to_string().replace('\n', "\\n") + }) + }); + + const MAX_MESSAGE_LENGTH: usize = 50; + let latest_message_for_log = if latest_message_for_log.len() > MAX_MESSAGE_LENGTH { + format!("{}...", &latest_message_for_log[..MAX_MESSAGE_LENGTH]) + } else { + latest_message_for_log + }; + + info!( + "request received, request type: chat_completion, usage preferences from request: {}, request path: {}, latest message: {}", + usage_preferences.is_some(), + request_path, + latest_message_for_log + ); + debug!("usage preferences from request: {:?}", usage_preferences); - let mut determined_route = match router_service + let model_name = match router_service .determine_route( &chat_completion_request.messages, trace_parent.clone(), @@ -112,7 +134,16 @@ pub async fn chat_completions( ) .await { - Ok(route) => route, + Ok(route) => match route { + Some((_, model_name)) => model_name, + None => { + debug!( + "No route determined, using default model from request: {}", + chat_completion_request.model + ); + chat_completion_request.model.clone() + } + }, Err(err) => { let err_msg = format!("Failed to determine route: {}", err); let mut internal_error = Response::new(full(err_msg)); @@ -121,14 +152,14 @@ pub async fn chat_completions( } }; - if determined_route.is_none() { - debug!("No LLM model selected, using default from request"); - determined_route = Some(chat_completion_request.model.clone()); - } + debug!( + "sending request to llm provider: {}, with model hint: {}", + llm_provider_endpoint, model_name + ); - info!( - "sending request to llm provider: {} with llm model: {:?}", - llm_provider_endpoint, determined_route + request_headers.insert( + ARCH_PROVIDER_HINT_HEADER, + header::HeaderValue::from_str(&model_name).unwrap(), ); if let Some(trace_parent) = trace_parent { @@ -138,13 +169,6 @@ pub async fn chat_completions( ); } - if let Some(selected_route) = determined_route { - request_headers.insert( - ARCH_PROVIDER_HINT_HEADER, - header::HeaderValue::from_str(&selected_route).unwrap(), - ); - } - let chat_request_parsed_bytes = serde_json::to_string(&chat_request_user_preferences_removed).unwrap(); diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index febab6c2..6de38b5b 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,3 +1,2 @@ pub mod chat_completions; pub mod models; -pub mod preferences; diff --git a/crates/brightstaff/src/handlers/preferences.rs b/crates/brightstaff/src/handlers/preferences.rs deleted file mode 100644 index a9c5a65d..00000000 --- a/crates/brightstaff/src/handlers/preferences.rs +++ /dev/null @@ -1,135 +0,0 @@ -use bytes::Bytes; -use common::configuration::{LlmProvider, ModelUsagePreference}; -use http_body_util::{combinators::BoxBody, BodyExt, Full}; -use hyper::{Request, Response, StatusCode}; -use serde_json; -use std::{collections::HashMap, sync::Arc}; -use tracing::{info, warn}; - -pub async fn list_preferences( - llm_providers: Arc>>, -) -> Response> { - let prov = llm_providers.read().await; - // convert the LlmProvider to UsageBasedProvider - let providers_with_usage = prov - .iter() - .map(|provider| ModelUsagePreference { - name: provider.name.clone(), - model: provider.model.clone().unwrap_or_default(), - usage: provider.usage.clone(), - }) - .collect::>(); - - match serde_json::to_string(&providers_with_usage) { - Ok(json) => { - let body = Full::new(Bytes::from(json)) - .map_err(|never| match never {}) - .boxed(); - Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(body) - .unwrap() - } - Err(_) => { - let body = Full::new(Bytes::from_static( - b"{\"error\":\"Failed to serialize models\"}", - )) - .map_err(|never| match never {}) - .boxed(); - Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header("Content-Type", "application/json") - .body(body) - .unwrap() - } - } -} - -pub async fn update_preferences( - request: Request, - llm_providers: Arc>>, -) -> Result>, hyper::Error> { - let request_body = request.collect().await?.to_bytes(); - - let usage: Vec = match serde_json::from_slice(&request_body) { - Ok(usage) => usage, - Err(_) => { - let response_body = Full::new(Bytes::from_static(b"Invalid request body: ")) - .map_err(|never| match never {}) - .boxed(); - return Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .header("Content-Type", "text/plain") - .body(response_body) - .unwrap()); - } - }; - - let usage_model_map: HashMap = - usage.into_iter().map(|u| (u.model.clone(), u)).collect(); - - info!( - "Updating usage preferences for models: {:?}", - usage_model_map.keys() - ); - - let mut llm_providers = llm_providers.write().await; - - // ensure that models coming in the request are valid - let llm_provider_names: Vec = llm_providers - .iter() - .map(|provider| provider.name.clone()) - .collect(); - - for model in usage_model_map.keys() { - if !llm_provider_names.contains(model) { - let model_not_found = format!("model not found: {}", model); - warn!("updating preferences: {}", model_not_found); - let response_body = Full::new(model_not_found.into()) - .map_err(|never| match never {}) - .boxed(); - return Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .header("Content-Type", "text/plain") - .body(response_body) - .unwrap()); - } - } - - let mut updated_models_list = Vec::new(); - for provider in llm_providers.iter_mut() { - if let Some(usage_provider) = usage_model_map.get(&provider.name) { - provider.usage = usage_provider.usage.clone(); - updated_models_list.push(ModelUsagePreference { - name: provider.name.clone(), - model: provider.model.clone().unwrap_or_default(), - usage: provider.usage.clone(), - }); - } - } - - if !updated_models_list.is_empty() { - // return list of updated models - let response_body = Full::new(Bytes::from(format!( - "{{\"updated_models\": {}}}", - serde_json::to_string(&updated_models_list).unwrap() - ))) - .map_err(|never| match never {}) - .boxed(); - Ok(Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(response_body) - .unwrap()) - } else { - let response_body = Full::new(Bytes::from_static(b"Provider not found")) - .map_err(|never| match never {}) - .boxed(); - Ok(Response::builder() - .status(StatusCode::NOT_FOUND) - .header("Content-Type", "text/plain") - .body(response_body) - .unwrap()) - } -} diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 4e4f18b7..b5bf0204 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,6 +1,5 @@ use brightstaff::handlers::chat_completions::chat_completions; use brightstaff::handlers::models::list_models; -use brightstaff::handlers::preferences::{list_preferences, update_preferences}; use brightstaff::router::llm_router::RouterService; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; @@ -116,12 +115,6 @@ async fn main() -> Result<(), Box> { .with_context(parent_cx) .await } - (&Method::GET, "/v1/router/preferences") => { - Ok(list_preferences(llm_providers).await) - } - (&Method::PUT, "/v1/router/preferences") => { - update_preferences(req, llm_providers).await - } (&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await), (&Method::OPTIONS, "/v1/models") => { let mut response = Response::new(empty()); @@ -156,7 +149,7 @@ async fn main() -> Result<(), Box> { }); tokio::task::spawn(async move { - info!("Accepted connection from {:?}", peer_addr); + debug!("Accepted connection from {:?}", peer_addr); if let Err(err) = http1::Builder::new() // .serve_connection(io, service_fn(chat_completion)) .serve_connection(io, service) diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index c1320c66..fc6d9365 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use common::{ configuration::{LlmProvider, ModelUsagePreference, RoutingPreference}, @@ -48,9 +48,14 @@ impl RouterService { .cloned() .collect::>(); - let llm_routes: Vec = providers_with_usage + let llm_routes: HashMap> = providers_with_usage .iter() - .flat_map(|provider| provider.routing_preferences.clone().unwrap_or_default()) + .filter_map(|provider| { + provider + .routing_preferences + .as_ref() + .map(|prefs| (provider.name.clone(), prefs.clone())) + }) .collect(); let router_model = Arc::new(router_model_v1::RouterModelV1::new( @@ -73,7 +78,7 @@ impl RouterService { messages: &[Message], trace_parent: Option, usage_preferences: Option>, - ) -> Result> { + ) -> Result> { if !self.llm_usage_defined { return Ok(None); } @@ -82,7 +87,7 @@ impl RouterService { .router_model .generate_request(messages, &usage_preferences); - info!( + debug!( "sending request to arch-router model: {}, endpoint: {}", self.router_model.get_model_name(), self.router_url @@ -151,21 +156,21 @@ impl RouterService { if let Some(ContentType::Text(content)) = &chat_completion_response.choices[0].message.content { - let route_name = self.router_model.parse_response(content)?; + let parsed_response = self + .router_model + .parse_response(content, &usage_preferences)?; info!( - "router response: {}, selected_model: {:?}, response time: {}ms", + "arch-router determined route: {}, selected_model: {:?}, response time: {}ms", content.replace("\n", "\\n"), - route_name, + parsed_response, router_response_time.as_millis() ); - if let Some(ref route) = route_name { - if route == "other" { - return Ok(None); - } + if let Some(ref parsed_response) = parsed_response { + return Ok(Some(parsed_response.clone())); } - Ok(route_name) + Ok(None) } else { Ok(None) } diff --git a/crates/brightstaff/src/router/router_model.rs b/crates/brightstaff/src/router/router_model.rs index dafa8776..ec0c1a1f 100644 --- a/crates/brightstaff/src/router/router_model.rs +++ b/crates/brightstaff/src/router/router_model.rs @@ -16,6 +16,10 @@ pub trait RouterModel: Send + Sync { messages: &[Message], usage_preferences: &Option>, ) -> ChatCompletionsRequest; - fn parse_response(&self, content: &str) -> Result>; + fn parse_response( + &self, + content: &str, + usage_preferences: &Option>, + ) -> Result>; fn get_model_name(&self) -> String; } diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 0dcefff6..bd06b525 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use common::{ configuration::{ModelUsagePreference, RoutingPreference}, consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE}, @@ -32,21 +34,30 @@ Based on your analysis, provide your response in the following JSON formats if y pub type Result = std::result::Result; pub struct RouterModelV1 { llm_route_json_str: String, + llm_route_to_model_map: HashMap, routing_model: String, max_token_length: usize, } impl RouterModelV1 { pub fn new( - llm_routes: Vec, + llm_routes: HashMap>, routing_model: String, max_token_length: usize, ) -> Self { + let llm_route_values: Vec = + llm_routes.values().flatten().cloned().collect(); let llm_route_json_str = - serde_json::to_string(&llm_routes).unwrap_or_else(|_| "[]".to_string()); + serde_json::to_string(&llm_route_values).unwrap_or_else(|_| "[]".to_string()); + let llm_route_to_model_map: HashMap = llm_routes + .iter() + .flat_map(|(model, prefs)| prefs.iter().map(|pref| (pref.name.clone(), model.clone()))) + .collect(); + RouterModelV1 { routing_model, max_token_length, llm_route_json_str, + llm_route_to_model_map, } } } @@ -62,7 +73,7 @@ impl RouterModel for RouterModelV1 { fn generate_request( &self, messages: &[Message], - usage_preferences: &Option>, + usage_preferences_from_request: &Option>, ) -> ChatCompletionsRequest { // remove system prompt, tool calls, tool call response and messages without content // if content is empty its likely a tool call @@ -139,31 +150,17 @@ impl RouterModel for RouterModelV1 { }) .collect::>(); - let llm_route_json = usage_preferences - .as_ref() - .map(|prefs| { - let llm_route: Vec = prefs - .iter() - .map(|pref| RoutingPreference { - name: pref.name.clone(), - description: pref.usage.clone().unwrap_or_default(), - }) - .collect(); - serde_json::to_string(&llm_route).unwrap_or_default() - }) - .unwrap_or_else(|| self.llm_route_json_str.clone()); - - let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT - .replace("{routes}", &llm_route_json) - .replace( - "{conversation}", - &serde_json::to_string(&selected_conversation_list).unwrap_or_default(), - ); + // Generate the router request message based on the usage preferences. + // If preferences are passed in request then we use them otherwise we use the default routing model preferences. + let router_message = match convert_to_router_preferences(usage_preferences_from_request) { + Some(prefs) => generate_router_message(&prefs, &selected_conversation_list), + None => generate_router_message(&self.llm_route_json_str, &selected_conversation_list), + }; ChatCompletionsRequest { model: self.routing_model.clone(), messages: vec![Message { - content: Some(ContentType::Text(messages_content)), + content: Some(ContentType::Text(router_message)), role: USER_ROLE.to_string(), }], temperature: Some(0.01), @@ -171,20 +168,57 @@ impl RouterModel for RouterModelV1 { } } - fn parse_response(&self, content: &str) -> Result> { + fn parse_response( + &self, + content: &str, + usage_preferences: &Option>, + ) -> Result> { if content.is_empty() { return Ok(None); } let router_resp_fixed = fix_json_response(content); let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?; - let selected_llm = router_response.route.unwrap_or_default().to_string(); + let selected_route = router_response.route.unwrap_or_default().to_string(); - if selected_llm.is_empty() { + if selected_route.is_empty() || selected_route == "other" { return Ok(None); } - Ok(Some(selected_llm)) + if let Some(usage_preferences) = usage_preferences { + // If usage preferences are defined, we need to find the model that matches the selected route + let model_name: Option = usage_preferences + .iter() + .map(|pref| { + pref.routing_preferences + .iter() + .find(|routing_pref| routing_pref.name == selected_route) + .map(|_| pref.model.clone()) + }) + .find_map(|model| model); + + if let Some(model_name) = model_name { + return Ok(Some((selected_route, model_name))); + } else { + warn!( + "No matching model found for route: {}, usage preferences: {:?}", + selected_route, usage_preferences + ); + return Ok(None); + } + } + + // If no usage preferences are passed in request then use the default routing model preferences + if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() { + return Ok(Some((selected_route, model))); + } + + warn!( + "No model found for route: {}, router model preferences: {:?}", + selected_route, self.llm_route_to_model_map + ); + + Ok(None) } fn get_model_name(&self) -> String { @@ -192,6 +226,37 @@ impl RouterModel for RouterModelV1 { } } +fn generate_router_message(prefs: &str, selected_conversation_list: &Vec) -> String { + ARCH_ROUTER_V1_SYSTEM_PROMPT + .replace("{routes}", prefs) + .replace( + "{conversation}", + &serde_json::to_string(&selected_conversation_list).unwrap_or_default(), + ) +} + +fn convert_to_router_preferences( + prefs_from_request: &Option>, +) -> Option { + if let Some(usage_preferences) = prefs_from_request { + let routing_preferences = usage_preferences + .iter() + .flat_map(|pref| { + pref.routing_preferences + .iter() + .map(|routing_pref| RoutingPreference { + name: routing_pref.name.clone(), + description: routing_pref.description.clone(), + }) + }) + .collect::>(); + + return Some(serde_json::to_string(&routing_preferences).unwrap_or_default()); + } + + None +} + fn fix_json_response(body: &str) -> String { let mut updated_body = body.to_string(); @@ -235,7 +300,7 @@ mod tests { You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -251,15 +316,14 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -310,15 +374,14 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -341,9 +404,11 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); let usage_preferences = Some(vec![ModelUsagePreference { - name: "code-generation".to_string(), model: "claude/claude-3-7-sonnet".to_string(), - usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()), + routing_preferences: vec![RoutingPreference { + name: "code-generation".to_string(), + description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(), + }], }]); let req = router.generate_request(&conversation, &usage_preferences); @@ -358,7 +423,7 @@ Based on your analysis, provide your response in the following JSON formats if y You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -375,15 +440,14 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235); @@ -419,7 +483,7 @@ Based on your analysis, provide your response in the following JSON formats if y You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -436,15 +500,15 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); + let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200); @@ -480,7 +544,7 @@ Based on your analysis, provide your response in the following JSON formats if y You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -497,15 +561,14 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230); @@ -549,7 +612,7 @@ Based on your analysis, provide your response in the following JSON formats if y You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -565,15 +628,14 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -619,7 +681,7 @@ Based on your analysis, provide your response in the following JSON formats if y You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -635,15 +697,14 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -712,56 +773,64 @@ Based on your analysis, provide your response in the following JSON formats if y #[test] fn test_parse_response() { let routes_str = r#" -[ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} -] -"#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } + "#; + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000); // Case 1: Valid JSON with non-empty route - let input = r#"{"route": "route1"}"#; - let result = router.parse_response(input).unwrap(); - assert_eq!(result, Some("route1".to_string())); + let input = r#"{"route": "Image generation"}"#; + let result = router.parse_response(input, &None).unwrap(); + assert_eq!( + result, + Some(("Image generation".to_string(), "gpt-4o".to_string())) + ); // Case 2: Valid JSON with empty route let input = r#"{"route": ""}"#; - let result = router.parse_response(input).unwrap(); + let result = router.parse_response(input, &None).unwrap(); assert_eq!(result, None); // Case 3: Valid JSON with null route let input = r#"{"route": null}"#; - let result = router.parse_response(input).unwrap(); + let result = router.parse_response(input, &None).unwrap(); assert_eq!(result, None); // Case 4: JSON missing route field let input = r#"{}"#; - let result = router.parse_response(input).unwrap(); + let result = router.parse_response(input, &None).unwrap(); assert_eq!(result, None); // Case 4.1: empty string let input = r#""#; - let result = router.parse_response(input).unwrap(); + let result = router.parse_response(input, &None).unwrap(); assert_eq!(result, None); // Case 5: Malformed JSON let input = r#"{"route": "route1""#; // missing closing } - let result = router.parse_response(input); + let result = router.parse_response(input, &None); assert!(result.is_err()); // Case 6: Single quotes and \n in JSON - let input = "{'route': 'route2'}\\n"; - let result = router.parse_response(input).unwrap(); - assert_eq!(result, Some("route2".to_string())); + let input = "{'route': 'Image generation'}\\n"; + let result = router.parse_response(input, &None).unwrap(); + assert_eq!( + result, + Some(("Image generation".to_string(), "gpt-4o".to_string())) + ); // Case 7: Code block marker - let input = "```json\n{\"route\": \"route1\"}\n```"; - let result = router.parse_response(input).unwrap(); - assert_eq!(result, Some("route1".to_string())); + let input = "```json\n{\"route\": \"Image generation\"}\n```"; + let result = router.parse_response(input, &None).unwrap(); + assert_eq!( + result, + Some(("Image generation".to_string(), "gpt-4o".to_string())) + ); } } diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 0693c09b..186691dc 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -1,6 +1,5 @@ use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models}; use serde::{Deserialize, Serialize}; -use serde_with::skip_serializing_none; use std::collections::HashMap; use std::fmt::Display; @@ -178,12 +177,10 @@ impl Display for LlmProviderType { } } -#[skip_serializing_none] #[derive(Serialize, Deserialize, Debug)] pub struct ModelUsagePreference { - pub name: String, pub model: String, - pub usage: Option, + pub routing_preferences: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs new file mode 100644 index 00000000..0ffe4e8d --- /dev/null +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -0,0 +1,898 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde_with::skip_serializing_none; +use std::collections::HashMap; + +use super::ApiDefinition; + +// Enum for all supported Anthropic APIs +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AnthropicApi { + Messages, + // Future APIs can be added here: + // Embeddings, + // etc. +} + +impl ApiDefinition for AnthropicApi { + fn endpoint(&self) -> &'static str { + match self { + AnthropicApi::Messages => "/v1/messages", + } + } + + fn from_endpoint(endpoint: &str) -> Option { + match endpoint { + "/v1/messages" => Some(AnthropicApi::Messages), + _ => None, + } + } + + fn supports_streaming(&self) -> bool { + match self { + AnthropicApi::Messages => true, + } + } + + fn supports_tools(&self) -> bool { + match self { + AnthropicApi::Messages => true, + } + } + + fn supports_vision(&self) -> bool { + match self { + AnthropicApi::Messages => true, + } + } + + fn all_variants() -> Vec { + vec![ + AnthropicApi::Messages, + ] + } +} + +// Service tier enum for request priority +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ServiceTier { + Auto, + StandardOnly, +} + +// Thinking configuration +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ThinkingConfig { + pub enabled: bool, +} + +// MCP Server types +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum McpServerType { + Url, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct McpToolConfiguration { + pub allowed_tools: Option>, + pub enabled: Option, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct McpServer { + pub name: String, + #[serde(rename = "type")] + pub server_type: McpServerType, + pub url: String, + pub authorization_token: Option, + pub tool_configuration: Option, +} + + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesRequest { + pub model: String, + pub messages: Vec, + pub max_tokens: u32, + pub container: Option, + pub mcp_servers: Option>, + pub system: Option, + pub metadata: Option>, + pub service_tier: Option, + pub thinking: Option, + + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub stream: Option, + pub stop_sequences: Option>, + pub tools: Option>, + pub tool_choice: Option, + +} + + +// Messages API specific types +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum MessagesRole { + User, + Assistant, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum MessagesContentBlock { + Text { + text: String, + }, + Thinking { + text: String, + }, + Image { + source: MessagesImageSource, + }, + Document { + source: MessagesDocumentSource, + }, + ToolUse { + id: String, + name: String, + input: Value, + }, + ToolResult { + tool_use_id: String, + is_error: Option, + content: Vec, + }, + ServerToolUse { + id: String, + name: String, + input: Value, + }, + WebSearchToolResult { + tool_use_id: String, + is_error: Option, + content: Vec, + }, + CodeExecutionToolResult { + tool_use_id: String, + is_error: Option, + content: Vec, + }, + McpToolUse { + id: String, + name: String, + input: Value, + }, + McpToolResult { + tool_use_id: String, + is_error: Option, + content: Vec, + }, + ContainerUpload { + id: String, + name: String, + media_type: String, + data: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +pub enum MessagesImageSource { + Base64 { + media_type: String, + data: String, + }, + Url { + url: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +pub enum MessagesDocumentSource { + Base64 { + media_type: String, + data: String, + }, + Url { + url: String, + }, + File { + file_id: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum MessagesMessageContent { + Single(String), + Blocks(Vec), +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum MessagesSystemPrompt { + Single(String), + Blocks(Vec), +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesMessage { + pub role: MessagesRole, + pub content: MessagesMessageContent, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesTool { + pub name: String, + pub description: Option, + pub input_schema: Value, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum MessagesToolChoiceType { + Auto, + Any, + Tool, + None, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesToolChoice { + #[serde(rename = "type")] + pub kind: MessagesToolChoiceType, + pub name: Option, + pub disable_parallel_tool_use: Option, +} + + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum MessagesStopReason { + EndTurn, + MaxTokens, + StopSequence, + ToolUse, + PauseTurn, + Refusal, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesUsage { + pub input_tokens: u32, + pub output_tokens: u32, + pub cache_creation_input_tokens: Option, + pub cache_read_input_tokens: Option, +} + +// Container response object +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesContainer { + pub id: String, + #[serde(rename = "type")] + pub container_type: String, + pub name: String, + pub status: String, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesResponse { + pub id: String, + #[serde(rename = "type")] + pub obj_type: String, + pub role: MessagesRole, + pub content: Vec, + pub model: String, + pub stop_reason: MessagesStopReason, + pub stop_sequence: Option, + pub usage: MessagesUsage, + pub container: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum MessagesStreamEvent { + MessageStart { + message: MessagesStreamMessage, + }, + ContentBlockStart { + index: u32, + content_block: MessagesContentBlock, + }, + ContentBlockDelta { + index: u32, + delta: MessagesContentDelta, + }, + ContentBlockStop { + index: u32, + }, + MessageDelta { + delta: MessagesMessageDelta, + usage: MessagesUsage, + }, + MessageStop, + Ping, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesStreamMessage { + pub id: String, + #[serde(rename = "type")] + pub obj_type: String, + pub role: MessagesRole, + pub content: Vec, // Initially empty + pub model: String, + pub stop_reason: Option, + pub stop_sequence: Option, + pub usage: MessagesUsage, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum MessagesContentDelta { + #[serde(rename = "text_delta")] + TextDelta { text: String }, + #[serde(rename = "input_json_delta")] + InputJsonDelta { partial_json: String }, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MessagesMessageDelta { + pub stop_reason: MessagesStopReason, + pub stop_sequence: Option, +} + +// Helper functions for API detection and conversion +impl MessagesRequest { + pub fn api_type() -> AnthropicApi { + AnthropicApi::Messages + } +} + +impl MessagesResponse { + pub fn api_type() -> AnthropicApi { + AnthropicApi::Messages + } +} + +impl MessagesStreamEvent { + pub fn api_type() -> AnthropicApi { + AnthropicApi::Messages + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_anthropic_required_fields() { + // Create a JSON object with only required fields + let original_json = json!({ + "model": "claude-3-sonnet-20240229", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "max_tokens": 100 + }); + + // Deserialize JSON into MessagesRequest + let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); + + // Validate required fields are properly set + assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); + assert_eq!(deserialized_request.messages.len(), 1); + assert_eq!(deserialized_request.max_tokens, 100); + + let message = &deserialized_request.messages[0]; + assert_eq!(message.role, MessagesRole::User); + if let MessagesMessageContent::Single(content) = &message.content { + assert_eq!(content, "Hello"); + } else { + panic!("Expected single content"); + } + + // Validate optional fields are None + assert!(deserialized_request.system.is_none()); + assert!(deserialized_request.container.is_none()); + assert!(deserialized_request.mcp_servers.is_none()); + assert!(deserialized_request.service_tier.is_none()); + assert!(deserialized_request.thinking.is_none()); + assert!(deserialized_request.temperature.is_none()); + assert!(deserialized_request.top_p.is_none()); + assert!(deserialized_request.top_k.is_none()); + assert!(deserialized_request.stream.is_none()); + assert!(deserialized_request.stop_sequences.is_none()); + assert!(deserialized_request.tools.is_none()); + assert!(deserialized_request.tool_choice.is_none()); + assert!(deserialized_request.metadata.is_none()); + + // Serialize back to JSON and compare + let serialized_json = serde_json::to_value(&deserialized_request).unwrap(); + assert_eq!(original_json, serialized_json); + } + + #[test] + fn test_anthropic_optional_fields() { + // Create a JSON object with optional fields set + let original_json = json!({ + "model": "claude-3-sonnet-20240229", + "messages": [ + { + "role": "user", + "content": "Hello" + } + ], + "max_tokens": 100, + "temperature": 0.7, + "top_p": 0.9, + "system": "You are a helpful assistant", + "service_tier": "auto", + "thinking": { + "enabled": true + }, + "metadata": { + "user_id": "123" + } + }); + + // Deserialize JSON into MessagesRequest + let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); + + // Validate required fields + assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); + assert_eq!(deserialized_request.messages.len(), 1); + assert_eq!(deserialized_request.max_tokens, 100); + + // Validate optional fields are properly set + assert!((deserialized_request.temperature.unwrap() - 0.7).abs() < 1e-6); + assert!((deserialized_request.top_p.unwrap() - 0.9).abs() < 1e-6); + assert_eq!(deserialized_request.service_tier, Some(ServiceTier::Auto)); + + if let Some(MessagesSystemPrompt::Single(system)) = &deserialized_request.system { + assert_eq!(system, "You are a helpful assistant"); + } else { + panic!("Expected single system prompt"); + } + + if let Some(thinking) = &deserialized_request.thinking { + assert_eq!(thinking.enabled, true); + } else { + panic!("Expected thinking config"); + } + + assert!(deserialized_request.metadata.is_some()); + + // Validate fields not in JSON are None + assert!(deserialized_request.container.is_none()); + assert!(deserialized_request.mcp_servers.is_none()); + assert!(deserialized_request.top_k.is_none()); + assert!(deserialized_request.stream.is_none()); + assert!(deserialized_request.stop_sequences.is_none()); + assert!(deserialized_request.tools.is_none()); + assert!(deserialized_request.tool_choice.is_none()); + + // Serialize back to JSON and compare (handle floating point precision) + let serialized_json = serde_json::to_value(&deserialized_request).unwrap(); + + // Compare all fields except floating point ones + assert_eq!(serialized_json["model"], original_json["model"]); + assert_eq!(serialized_json["messages"], original_json["messages"]); + assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]); + assert_eq!(serialized_json["system"], original_json["system"]); + assert_eq!(serialized_json["service_tier"], original_json["service_tier"]); + assert_eq!(serialized_json["thinking"], original_json["thinking"]); + assert_eq!(serialized_json["metadata"], original_json["metadata"]); + + // Handle floating point fields with tolerance + let original_temp = original_json["temperature"].as_f64().unwrap(); + let serialized_temp = serialized_json["temperature"].as_f64().unwrap(); + assert!((original_temp - serialized_temp).abs() < 1e-6); + + let original_top_p = original_json["top_p"].as_f64().unwrap(); + let serialized_top_p = serialized_json["top_p"].as_f64().unwrap(); + assert!((original_top_p - serialized_top_p).abs() < 1e-6); + } + + #[test] + fn test_anthropic_nested_types() { + // Create a comprehensive JSON object with nested types - a MessagesRequest with complex message content and tools + let original_json = json!({ + "model": "claude-3-sonnet-20240229", + "max_tokens": 1000, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What can you see in this image and what's the weather like?" + }, + { + "type": "image", + "source": { + "base64": { + "media_type": "image/jpeg", + "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + } + } + } + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "text": "Let me analyze the image and then check the weather..." + }, + { + "type": "text", + "text": "I can see the image. Let me check the weather for you." + }, + { + "type": "tool_use", + "id": "toolu_weather123", + "name": "get_weather", + "input": { + "location": "San Francisco, CA" + } + } + ] + } + ], + "tools": [ + { + "name": "get_weather", + "description": "Get current weather information for a location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + ], + "tool_choice": { + "type": "auto" + }, + "system": [ + { + "type": "text", + "text": "You are a helpful assistant that can analyze images and provide weather information." + } + ] + }); + + // Deserialize JSON into MessagesRequest + let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap(); + + // Validate top-level fields + assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229"); + assert_eq!(deserialized_request.max_tokens, 1000); + assert_eq!(deserialized_request.messages.len(), 2); + + // Validate first message (user with text and image content) + let user_message = &deserialized_request.messages[0]; + assert_eq!(user_message.role, MessagesRole::User); + if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content { + assert_eq!(content_blocks.len(), 2); + + // Validate text content block + if let MessagesContentBlock::Text { text } = &content_blocks[0] { + assert_eq!(text, "What can you see in this image and what's the weather like?"); + } else { + panic!("Expected text content block"); + } + + // Validate image content block + if let MessagesContentBlock::Image { ref source } = content_blocks[1] { + if let MessagesImageSource::Base64 { media_type, data } = source { + assert_eq!(media_type, "image/jpeg"); + assert_eq!(data, "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="); + } else { + panic!("Expected base64 image source"); + } + } else { + panic!("Expected image content block"); + } + } else { + panic!("Expected content blocks for user message"); + } + + // Validate second message (assistant with thinking, text, and tool use) + let assistant_message = &deserialized_request.messages[1]; + assert_eq!(assistant_message.role, MessagesRole::Assistant); + if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content { + assert_eq!(content_blocks.len(), 3); + + // Validate thinking content block + if let MessagesContentBlock::Thinking { text } = &content_blocks[0] { + assert_eq!(text, "Let me analyze the image and then check the weather..."); + } else { + panic!("Expected thinking content block"); + } + + // Validate text content block + if let MessagesContentBlock::Text { text } = &content_blocks[1] { + assert_eq!(text, "I can see the image. Let me check the weather for you."); + } else { + panic!("Expected text content block"); + } + + // Validate tool use content block + if let MessagesContentBlock::ToolUse { ref id, ref name, ref input } = content_blocks[2] { + assert_eq!(id, "toolu_weather123"); + assert_eq!(name, "get_weather"); + assert_eq!(input["location"], "San Francisco, CA"); + } else { + panic!("Expected tool use content block"); + } + } else { + panic!("Expected content blocks for assistant message"); + } + + // Validate tools array + assert!(deserialized_request.tools.is_some()); + let tools = deserialized_request.tools.as_ref().unwrap(); + assert_eq!(tools.len(), 1); + + let tool = &tools[0]; + assert_eq!(tool.name, "get_weather"); + assert_eq!(tool.description, Some("Get current weather information for a location".to_string())); + assert_eq!(tool.input_schema["type"], "object"); + assert!(tool.input_schema["properties"]["location"].is_object()); + + // Validate tool choice + assert!(deserialized_request.tool_choice.is_some()); + let tool_choice = deserialized_request.tool_choice.as_ref().unwrap(); + assert_eq!(tool_choice.kind, MessagesToolChoiceType::Auto); + assert!(tool_choice.name.is_none()); + + // Validate system prompt with content blocks + assert!(deserialized_request.system.is_some()); + if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system { + assert_eq!(system_blocks.len(), 1); + if let MessagesContentBlock::Text { text } = &system_blocks[0] { + assert_eq!(text, "You are a helpful assistant that can analyze images and provide weather information."); + } else { + panic!("Expected text content block in system prompt"); + } + } else { + panic!("Expected system prompt with content blocks"); + } + + // Serialize back to JSON and compare + let serialized_json = serde_json::to_value(&deserialized_request).unwrap(); + assert_eq!(original_json, serialized_json); + } + + #[test] + fn test_anthropic_mcp_server_configuration() { + // Test MCP Server configuration with JSON-first approach + let mcp_server_json = json!({ + "name": "test-server", + "type": "url", + "url": "https://example.com/mcp", + "authorization_token": "secret-token", + "tool_configuration": { + "allowed_tools": ["tool1", "tool2"], + "enabled": true + } + }); + + let deserialized_mcp: McpServer = serde_json::from_value(mcp_server_json.clone()).unwrap(); + assert_eq!(deserialized_mcp.name, "test-server"); + assert_eq!(deserialized_mcp.server_type, McpServerType::Url); + assert_eq!(deserialized_mcp.url, "https://example.com/mcp"); + assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string())); + + if let Some(tool_config) = &deserialized_mcp.tool_configuration { + assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()])); + assert_eq!(tool_config.enabled, Some(true)); + } else { + panic!("Expected tool configuration"); + } + + let serialized_mcp_json = serde_json::to_value(&deserialized_mcp).unwrap(); + assert_eq!(mcp_server_json, serialized_mcp_json); + + // Test MCP Server with minimal configuration (optional fields as None) + let minimal_mcp_json = json!({ + "name": "minimal-server", + "type": "url", + "url": "https://minimal.com/mcp" + }); + + let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap(); + assert_eq!(deserialized_minimal.name, "minimal-server"); + assert_eq!(deserialized_minimal.server_type, McpServerType::Url); + assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp"); + assert!(deserialized_minimal.authorization_token.is_none()); + assert!(deserialized_minimal.tool_configuration.is_none()); + + let serialized_minimal_json = serde_json::to_value(&deserialized_minimal).unwrap(); + assert_eq!(minimal_mcp_json, serialized_minimal_json); + } + + #[test] + fn test_anthropic_response_types() { + // Test MessagesResponse deserialization + let response_json = json!({ + "id": "msg_01ABC123", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "text", + "text": "Hello! How can I help you today?" + } + ], + "model": "claude-3-sonnet-20240229", + "stop_reason": "end_turn", + "usage": { + "input_tokens": 10, + "output_tokens": 25, + "cache_creation_input_tokens": 5, + "cache_read_input_tokens": 3 + } + }); + + let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap(); + assert_eq!(deserialized_response.id, "msg_01ABC123"); + assert_eq!(deserialized_response.obj_type, "message"); + assert_eq!(deserialized_response.role, MessagesRole::Assistant); + assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229"); + assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn); + assert!(deserialized_response.stop_sequence.is_none()); + assert!(deserialized_response.container.is_none()); + + // Check content + assert_eq!(deserialized_response.content.len(), 1); + if let MessagesContentBlock::Text { text } = &deserialized_response.content[0] { + assert_eq!(text, "Hello! How can I help you today?"); + } else { + panic!("Expected text content block"); + } + + // Check usage + assert_eq!(deserialized_response.usage.input_tokens, 10); + assert_eq!(deserialized_response.usage.output_tokens, 25); + assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5)); + assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3)); + + let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap(); + assert_eq!(response_json, serialized_response_json); + + // Test streaming event + let stream_event_json = json!({ + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": " How" + } + }); + + let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap(); + if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event { + assert_eq!(index, 0); + if let MessagesContentDelta::TextDelta { text } = delta { + assert_eq!(text, " How"); + } else { + panic!("Expected text delta"); + } + } else { + panic!("Expected content block delta event"); + } + + let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap(); + assert_eq!(stream_event_json, serialized_event_json); + } + + #[test] + fn test_anthropic_tool_use_content() { + // Test tool use and tool result content blocks + let tool_use_json = json!({ + "type": "tool_use", + "id": "toolu_01ABC123", + "name": "get_weather", + "input": { + "location": "San Francisco, CA" + } + }); + + let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap(); + if let MessagesContentBlock::ToolUse { ref id, ref name, ref input } = deserialized_tool_use { + assert_eq!(id, "toolu_01ABC123"); + assert_eq!(name, "get_weather"); + assert_eq!(input["location"], "San Francisco, CA"); + } else { + panic!("Expected tool use content block"); + } + + let serialized_tool_use_json = serde_json::to_value(&deserialized_tool_use).unwrap(); + assert_eq!(tool_use_json, serialized_tool_use_json); + + // Test tool result content block + let tool_result_json = json!({ + "type": "tool_result", + "tool_use_id": "toolu_01ABC123", + "content": [ + { + "type": "text", + "text": "The weather in San Francisco is sunny, 72°F" + } + ] + }); + + let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap(); + if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content } = deserialized_tool_result { + assert_eq!(tool_use_id, "toolu_01ABC123"); + assert!(is_error.is_none()); + assert_eq!(content.len(), 1); + if let MessagesContentBlock::Text { text } = &content[0] { + assert_eq!(text, "The weather in San Francisco is sunny, 72°F"); + } else { + panic!("Expected text content in tool result"); + } + } else { + panic!("Expected tool result content block"); + } + + let serialized_tool_result_json = serde_json::to_value(&deserialized_tool_result).unwrap(); + assert_eq!(tool_result_json, serialized_tool_result_json); + } + + #[test] + fn test_anthropic_api_provider_trait_implementation() { + // Test that AnthropicApi implements ApiDefinition trait correctly + let api = AnthropicApi::Messages; + + // Test trait methods + assert_eq!(api.endpoint(), "/v1/messages"); + assert!(api.supports_streaming()); + assert!(api.supports_tools()); + assert!(api.supports_vision()); + + // Test from_endpoint trait method + let found_api = AnthropicApi::from_endpoint("/v1/messages"); + assert_eq!(found_api, Some(AnthropicApi::Messages)); + + let not_found = AnthropicApi::from_endpoint("/v1/unknown"); + assert_eq!(not_found, None); + + // Test all_variants + let all_variants = AnthropicApi::all_variants(); + assert_eq!(all_variants.len(), 1); + assert_eq!(all_variants[0], AnthropicApi::Messages); + } +} diff --git a/crates/hermesllm/src/apis/mod.rs b/crates/hermesllm/src/apis/mod.rs new file mode 100644 index 00000000..78b634d5 --- /dev/null +++ b/crates/hermesllm/src/apis/mod.rs @@ -0,0 +1,197 @@ +pub mod anthropic; +pub mod openai; + +// Re-export all types for convenience +pub use anthropic::*; +pub use openai::*; + +/// Common trait that all API definitions must implement +/// +/// This trait ensures consistency across different AI provider API definitions +/// and makes it easy to add new providers like Gemini, Claude, etc. +/// +/// Note: This is different from the `ApiProvider` enum in `clients::endpoints` +/// which represents provider identification, while this trait defines API capabilities. +/// +/// # Benefits +/// +/// - **Consistency**: All API providers implement the same interface +/// - **Extensibility**: Easy to add new providers without breaking existing code +/// - **Type Safety**: Compile-time guarantees that all providers implement required methods +/// - **Discoverability**: Clear documentation of what capabilities each API supports +/// +/// # Example implementation for a new provider: +/// +/// ```rust,ignore +/// use serde::{Deserialize, Serialize}; +/// use super::ApiDefinition; +/// +/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +/// pub enum GeminiApi { +/// GenerateContent, +/// ChatCompletions, +/// } +/// +/// impl GeminiApi { +/// pub fn endpoint(&self) -> &'static str { +/// match self { +/// GeminiApi::GenerateContent => "/v1/models/gemini-pro:generateContent", +/// GeminiApi::ChatCompletions => "/v1/models/gemini-pro:chat", +/// } +/// } +/// +/// pub fn from_endpoint(endpoint: &str) -> Option { +/// match endpoint { +/// "/v1/models/gemini-pro:generateContent" => Some(GeminiApi::GenerateContent), +/// "/v1/models/gemini-pro:chat" => Some(GeminiApi::ChatCompletions), +/// _ => None, +/// } +/// } +/// +/// pub fn supports_streaming(&self) -> bool { +/// match self { +/// GeminiApi::GenerateContent => true, +/// GeminiApi::ChatCompletions => true, +/// } +/// } +/// +/// pub fn supports_tools(&self) -> bool { +/// match self { +/// GeminiApi::GenerateContent => true, +/// GeminiApi::ChatCompletions => false, +/// } +/// } +/// +/// pub fn supports_vision(&self) -> bool { +/// match self { +/// GeminiApi::GenerateContent => true, +/// GeminiApi::ChatCompletions => false, +/// } +/// } +/// } +/// +/// impl ApiDefinition for GeminiApi { +/// fn endpoint(&self) -> &'static str { +/// self.endpoint() +/// } +/// +/// fn from_endpoint(endpoint: &str) -> Option { +/// Self::from_endpoint(endpoint) +/// } +/// +/// fn supports_streaming(&self) -> bool { +/// self.supports_streaming() +/// } +/// +/// fn supports_tools(&self) -> bool { +/// self.supports_tools() +/// } +/// +/// fn supports_vision(&self) -> bool { +/// self.supports_vision() +/// } +/// } +/// +/// // Now you can use generic code that works with any API: +/// fn print_api_info(api: &T) { +/// println!("Endpoint: {}", api.endpoint()); +/// println!("Supports streaming: {}", api.supports_streaming()); +/// println!("Supports tools: {}", api.supports_tools()); +/// println!("Supports vision: {}", api.supports_vision()); +/// } +/// +/// // Works with both OpenAI and Anthropic (and future Gemini) +/// print_api_info(&OpenAIApi::ChatCompletions); +/// print_api_info(&AnthropicApi::Messages); +/// print_api_info(&GeminiApi::GenerateContent); +/// ``` +pub trait ApiDefinition { + /// Returns the endpoint path for this API + fn endpoint(&self) -> &'static str; + + /// Creates an API instance from an endpoint path + fn from_endpoint(endpoint: &str) -> Option + where + Self: Sized; + + /// Returns whether this API supports streaming responses + fn supports_streaming(&self) -> bool; + + /// Returns whether this API supports tool/function calling + fn supports_tools(&self) -> bool; + + /// Returns whether this API supports vision/image processing + fn supports_vision(&self) -> bool; + + /// Returns all variants of this API enum + fn all_variants() -> Vec + where + Self: Sized; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generic_api_functionality() { + // Test that our generic API functionality works with both providers + fn test_api(api: &T) { + let endpoint = api.endpoint(); + assert!(!endpoint.is_empty()); + assert!(endpoint.starts_with('/')); + } + + test_api(&OpenAIApi::ChatCompletions); + test_api(&AnthropicApi::Messages); + } + + #[test] + fn test_api_detection_from_endpoints() { + // Test that we can detect APIs from endpoints using the trait + let endpoints = vec![ + "/v1/chat/completions", + "/v1/messages", + "/v1/unknown" + ]; + + let mut detected_apis = Vec::new(); + + for endpoint in endpoints { + if let Some(api) = OpenAIApi::from_endpoint(endpoint) { + detected_apis.push(format!("OpenAI: {:?}", api)); + } else if let Some(api) = AnthropicApi::from_endpoint(endpoint) { + detected_apis.push(format!("Anthropic: {:?}", api)); + } else { + detected_apis.push("Unknown API".to_string()); + } + } + + assert_eq!(detected_apis, vec![ + "OpenAI: ChatCompletions", + "Anthropic: Messages", + "Unknown API" + ]); + } + + #[test] + fn test_all_variants_method() { + // Test that all_variants returns the expected variants + let openai_variants = OpenAIApi::all_variants(); + assert_eq!(openai_variants.len(), 1); + assert!(openai_variants.contains(&OpenAIApi::ChatCompletions)); + + let anthropic_variants = AnthropicApi::all_variants(); + assert_eq!(anthropic_variants.len(), 1); + assert!(anthropic_variants.contains(&AnthropicApi::Messages)); + + // Verify each variant has a valid endpoint + for variant in openai_variants { + assert!(!variant.endpoint().is_empty()); + } + + for variant in anthropic_variants { + assert!(!variant.endpoint().is_empty()); + } + } +} diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs new file mode 100644 index 00000000..7f75c6be --- /dev/null +++ b/crates/hermesllm/src/apis/openai.rs @@ -0,0 +1,883 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde_with::skip_serializing_none; +use std::collections::HashMap; + +use super::ApiDefinition; + +// ============================================================================ +// OPENAI API ENUMERATION +// ============================================================================ + +/// Enum for all supported OpenAI APIs +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum OpenAIApi { + ChatCompletions, + // Future APIs can be added here: + // Embeddings, + // FineTuning, + // etc. +} + +impl ApiDefinition for OpenAIApi { + fn endpoint(&self) -> &'static str { + match self { + OpenAIApi::ChatCompletions => "/v1/chat/completions", + } + } + + fn from_endpoint(endpoint: &str) -> Option { + match endpoint { + "/v1/chat/completions" => Some(OpenAIApi::ChatCompletions), + _ => None, + } + } + + fn supports_streaming(&self) -> bool { + match self { + OpenAIApi::ChatCompletions => true, + } + } + + fn supports_tools(&self) -> bool { + match self { + OpenAIApi::ChatCompletions => true, + } + } + + fn supports_vision(&self) -> bool { + match self { + OpenAIApi::ChatCompletions => true, + } + } + + fn all_variants() -> Vec { + vec![ + OpenAIApi::ChatCompletions, + ] + } +} + +/// Chat completions API request +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +pub struct ChatCompletionsRequest { + pub messages: Vec, + pub model: String, + // pub audio: Option