Merge branch 'main' into adil/update_torch

This commit is contained in:
Adil Hafeez 2025-08-10 22:10:33 -07:00
commit 5bc96b139e
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
42 changed files with 4363 additions and 457 deletions

View file

@ -24,7 +24,7 @@ jobs:
- name: build arch docker image
run: |
cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.5 -t katanemo/archgw:latest
cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.7 -t katanemo/archgw:latest
- name: start archgw
env:

View file

@ -24,7 +24,7 @@ jobs:
- name: build arch docker image
run: |
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.5
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.7
- name: install poetry
run: |

View file

@ -24,7 +24,7 @@ jobs:
- name: build arch docker image
run: |
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.5
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.7
- name: install poetry
run: |

View file

@ -24,7 +24,7 @@ jobs:
- name: build arch docker image
run: |
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.5
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.7
- name: validate arch config
run: |

View file

@ -4,8 +4,8 @@
<div align="center">
_The proxy server and the universal data plane for AI-native apps._<br><br>
Arch handles the *pesky low-level work* in building AI agents like clarifying vague user inputs, routing prompts to the right agents, calling tools for simple tasks, and unifying access to large language models (LLMs) - all without locking you into a framework. Move faster by focusing on the high-level logic of your agents.
_Arch is a smart proxy server designed as a modular edge and AI gateway for agentic apps_<br><br>
Arch handles the *pesky low-level work* in building agentic apps — like applying guardrails, clarifying vague user input, routing prompts to the right agent, and unifying access to any LLM. Its a language and framework friendly infrastructure layer designed to help you build and ship agentic apps faster.
[Quickstart](#Quickstart) •
@ -80,9 +80,9 @@ Arch's CLI allows you to manage and interact with the Arch gateway efficiently.
> We recommend that developers create a new Python virtual environment to isolate dependencies before installing Arch. This ensures that archgw and its dependencies do not interfere with other packages on your system.
```console
$ python -m venv venv
$ python3.12 -m venv venv
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
$ pip install archgw==0.3.5
$ pip install archgw==0.3.7
```
### Build Agentic Apps with Arch Gateway
@ -148,13 +148,10 @@ endpoints:
```sh
$ archgw up arch_config.yaml
2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.1.5
...
2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.3.7
2024-12-05 16:56:28,485 - cli.utils - INFO - Schema validation successful!
2024-12-05 16:56:28,485 - cli.main - INFO - Starting arch model server and arch gateway
...
2024-12-05 16:56:51,647 - cli.core - INFO - Container is healthy!
```
Once the gateway is up you can start interacting with at port 10000 using openai chat completion API.

View file

@ -2,14 +2,14 @@
nodaemon=true
[program:brightstaff]
command=sh -c "RUST_LOG=info /app/brightstaff 2>&1 | tee /var/log/brightstaff.log"
command=sh -c "RUST_LOG=debug /app/brightstaff 2>&1 | tee /var/log/brightstaff.log"
stdout_logfile=/dev/stdout
redirect_stderr=true
stdout_logfile_maxbytes=0
stderr_logfile_maxbytes=0
[program:envoy]
command=/bin/sh -c "python /app/config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:info 2>&1 | tee /var/log//envoy.log"
command=/bin/sh -c "python /app/config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:debug 2>&1 | tee /var/log//envoy.log"
stdout_logfile=/dev/stdout
redirect_stderr=true
stdout_logfile_maxbytes=0

View file

@ -19,7 +19,7 @@ source venv/bin/activate
### Step 3: Run the build script
```bash
pip install archgw==0.3.5
pip install archgw==0.3.7
```
## Uninstall Instructions: archgw CLI

View file

@ -10,4 +10,4 @@ SERVICE_NAME_MODEL_SERVER = "model_server"
SERVICE_ALL = "all"
MODEL_SERVER_LOG_FILE = "~/archgw_logs/modelserver.log"
ARCHGW_DOCKER_NAME = "archgw"
ARCHGW_DOCKER_IMAGE = os.getenv("ARCHGW_DOCKER_IMAGE", "katanemo/archgw:0.3.5")
ARCHGW_DOCKER_IMAGE = os.getenv("ARCHGW_DOCKER_IMAGE", "katanemo/archgw:0.3.7")

12
arch/tools/poetry.lock generated
View file

@ -2,7 +2,7 @@
[[package]]
name = "archgw_modelserver"
version = "0.3.5"
version = "0.3.7"
description = "A model server for serving models"
optional = false
python-versions = "*"
@ -104,13 +104,13 @@ i18n = ["Babel (>=2.7)"]
[[package]]
name = "jsonschema"
version = "4.24.0"
version = "4.25.0"
description = "An implementation of JSON Schema validation for Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "jsonschema-4.24.0-py3-none-any.whl", hash = "sha256:a462455f19f5faf404a7902952b6f0e3ce868f3ee09a359b05eca6673bd8412d"},
{file = "jsonschema-4.24.0.tar.gz", hash = "sha256:0b4e8069eb12aedfa881333004bccaec24ecef5a8a6a4b6df142b2cc9599d196"},
{file = "jsonschema-4.25.0-py3-none-any.whl", hash = "sha256:24c2e8da302de79c8b9382fee3e76b355e44d2a4364bb207159ce10b517bd716"},
{file = "jsonschema-4.25.0.tar.gz", hash = "sha256:e63acf5c11762c0e6672ffb61482bdf57f0876684d8d249c0fe2d730d48bc55f"},
]
[package.dependencies]
@ -121,7 +121,7 @@ rpds-py = ">=0.7.1"
[package.extras]
format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"]
format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"]
format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "rfc3987-syntax (>=1.1.0)", "uri-template", "webcolors (>=24.6.0)"]
[[package]]
name = "jsonschema-specifications"
@ -576,4 +576,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "e86085ff732244cb68d2e3f7f4c2903f4a8a50cc7e0963324c2506f0de90df11"
content-hash = "1875c613e62e116d557ad2d30491891557b4114a99c7c65b22b26d690e9e268b"

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw"
version = "0.3.5"
version = "0.3.7"
description = "Python-based CLI tool to manage Arch Gateway."
authors = ["Katanemo Labs, Inc."]
packages = [
@ -10,7 +10,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10"
archgw_modelserver = "^0.3.5"
archgw_modelserver = "^0.3.7"
click = "^8.1.7"
jinja2 = "^3.1.4"
jsonschema = "^4.23.0"

View file

@ -27,10 +27,13 @@ pub async fn chat_completions(
router_service: Arc<RouterService>,
llm_provider_endpoint: String,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
let chat_request_parsed = serde_json::from_slice::<serde_json::Value>(&chat_request_bytes)
.inspect_err(|err| {
warn!(
@ -61,20 +64,15 @@ pub async fn chat_completions(
// remove metadata from the request
let mut chat_request_user_preferences_removed = chat_request_parsed;
if let Some(metadata) = chat_request_user_preferences_removed.get_mut("metadata") {
info!("Removing metadata from request");
debug!("Removing metadata from request");
if let Some(m) = metadata.as_object_mut() {
m.remove("archgw_preference_config");
info!("Removed archgw_preference_config from metadata");
debug!("Removed archgw_preference_config from metadata");
}
// metadata.as_object_mut().map(|m| {
// m.remove("archgw_preference_config");
// info!("Removed archgw_preference_config from metadata");
// });
// if metadata is empty, remove it
if metadata.as_object().map_or(false, |m| m.is_empty()) {
info!("Removing empty metadata from request");
debug!("Removing empty metadata from request");
chat_request_user_preferences_removed
.as_object_mut()
.map(|m| m.remove("metadata"));
@ -102,9 +100,33 @@ pub async fn chat_completions(
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
let latest_message_for_log =
chat_completion_request
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content.as_ref().map_or("None".to_string(), |content| {
content.to_string().replace('\n', "\\n")
})
});
const MAX_MESSAGE_LENGTH: usize = 50;
let latest_message_for_log = if latest_message_for_log.len() > MAX_MESSAGE_LENGTH {
format!("{}...", &latest_message_for_log[..MAX_MESSAGE_LENGTH])
} else {
latest_message_for_log
};
info!(
"request received, request type: chat_completion, usage preferences from request: {}, request path: {}, latest message: {}",
usage_preferences.is_some(),
request_path,
latest_message_for_log
);
debug!("usage preferences from request: {:?}", usage_preferences);
let mut determined_route = match router_service
let model_name = match router_service
.determine_route(
&chat_completion_request.messages,
trace_parent.clone(),
@ -112,7 +134,16 @@ pub async fn chat_completions(
)
.await
{
Ok(route) => route,
Ok(route) => match route {
Some((_, model_name)) => model_name,
None => {
debug!(
"No route determined, using default model from request: {}",
chat_completion_request.model
);
chat_completion_request.model.clone()
}
},
Err(err) => {
let err_msg = format!("Failed to determine route: {}", err);
let mut internal_error = Response::new(full(err_msg));
@ -121,14 +152,14 @@ pub async fn chat_completions(
}
};
if determined_route.is_none() {
debug!("No LLM model selected, using default from request");
determined_route = Some(chat_completion_request.model.clone());
}
debug!(
"sending request to llm provider: {}, with model hint: {}",
llm_provider_endpoint, model_name
);
info!(
"sending request to llm provider: {} with llm model: {:?}",
llm_provider_endpoint, determined_route
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&model_name).unwrap(),
);
if let Some(trace_parent) = trace_parent {
@ -138,13 +169,6 @@ pub async fn chat_completions(
);
}
if let Some(selected_route) = determined_route {
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&selected_route).unwrap(),
);
}
let chat_request_parsed_bytes =
serde_json::to_string(&chat_request_user_preferences_removed).unwrap();

View file

@ -1,3 +1,2 @@
pub mod chat_completions;
pub mod models;
pub mod preferences;

View file

@ -1,135 +0,0 @@
use bytes::Bytes;
use common::configuration::{LlmProvider, ModelUsagePreference};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use serde_json;
use std::{collections::HashMap, sync::Arc};
use tracing::{info, warn};
pub async fn list_preferences(
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let prov = llm_providers.read().await;
// convert the LlmProvider to UsageBasedProvider
let providers_with_usage = prov
.iter()
.map(|provider| ModelUsagePreference {
name: provider.name.clone(),
model: provider.model.clone().unwrap_or_default(),
usage: provider.usage.clone(),
})
.collect::<Vec<ModelUsagePreference>>();
match serde_json::to_string(&providers_with_usage) {
Ok(json) => {
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
Err(_) => {
let body = Full::new(Bytes::from_static(
b"{\"error\":\"Failed to serialize models\"}",
))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
}
}
pub async fn update_preferences(
request: Request<hyper::body::Incoming>,
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_body = request.collect().await?.to_bytes();
let usage: Vec<ModelUsagePreference> = match serde_json::from_slice(&request_body) {
Ok(usage) => usage,
Err(_) => {
let response_body = Full::new(Bytes::from_static(b"Invalid request body: "))
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap());
}
};
let usage_model_map: HashMap<String, ModelUsagePreference> =
usage.into_iter().map(|u| (u.model.clone(), u)).collect();
info!(
"Updating usage preferences for models: {:?}",
usage_model_map.keys()
);
let mut llm_providers = llm_providers.write().await;
// ensure that models coming in the request are valid
let llm_provider_names: Vec<String> = llm_providers
.iter()
.map(|provider| provider.name.clone())
.collect();
for model in usage_model_map.keys() {
if !llm_provider_names.contains(model) {
let model_not_found = format!("model not found: {}", model);
warn!("updating preferences: {}", model_not_found);
let response_body = Full::new(model_not_found.into())
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap());
}
}
let mut updated_models_list = Vec::new();
for provider in llm_providers.iter_mut() {
if let Some(usage_provider) = usage_model_map.get(&provider.name) {
provider.usage = usage_provider.usage.clone();
updated_models_list.push(ModelUsagePreference {
name: provider.name.clone(),
model: provider.model.clone().unwrap_or_default(),
usage: provider.usage.clone(),
});
}
}
if !updated_models_list.is_empty() {
// return list of updated models
let response_body = Full::new(Bytes::from(format!(
"{{\"updated_models\": {}}}",
serde_json::to_string(&updated_models_list).unwrap()
)))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(response_body)
.unwrap())
} else {
let response_body = Full::new(Bytes::from_static(b"Provider not found"))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap())
}
}

View file

@ -1,6 +1,5 @@
use brightstaff::handlers::chat_completions::chat_completions;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::preferences::{list_preferences, update_preferences};
use brightstaff::router::llm_router::RouterService;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
@ -116,12 +115,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/router/preferences") => {
Ok(list_preferences(llm_providers).await)
}
(&Method::PUT, "/v1/router/preferences") => {
update_preferences(req, llm_providers).await
}
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
(&Method::OPTIONS, "/v1/models") => {
let mut response = Response::new(empty());
@ -156,7 +149,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
});
tokio::task::spawn(async move {
info!("Accepted connection from {:?}", peer_addr);
debug!("Accepted connection from {:?}", peer_addr);
if let Err(err) = http1::Builder::new()
// .serve_connection(io, service_fn(chat_completion))
.serve_connection(io, service)

View file

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
@ -48,9 +48,14 @@ impl RouterService {
.cloned()
.collect::<Vec<LlmProvider>>();
let llm_routes: Vec<RoutingPreference> = providers_with_usage
let llm_routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
.iter()
.flat_map(|provider| provider.routing_preferences.clone().unwrap_or_default())
.filter_map(|provider| {
provider
.routing_preferences
.as_ref()
.map(|prefs| (provider.name.clone(), prefs.clone()))
})
.collect();
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
@ -73,7 +78,7 @@ impl RouterService {
messages: &[Message],
trace_parent: Option<String>,
usage_preferences: Option<Vec<ModelUsagePreference>>,
) -> Result<Option<String>> {
) -> Result<Option<(String, String)>> {
if !self.llm_usage_defined {
return Ok(None);
}
@ -82,7 +87,7 @@ impl RouterService {
.router_model
.generate_request(messages, &usage_preferences);
info!(
debug!(
"sending request to arch-router model: {}, endpoint: {}",
self.router_model.get_model_name(),
self.router_url
@ -151,21 +156,21 @@ impl RouterService {
if let Some(ContentType::Text(content)) =
&chat_completion_response.choices[0].message.content
{
let route_name = self.router_model.parse_response(content)?;
let parsed_response = self
.router_model
.parse_response(content, &usage_preferences)?;
info!(
"router response: {}, selected_model: {:?}, response time: {}ms",
"arch-router determined route: {}, selected_model: {:?}, response time: {}ms",
content.replace("\n", "\\n"),
route_name,
parsed_response,
router_response_time.as_millis()
);
if let Some(ref route) = route_name {
if route == "other" {
return Ok(None);
}
if let Some(ref parsed_response) = parsed_response {
return Ok(Some(parsed_response.clone()));
}
Ok(route_name)
Ok(None)
} else {
Ok(None)
}

View file

@ -16,6 +16,10 @@ pub trait RouterModel: Send + Sync {
messages: &[Message],
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest;
fn parse_response(&self, content: &str) -> Result<Option<String>>;
fn parse_response(
&self,
content: &str,
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> Result<Option<(String, String)>>;
fn get_model_name(&self) -> String;
}

View file

@ -1,3 +1,5 @@
use std::collections::HashMap;
use common::{
configuration::{ModelUsagePreference, RoutingPreference},
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
@ -32,21 +34,30 @@ Based on your analysis, provide your response in the following JSON formats if y
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub struct RouterModelV1 {
llm_route_json_str: String,
llm_route_to_model_map: HashMap<String, String>,
routing_model: String,
max_token_length: usize,
}
impl RouterModelV1 {
pub fn new(
llm_routes: Vec<RoutingPreference>,
llm_routes: HashMap<String, Vec<RoutingPreference>>,
routing_model: String,
max_token_length: usize,
) -> Self {
let llm_route_values: Vec<RoutingPreference> =
llm_routes.values().flatten().cloned().collect();
let llm_route_json_str =
serde_json::to_string(&llm_routes).unwrap_or_else(|_| "[]".to_string());
serde_json::to_string(&llm_route_values).unwrap_or_else(|_| "[]".to_string());
let llm_route_to_model_map: HashMap<String, String> = llm_routes
.iter()
.flat_map(|(model, prefs)| prefs.iter().map(|pref| (pref.name.clone(), model.clone())))
.collect();
RouterModelV1 {
routing_model,
max_token_length,
llm_route_json_str,
llm_route_to_model_map,
}
}
}
@ -62,7 +73,7 @@ impl RouterModel for RouterModelV1 {
fn generate_request(
&self,
messages: &[Message],
usage_preferences: &Option<Vec<ModelUsagePreference>>,
usage_preferences_from_request: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content
// if content is empty its likely a tool call
@ -139,31 +150,17 @@ impl RouterModel for RouterModelV1 {
})
.collect::<Vec<Message>>();
let llm_route_json = usage_preferences
.as_ref()
.map(|prefs| {
let llm_route: Vec<RoutingPreference> = prefs
.iter()
.map(|pref| RoutingPreference {
name: pref.name.clone(),
description: pref.usage.clone().unwrap_or_default(),
})
.collect();
serde_json::to_string(&llm_route).unwrap_or_default()
})
.unwrap_or_else(|| self.llm_route_json_str.clone());
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &llm_route_json)
.replace(
"{conversation}",
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
);
// Generate the router request message based on the usage preferences.
// If preferences are passed in request then we use them otherwise we use the default routing model preferences.
let router_message = match convert_to_router_preferences(usage_preferences_from_request) {
Some(prefs) => generate_router_message(&prefs, &selected_conversation_list),
None => generate_router_message(&self.llm_route_json_str, &selected_conversation_list),
};
ChatCompletionsRequest {
model: self.routing_model.clone(),
messages: vec![Message {
content: Some(ContentType::Text(messages_content)),
content: Some(ContentType::Text(router_message)),
role: USER_ROLE.to_string(),
}],
temperature: Some(0.01),
@ -171,20 +168,57 @@ impl RouterModel for RouterModelV1 {
}
}
fn parse_response(&self, content: &str) -> Result<Option<String>> {
fn parse_response(
&self,
content: &str,
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> Result<Option<(String, String)>> {
if content.is_empty() {
return Ok(None);
}
let router_resp_fixed = fix_json_response(content);
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
let selected_llm = router_response.route.unwrap_or_default().to_string();
let selected_route = router_response.route.unwrap_or_default().to_string();
if selected_llm.is_empty() {
if selected_route.is_empty() || selected_route == "other" {
return Ok(None);
}
Ok(Some(selected_llm))
if let Some(usage_preferences) = usage_preferences {
// If usage preferences are defined, we need to find the model that matches the selected route
let model_name: Option<String> = usage_preferences
.iter()
.map(|pref| {
pref.routing_preferences
.iter()
.find(|routing_pref| routing_pref.name == selected_route)
.map(|_| pref.model.clone())
})
.find_map(|model| model);
if let Some(model_name) = model_name {
return Ok(Some((selected_route, model_name)));
} else {
warn!(
"No matching model found for route: {}, usage preferences: {:?}",
selected_route, usage_preferences
);
return Ok(None);
}
}
// If no usage preferences are passed in request then use the default routing model preferences
if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() {
return Ok(Some((selected_route, model)));
}
warn!(
"No model found for route: {}, router model preferences: {:?}",
selected_route, self.llm_route_to_model_map
);
Ok(None)
}
fn get_model_name(&self) -> String {
@ -192,6 +226,37 @@ impl RouterModel for RouterModelV1 {
}
}
fn generate_router_message(prefs: &str, selected_conversation_list: &Vec<Message>) -> String {
ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", prefs)
.replace(
"{conversation}",
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
)
}
fn convert_to_router_preferences(
prefs_from_request: &Option<Vec<ModelUsagePreference>>,
) -> Option<String> {
if let Some(usage_preferences) = prefs_from_request {
let routing_preferences = usage_preferences
.iter()
.flat_map(|pref| {
pref.routing_preferences
.iter()
.map(|routing_pref| RoutingPreference {
name: routing_pref.name.clone(),
description: routing_pref.description.clone(),
})
})
.collect::<Vec<RoutingPreference>>();
return Some(serde_json::to_string(&routing_preferences).unwrap_or_default());
}
None
}
fn fix_json_response(body: &str) -> String {
let mut updated_body = body.to_string();
@ -235,7 +300,7 @@ mod tests {
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -251,15 +316,14 @@ Based on your analysis, provide your response in the following JSON formats if y
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -310,15 +374,14 @@ Based on your analysis, provide your response in the following JSON formats if y
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -341,9 +404,11 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let usage_preferences = Some(vec![ModelUsagePreference {
name: "code-generation".to_string(),
model: "claude/claude-3-7-sonnet".to_string(),
usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()),
routing_preferences: vec![RoutingPreference {
name: "code-generation".to_string(),
description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(),
}],
}]);
let req = router.generate_request(&conversation, &usage_preferences);
@ -358,7 +423,7 @@ Based on your analysis, provide your response in the following JSON formats if y
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -375,15 +440,14 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235);
@ -419,7 +483,7 @@ Based on your analysis, provide your response in the following JSON formats if y
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -436,15 +500,15 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200);
@ -480,7 +544,7 @@ Based on your analysis, provide your response in the following JSON formats if y
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -497,15 +561,14 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230);
@ -549,7 +612,7 @@ Based on your analysis, provide your response in the following JSON formats if y
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -565,15 +628,14 @@ Based on your analysis, provide your response in the following JSON formats if y
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -619,7 +681,7 @@ Based on your analysis, provide your response in the following JSON formats if y
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -635,15 +697,14 @@ Based on your analysis, provide your response in the following JSON formats if y
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -712,56 +773,64 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_parse_response() {
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
// Case 1: Valid JSON with non-empty route
let input = r#"{"route": "route1"}"#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route1".to_string()));
let input = r#"{"route": "Image generation"}"#;
let result = router.parse_response(input, &None).unwrap();
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
// Case 2: Valid JSON with empty route
let input = r#"{"route": ""}"#;
let result = router.parse_response(input).unwrap();
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 3: Valid JSON with null route
let input = r#"{"route": null}"#;
let result = router.parse_response(input).unwrap();
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 4: JSON missing route field
let input = r#"{}"#;
let result = router.parse_response(input).unwrap();
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 4.1: empty string
let input = r#""#;
let result = router.parse_response(input).unwrap();
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 5: Malformed JSON
let input = r#"{"route": "route1""#; // missing closing }
let result = router.parse_response(input);
let result = router.parse_response(input, &None);
assert!(result.is_err());
// Case 6: Single quotes and \n in JSON
let input = "{'route': 'route2'}\\n";
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route2".to_string()));
let input = "{'route': 'Image generation'}\\n";
let result = router.parse_response(input, &None).unwrap();
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
// Case 7: Code block marker
let input = "```json\n{\"route\": \"route1\"}\n```";
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route1".to_string()));
let input = "```json\n{\"route\": \"Image generation\"}\n```";
let result = router.parse_response(input, &None).unwrap();
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
}
}

View file

@ -1,6 +1,5 @@
use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models};
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use std::fmt::Display;
@ -178,12 +177,10 @@ impl Display for LlmProviderType {
}
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug)]
pub struct ModelUsagePreference {
pub name: String,
pub model: String,
pub usage: Option<String>,
pub routing_preferences: Vec<RoutingPreference>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -0,0 +1,898 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use super::ApiDefinition;
// Enum for all supported Anthropic APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum AnthropicApi {
Messages,
// Future APIs can be added here:
// Embeddings,
// etc.
}
impl ApiDefinition for AnthropicApi {
fn endpoint(&self) -> &'static str {
match self {
AnthropicApi::Messages => "/v1/messages",
}
}
fn from_endpoint(endpoint: &str) -> Option<Self> {
match endpoint {
"/v1/messages" => Some(AnthropicApi::Messages),
_ => None,
}
}
fn supports_streaming(&self) -> bool {
match self {
AnthropicApi::Messages => true,
}
}
fn supports_tools(&self) -> bool {
match self {
AnthropicApi::Messages => true,
}
}
fn supports_vision(&self) -> bool {
match self {
AnthropicApi::Messages => true,
}
}
fn all_variants() -> Vec<Self> {
vec![
AnthropicApi::Messages,
]
}
}
// Service tier enum for request priority
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ServiceTier {
Auto,
StandardOnly,
}
// Thinking configuration
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ThinkingConfig {
pub enabled: bool,
}
// MCP Server types
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum McpServerType {
Url,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct McpToolConfiguration {
pub allowed_tools: Option<Vec<String>>,
pub enabled: Option<bool>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct McpServer {
pub name: String,
#[serde(rename = "type")]
pub server_type: McpServerType,
pub url: String,
pub authorization_token: Option<String>,
pub tool_configuration: Option<McpToolConfiguration>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesRequest {
pub model: String,
pub messages: Vec<MessagesMessage>,
pub max_tokens: u32,
pub container: Option<String>,
pub mcp_servers: Option<Vec<McpServer>>,
pub system: Option<MessagesSystemPrompt>,
pub metadata: Option<HashMap<String, Value>>,
pub service_tier: Option<ServiceTier>,
pub thinking: Option<ThinkingConfig>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub stream: Option<bool>,
pub stop_sequences: Option<Vec<String>>,
pub tools: Option<Vec<MessagesTool>>,
pub tool_choice: Option<MessagesToolChoice>,
}
// Messages API specific types
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum MessagesRole {
User,
Assistant,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum MessagesContentBlock {
Text {
text: String,
},
Thinking {
text: String,
},
Image {
source: MessagesImageSource,
},
Document {
source: MessagesDocumentSource,
},
ToolUse {
id: String,
name: String,
input: Value,
},
ToolResult {
tool_use_id: String,
is_error: Option<bool>,
content: Vec<MessagesContentBlock>,
},
ServerToolUse {
id: String,
name: String,
input: Value,
},
WebSearchToolResult {
tool_use_id: String,
is_error: Option<bool>,
content: Vec<MessagesContentBlock>,
},
CodeExecutionToolResult {
tool_use_id: String,
is_error: Option<bool>,
content: Vec<MessagesContentBlock>,
},
McpToolUse {
id: String,
name: String,
input: Value,
},
McpToolResult {
tool_use_id: String,
is_error: Option<bool>,
content: Vec<MessagesContentBlock>,
},
ContainerUpload {
id: String,
name: String,
media_type: String,
data: String,
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum MessagesImageSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum MessagesDocumentSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
File {
file_id: String,
},
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum MessagesMessageContent {
Single(String),
Blocks(Vec<MessagesContentBlock>),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum MessagesSystemPrompt {
Single(String),
Blocks(Vec<MessagesContentBlock>),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesMessage {
pub role: MessagesRole,
pub content: MessagesMessageContent,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesTool {
pub name: String,
pub description: Option<String>,
pub input_schema: Value,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum MessagesToolChoiceType {
Auto,
Any,
Tool,
None,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesToolChoice {
#[serde(rename = "type")]
pub kind: MessagesToolChoiceType,
pub name: Option<String>,
pub disable_parallel_tool_use: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MessagesStopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
PauseTurn,
Refusal,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesUsage {
pub input_tokens: u32,
pub output_tokens: u32,
pub cache_creation_input_tokens: Option<u32>,
pub cache_read_input_tokens: Option<u32>,
}
// Container response object
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesContainer {
pub id: String,
#[serde(rename = "type")]
pub container_type: String,
pub name: String,
pub status: String,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesResponse {
pub id: String,
#[serde(rename = "type")]
pub obj_type: String,
pub role: MessagesRole,
pub content: Vec<MessagesContentBlock>,
pub model: String,
pub stop_reason: MessagesStopReason,
pub stop_sequence: Option<String>,
pub usage: MessagesUsage,
pub container: Option<MessagesContainer>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum MessagesStreamEvent {
MessageStart {
message: MessagesStreamMessage,
},
ContentBlockStart {
index: u32,
content_block: MessagesContentBlock,
},
ContentBlockDelta {
index: u32,
delta: MessagesContentDelta,
},
ContentBlockStop {
index: u32,
},
MessageDelta {
delta: MessagesMessageDelta,
usage: MessagesUsage,
},
MessageStop,
Ping,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesStreamMessage {
pub id: String,
#[serde(rename = "type")]
pub obj_type: String,
pub role: MessagesRole,
pub content: Vec<Value>, // Initially empty
pub model: String,
pub stop_reason: Option<MessagesStopReason>,
pub stop_sequence: Option<String>,
pub usage: MessagesUsage,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum MessagesContentDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessagesMessageDelta {
pub stop_reason: MessagesStopReason,
pub stop_sequence: Option<String>,
}
// Helper functions for API detection and conversion
impl MessagesRequest {
pub fn api_type() -> AnthropicApi {
AnthropicApi::Messages
}
}
impl MessagesResponse {
pub fn api_type() -> AnthropicApi {
AnthropicApi::Messages
}
}
impl MessagesStreamEvent {
pub fn api_type() -> AnthropicApi {
AnthropicApi::Messages
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_anthropic_required_fields() {
// Create a JSON object with only required fields
let original_json = json!({
"model": "claude-3-sonnet-20240229",
"messages": [
{
"role": "user",
"content": "Hello"
}
],
"max_tokens": 100
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_request.messages.len(), 1);
assert_eq!(deserialized_request.max_tokens, 100);
let message = &deserialized_request.messages[0];
assert_eq!(message.role, MessagesRole::User);
if let MessagesMessageContent::Single(content) = &message.content {
assert_eq!(content, "Hello");
} else {
panic!("Expected single content");
}
// Validate optional fields are None
assert!(deserialized_request.system.is_none());
assert!(deserialized_request.container.is_none());
assert!(deserialized_request.mcp_servers.is_none());
assert!(deserialized_request.service_tier.is_none());
assert!(deserialized_request.thinking.is_none());
assert!(deserialized_request.temperature.is_none());
assert!(deserialized_request.top_p.is_none());
assert!(deserialized_request.top_k.is_none());
assert!(deserialized_request.stream.is_none());
assert!(deserialized_request.stop_sequences.is_none());
assert!(deserialized_request.tools.is_none());
assert!(deserialized_request.tool_choice.is_none());
assert!(deserialized_request.metadata.is_none());
// Serialize back to JSON and compare
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
assert_eq!(original_json, serialized_json);
}
#[test]
fn test_anthropic_optional_fields() {
// Create a JSON object with optional fields set
let original_json = json!({
"model": "claude-3-sonnet-20240229",
"messages": [
{
"role": "user",
"content": "Hello"
}
],
"max_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
"system": "You are a helpful assistant",
"service_tier": "auto",
"thinking": {
"enabled": true
},
"metadata": {
"user_id": "123"
}
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_request.messages.len(), 1);
assert_eq!(deserialized_request.max_tokens, 100);
// Validate optional fields are properly set
assert!((deserialized_request.temperature.unwrap() - 0.7).abs() < 1e-6);
assert!((deserialized_request.top_p.unwrap() - 0.9).abs() < 1e-6);
assert_eq!(deserialized_request.service_tier, Some(ServiceTier::Auto));
if let Some(MessagesSystemPrompt::Single(system)) = &deserialized_request.system {
assert_eq!(system, "You are a helpful assistant");
} else {
panic!("Expected single system prompt");
}
if let Some(thinking) = &deserialized_request.thinking {
assert_eq!(thinking.enabled, true);
} else {
panic!("Expected thinking config");
}
assert!(deserialized_request.metadata.is_some());
// Validate fields not in JSON are None
assert!(deserialized_request.container.is_none());
assert!(deserialized_request.mcp_servers.is_none());
assert!(deserialized_request.top_k.is_none());
assert!(deserialized_request.stream.is_none());
assert!(deserialized_request.stop_sequences.is_none());
assert!(deserialized_request.tools.is_none());
assert!(deserialized_request.tool_choice.is_none());
// Serialize back to JSON and compare (handle floating point precision)
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
// Compare all fields except floating point ones
assert_eq!(serialized_json["model"], original_json["model"]);
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["system"], original_json["system"]);
assert_eq!(serialized_json["service_tier"], original_json["service_tier"]);
assert_eq!(serialized_json["thinking"], original_json["thinking"]);
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
// Handle floating point fields with tolerance
let original_temp = original_json["temperature"].as_f64().unwrap();
let serialized_temp = serialized_json["temperature"].as_f64().unwrap();
assert!((original_temp - serialized_temp).abs() < 1e-6);
let original_top_p = original_json["top_p"].as_f64().unwrap();
let serialized_top_p = serialized_json["top_p"].as_f64().unwrap();
assert!((original_top_p - serialized_top_p).abs() < 1e-6);
}
#[test]
fn test_anthropic_nested_types() {
// Create a comprehensive JSON object with nested types - a MessagesRequest with complex message content and tools
let original_json = json!({
"model": "claude-3-sonnet-20240229",
"max_tokens": 1000,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What can you see in this image and what's the weather like?"
},
{
"type": "image",
"source": {
"base64": {
"media_type": "image/jpeg",
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
}
}
}
]
},
{
"role": "assistant",
"content": [
{
"type": "thinking",
"text": "Let me analyze the image and then check the weather..."
},
{
"type": "text",
"text": "I can see the image. Let me check the weather for you."
},
{
"type": "tool_use",
"id": "toolu_weather123",
"name": "get_weather",
"input": {
"location": "San Francisco, CA"
}
}
]
}
],
"tools": [
{
"name": "get_weather",
"description": "Get current weather information for a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
},
"required": ["location"]
}
}
],
"tool_choice": {
"type": "auto"
},
"system": [
{
"type": "text",
"text": "You are a helpful assistant that can analyze images and provide weather information."
}
]
});
// Deserialize JSON into MessagesRequest
let deserialized_request: MessagesRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields
assert_eq!(deserialized_request.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_request.max_tokens, 1000);
assert_eq!(deserialized_request.messages.len(), 2);
// Validate first message (user with text and image content)
let user_message = &deserialized_request.messages[0];
assert_eq!(user_message.role, MessagesRole::User);
if let MessagesMessageContent::Blocks(ref content_blocks) = user_message.content {
assert_eq!(content_blocks.len(), 2);
// Validate text content block
if let MessagesContentBlock::Text { text } = &content_blocks[0] {
assert_eq!(text, "What can you see in this image and what's the weather like?");
} else {
panic!("Expected text content block");
}
// Validate image content block
if let MessagesContentBlock::Image { ref source } = content_blocks[1] {
if let MessagesImageSource::Base64 { media_type, data } = source {
assert_eq!(media_type, "image/jpeg");
assert_eq!(data, "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==");
} else {
panic!("Expected base64 image source");
}
} else {
panic!("Expected image content block");
}
} else {
panic!("Expected content blocks for user message");
}
// Validate second message (assistant with thinking, text, and tool use)
let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, MessagesRole::Assistant);
if let MessagesMessageContent::Blocks(ref content_blocks) = assistant_message.content {
assert_eq!(content_blocks.len(), 3);
// Validate thinking content block
if let MessagesContentBlock::Thinking { text } = &content_blocks[0] {
assert_eq!(text, "Let me analyze the image and then check the weather...");
} else {
panic!("Expected thinking content block");
}
// Validate text content block
if let MessagesContentBlock::Text { text } = &content_blocks[1] {
assert_eq!(text, "I can see the image. Let me check the weather for you.");
} else {
panic!("Expected text content block");
}
// Validate tool use content block
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input } = content_blocks[2] {
assert_eq!(id, "toolu_weather123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA");
} else {
panic!("Expected tool use content block");
}
} else {
panic!("Expected content blocks for assistant message");
}
// Validate tools array
assert!(deserialized_request.tools.is_some());
let tools = deserialized_request.tools.as_ref().unwrap();
assert_eq!(tools.len(), 1);
let tool = &tools[0];
assert_eq!(tool.name, "get_weather");
assert_eq!(tool.description, Some("Get current weather information for a location".to_string()));
assert_eq!(tool.input_schema["type"], "object");
assert!(tool.input_schema["properties"]["location"].is_object());
// Validate tool choice
assert!(deserialized_request.tool_choice.is_some());
let tool_choice = deserialized_request.tool_choice.as_ref().unwrap();
assert_eq!(tool_choice.kind, MessagesToolChoiceType::Auto);
assert!(tool_choice.name.is_none());
// Validate system prompt with content blocks
assert!(deserialized_request.system.is_some());
if let Some(MessagesSystemPrompt::Blocks(ref system_blocks)) = deserialized_request.system {
assert_eq!(system_blocks.len(), 1);
if let MessagesContentBlock::Text { text } = &system_blocks[0] {
assert_eq!(text, "You are a helpful assistant that can analyze images and provide weather information.");
} else {
panic!("Expected text content block in system prompt");
}
} else {
panic!("Expected system prompt with content blocks");
}
// Serialize back to JSON and compare
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
assert_eq!(original_json, serialized_json);
}
#[test]
fn test_anthropic_mcp_server_configuration() {
// Test MCP Server configuration with JSON-first approach
let mcp_server_json = json!({
"name": "test-server",
"type": "url",
"url": "https://example.com/mcp",
"authorization_token": "secret-token",
"tool_configuration": {
"allowed_tools": ["tool1", "tool2"],
"enabled": true
}
});
let deserialized_mcp: McpServer = serde_json::from_value(mcp_server_json.clone()).unwrap();
assert_eq!(deserialized_mcp.name, "test-server");
assert_eq!(deserialized_mcp.server_type, McpServerType::Url);
assert_eq!(deserialized_mcp.url, "https://example.com/mcp");
assert_eq!(deserialized_mcp.authorization_token, Some("secret-token".to_string()));
if let Some(tool_config) = &deserialized_mcp.tool_configuration {
assert_eq!(tool_config.allowed_tools, Some(vec!["tool1".to_string(), "tool2".to_string()]));
assert_eq!(tool_config.enabled, Some(true));
} else {
panic!("Expected tool configuration");
}
let serialized_mcp_json = serde_json::to_value(&deserialized_mcp).unwrap();
assert_eq!(mcp_server_json, serialized_mcp_json);
// Test MCP Server with minimal configuration (optional fields as None)
let minimal_mcp_json = json!({
"name": "minimal-server",
"type": "url",
"url": "https://minimal.com/mcp"
});
let deserialized_minimal: McpServer = serde_json::from_value(minimal_mcp_json.clone()).unwrap();
assert_eq!(deserialized_minimal.name, "minimal-server");
assert_eq!(deserialized_minimal.server_type, McpServerType::Url);
assert_eq!(deserialized_minimal.url, "https://minimal.com/mcp");
assert!(deserialized_minimal.authorization_token.is_none());
assert!(deserialized_minimal.tool_configuration.is_none());
let serialized_minimal_json = serde_json::to_value(&deserialized_minimal).unwrap();
assert_eq!(minimal_mcp_json, serialized_minimal_json);
}
#[test]
fn test_anthropic_response_types() {
// Test MessagesResponse deserialization
let response_json = json!({
"id": "msg_01ABC123",
"type": "message",
"role": "assistant",
"content": [
{
"type": "text",
"text": "Hello! How can I help you today?"
}
],
"model": "claude-3-sonnet-20240229",
"stop_reason": "end_turn",
"usage": {
"input_tokens": 10,
"output_tokens": 25,
"cache_creation_input_tokens": 5,
"cache_read_input_tokens": 3
}
});
let deserialized_response: MessagesResponse = serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.id, "msg_01ABC123");
assert_eq!(deserialized_response.obj_type, "message");
assert_eq!(deserialized_response.role, MessagesRole::Assistant);
assert_eq!(deserialized_response.model, "claude-3-sonnet-20240229");
assert_eq!(deserialized_response.stop_reason, MessagesStopReason::EndTurn);
assert!(deserialized_response.stop_sequence.is_none());
assert!(deserialized_response.container.is_none());
// Check content
assert_eq!(deserialized_response.content.len(), 1);
if let MessagesContentBlock::Text { text } = &deserialized_response.content[0] {
assert_eq!(text, "Hello! How can I help you today?");
} else {
panic!("Expected text content block");
}
// Check usage
assert_eq!(deserialized_response.usage.input_tokens, 10);
assert_eq!(deserialized_response.usage.output_tokens, 25);
assert_eq!(deserialized_response.usage.cache_creation_input_tokens, Some(5));
assert_eq!(deserialized_response.usage.cache_read_input_tokens, Some(3));
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
assert_eq!(response_json, serialized_response_json);
// Test streaming event
let stream_event_json = json!({
"type": "content_block_delta",
"index": 0,
"delta": {
"type": "text_delta",
"text": " How"
}
});
let deserialized_event: MessagesStreamEvent = serde_json::from_value(stream_event_json.clone()).unwrap();
if let MessagesStreamEvent::ContentBlockDelta { index, ref delta } = deserialized_event {
assert_eq!(index, 0);
if let MessagesContentDelta::TextDelta { text } = delta {
assert_eq!(text, " How");
} else {
panic!("Expected text delta");
}
} else {
panic!("Expected content block delta event");
}
let serialized_event_json = serde_json::to_value(&deserialized_event).unwrap();
assert_eq!(stream_event_json, serialized_event_json);
}
#[test]
fn test_anthropic_tool_use_content() {
// Test tool use and tool result content blocks
let tool_use_json = json!({
"type": "tool_use",
"id": "toolu_01ABC123",
"name": "get_weather",
"input": {
"location": "San Francisco, CA"
}
});
let deserialized_tool_use: MessagesContentBlock = serde_json::from_value(tool_use_json.clone()).unwrap();
if let MessagesContentBlock::ToolUse { ref id, ref name, ref input } = deserialized_tool_use {
assert_eq!(id, "toolu_01ABC123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "San Francisco, CA");
} else {
panic!("Expected tool use content block");
}
let serialized_tool_use_json = serde_json::to_value(&deserialized_tool_use).unwrap();
assert_eq!(tool_use_json, serialized_tool_use_json);
// Test tool result content block
let tool_result_json = json!({
"type": "tool_result",
"tool_use_id": "toolu_01ABC123",
"content": [
{
"type": "text",
"text": "The weather in San Francisco is sunny, 72°F"
}
]
});
let deserialized_tool_result: MessagesContentBlock = serde_json::from_value(tool_result_json.clone()).unwrap();
if let MessagesContentBlock::ToolResult { ref tool_use_id, ref is_error, ref content } = deserialized_tool_result {
assert_eq!(tool_use_id, "toolu_01ABC123");
assert!(is_error.is_none());
assert_eq!(content.len(), 1);
if let MessagesContentBlock::Text { text } = &content[0] {
assert_eq!(text, "The weather in San Francisco is sunny, 72°F");
} else {
panic!("Expected text content in tool result");
}
} else {
panic!("Expected tool result content block");
}
let serialized_tool_result_json = serde_json::to_value(&deserialized_tool_result).unwrap();
assert_eq!(tool_result_json, serialized_tool_result_json);
}
#[test]
fn test_anthropic_api_provider_trait_implementation() {
// Test that AnthropicApi implements ApiDefinition trait correctly
let api = AnthropicApi::Messages;
// Test trait methods
assert_eq!(api.endpoint(), "/v1/messages");
assert!(api.supports_streaming());
assert!(api.supports_tools());
assert!(api.supports_vision());
// Test from_endpoint trait method
let found_api = AnthropicApi::from_endpoint("/v1/messages");
assert_eq!(found_api, Some(AnthropicApi::Messages));
let not_found = AnthropicApi::from_endpoint("/v1/unknown");
assert_eq!(not_found, None);
// Test all_variants
let all_variants = AnthropicApi::all_variants();
assert_eq!(all_variants.len(), 1);
assert_eq!(all_variants[0], AnthropicApi::Messages);
}
}

View file

@ -0,0 +1,197 @@
pub mod anthropic;
pub mod openai;
// Re-export all types for convenience
pub use anthropic::*;
pub use openai::*;
/// Common trait that all API definitions must implement
///
/// This trait ensures consistency across different AI provider API definitions
/// and makes it easy to add new providers like Gemini, Claude, etc.
///
/// Note: This is different from the `ApiProvider` enum in `clients::endpoints`
/// which represents provider identification, while this trait defines API capabilities.
///
/// # Benefits
///
/// - **Consistency**: All API providers implement the same interface
/// - **Extensibility**: Easy to add new providers without breaking existing code
/// - **Type Safety**: Compile-time guarantees that all providers implement required methods
/// - **Discoverability**: Clear documentation of what capabilities each API supports
///
/// # Example implementation for a new provider:
///
/// ```rust,ignore
/// use serde::{Deserialize, Serialize};
/// use super::ApiDefinition;
///
/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
/// pub enum GeminiApi {
/// GenerateContent,
/// ChatCompletions,
/// }
///
/// impl GeminiApi {
/// pub fn endpoint(&self) -> &'static str {
/// match self {
/// GeminiApi::GenerateContent => "/v1/models/gemini-pro:generateContent",
/// GeminiApi::ChatCompletions => "/v1/models/gemini-pro:chat",
/// }
/// }
///
/// pub fn from_endpoint(endpoint: &str) -> Option<Self> {
/// match endpoint {
/// "/v1/models/gemini-pro:generateContent" => Some(GeminiApi::GenerateContent),
/// "/v1/models/gemini-pro:chat" => Some(GeminiApi::ChatCompletions),
/// _ => None,
/// }
/// }
///
/// pub fn supports_streaming(&self) -> bool {
/// match self {
/// GeminiApi::GenerateContent => true,
/// GeminiApi::ChatCompletions => true,
/// }
/// }
///
/// pub fn supports_tools(&self) -> bool {
/// match self {
/// GeminiApi::GenerateContent => true,
/// GeminiApi::ChatCompletions => false,
/// }
/// }
///
/// pub fn supports_vision(&self) -> bool {
/// match self {
/// GeminiApi::GenerateContent => true,
/// GeminiApi::ChatCompletions => false,
/// }
/// }
/// }
///
/// impl ApiDefinition for GeminiApi {
/// fn endpoint(&self) -> &'static str {
/// self.endpoint()
/// }
///
/// fn from_endpoint(endpoint: &str) -> Option<Self> {
/// Self::from_endpoint(endpoint)
/// }
///
/// fn supports_streaming(&self) -> bool {
/// self.supports_streaming()
/// }
///
/// fn supports_tools(&self) -> bool {
/// self.supports_tools()
/// }
///
/// fn supports_vision(&self) -> bool {
/// self.supports_vision()
/// }
/// }
///
/// // Now you can use generic code that works with any API:
/// fn print_api_info<T: ApiDefinition>(api: &T) {
/// println!("Endpoint: {}", api.endpoint());
/// println!("Supports streaming: {}", api.supports_streaming());
/// println!("Supports tools: {}", api.supports_tools());
/// println!("Supports vision: {}", api.supports_vision());
/// }
///
/// // Works with both OpenAI and Anthropic (and future Gemini)
/// print_api_info(&OpenAIApi::ChatCompletions);
/// print_api_info(&AnthropicApi::Messages);
/// print_api_info(&GeminiApi::GenerateContent);
/// ```
pub trait ApiDefinition {
/// Returns the endpoint path for this API
fn endpoint(&self) -> &'static str;
/// Creates an API instance from an endpoint path
fn from_endpoint(endpoint: &str) -> Option<Self>
where
Self: Sized;
/// Returns whether this API supports streaming responses
fn supports_streaming(&self) -> bool;
/// Returns whether this API supports tool/function calling
fn supports_tools(&self) -> bool;
/// Returns whether this API supports vision/image processing
fn supports_vision(&self) -> bool;
/// Returns all variants of this API enum
fn all_variants() -> Vec<Self>
where
Self: Sized;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generic_api_functionality() {
// Test that our generic API functionality works with both providers
fn test_api<T: ApiDefinition>(api: &T) {
let endpoint = api.endpoint();
assert!(!endpoint.is_empty());
assert!(endpoint.starts_with('/'));
}
test_api(&OpenAIApi::ChatCompletions);
test_api(&AnthropicApi::Messages);
}
#[test]
fn test_api_detection_from_endpoints() {
// Test that we can detect APIs from endpoints using the trait
let endpoints = vec![
"/v1/chat/completions",
"/v1/messages",
"/v1/unknown"
];
let mut detected_apis = Vec::new();
for endpoint in endpoints {
if let Some(api) = OpenAIApi::from_endpoint(endpoint) {
detected_apis.push(format!("OpenAI: {:?}", api));
} else if let Some(api) = AnthropicApi::from_endpoint(endpoint) {
detected_apis.push(format!("Anthropic: {:?}", api));
} else {
detected_apis.push("Unknown API".to_string());
}
}
assert_eq!(detected_apis, vec![
"OpenAI: ChatCompletions",
"Anthropic: Messages",
"Unknown API"
]);
}
#[test]
fn test_all_variants_method() {
// Test that all_variants returns the expected variants
let openai_variants = OpenAIApi::all_variants();
assert_eq!(openai_variants.len(), 1);
assert!(openai_variants.contains(&OpenAIApi::ChatCompletions));
let anthropic_variants = AnthropicApi::all_variants();
assert_eq!(anthropic_variants.len(), 1);
assert!(anthropic_variants.contains(&AnthropicApi::Messages));
// Verify each variant has a valid endpoint
for variant in openai_variants {
assert!(!variant.endpoint().is_empty());
}
for variant in anthropic_variants {
assert!(!variant.endpoint().is_empty());
}
}
}

View file

@ -0,0 +1,883 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use super::ApiDefinition;
// ============================================================================
// OPENAI API ENUMERATION
// ============================================================================
/// Enum for all supported OpenAI APIs
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum OpenAIApi {
ChatCompletions,
// Future APIs can be added here:
// Embeddings,
// FineTuning,
// etc.
}
impl ApiDefinition for OpenAIApi {
fn endpoint(&self) -> &'static str {
match self {
OpenAIApi::ChatCompletions => "/v1/chat/completions",
}
}
fn from_endpoint(endpoint: &str) -> Option<Self> {
match endpoint {
"/v1/chat/completions" => Some(OpenAIApi::ChatCompletions),
_ => None,
}
}
fn supports_streaming(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
}
}
fn supports_tools(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
}
}
fn supports_vision(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
}
}
fn all_variants() -> Vec<Self> {
vec![
OpenAIApi::ChatCompletions,
]
}
}
/// Chat completions API request
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ChatCompletionsRequest {
pub messages: Vec<Message>,
pub model: String,
// pub audio: Option<Audio> // GOOD FIRST ISSUE: future support for audio input
pub frequency_penalty: Option<f32>,
// Function calling configuration has been deprecated, but we keep it for compatibility
pub function_call: Option<FunctionChoice>,
pub functions: Option<Vec<Tool>>,
pub logit_bias: Option<HashMap<String, i32>>,
pub logprobs: Option<bool>,
pub max_completion_tokens: Option<u32>,
// Maximum tokens in the response has been deprecated, but we keep it for compatibility
pub max_tokens: Option<u32>,
pub modalities: Option<Vec<String>>,
pub metadata: Option<HashMap<String, String>>,
pub n: Option<u32>,
pub presence_penalty: Option<f32>,
pub parallel_tool_calls: Option<bool>,
pub prediction: Option<StaticContent>,
// pub reasoning_effect: Option<bool>, // GOOD FIRST ISSUE: Future support for reasoning effects
pub response_format: Option<Value>,
// pub safety_identifier: Option<String>, // GOOD FIRST ISSUE: Future support for safety identifiers
pub seed: Option<i32>,
pub service_tier: Option<String>,
pub stop: Option<Vec<String>>,
pub store: Option<bool>,
pub stream: Option<bool>,
pub stream_options: Option<StreamOptions>,
pub temperature: Option<f32>,
pub tool_choice: Option<ToolChoice>,
pub tools: Option<Vec<Tool>>,
pub top_p: Option<f32>,
pub top_logprobs: Option<u32>,
pub user: Option<String>,
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
}
// ============================================================================
// CHAT COMPLETIONS API TYPES
// ============================================================================
/// Message role in a chat conversation
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub content: MessageContent,
pub role: Role,
pub name: Option<String>,
/// Tool calls made by the assistant (only present for assistant role)
pub tool_calls: Option<Vec<ToolCall>>,
/// ID of the tool call that this message is responding to (only present for tool role)
pub tool_call_id: Option<String>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ResponseMessage {
pub role: Role,
/// The contents of the message (can be null for some cases)
pub content: Option<String>,
/// The refusal message generated by the model
pub refusal: Option<String>,
/// Annotations for the message, when applicable, as when using the web search tool
pub annotations: Option<Vec<Value>>,
/// If the audio output modality is requested, this object contains data about the audio response
pub audio: Option<Value>,
/// Deprecated and replaced by tool_calls. The name and arguments of a function that should be called
pub function_call: Option<FunctionCall>,
/// The tool calls generated by the model, such as function calls
pub tool_calls: Option<Vec<ToolCall>>,
}
impl ResponseMessage {
/// Convert ResponseMessage to Message for internal processing
/// This is useful for transformations that need to work with the request Message type
pub fn to_message(&self) -> Message {
Message {
role: self.role.clone(),
content: self.content.as_ref()
.map(|s| MessageContent::Text(s.clone()))
.unwrap_or(MessageContent::Text(String::new())),
name: None, // Response messages don't have names in the same way request messages do
tool_calls: self.tool_calls.clone(),
tool_call_id: None, // Response messages don't have tool_call_id
}
}
}
/// In the OpenAI API, this is represented as either:
/// - A string for simple text content
/// - An array of content parts for multimodal content (text + images)
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
/// Individual content part within a message (text or image)
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
/// Image URL configuration for vision capabilities
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ImageUrl {
pub url: String,
pub detail: Option<String>,
}
/// A single message in a chat conversation
/// A tool call made by the assistant
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
/// Function call within a tool call
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
/// Tool definition for function calling
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: Function,
}
/// Function definition within a tool
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Function {
pub name: String,
pub description: Option<String>,
pub parameters: Value,
pub strict: Option<bool>,
}
/// Tool choice string values
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoiceType {
/// Let the model automatically decide whether to call tools
Auto,
/// Force the model to call at least one tool
Required,
/// Prevent the model from calling any tools
None,
}
/// Tool choice configuration
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged)]
pub enum ToolChoice {
/// String-based tool choice (auto, required, none)
Type(ToolChoiceType),
/// Specific function to call
Function {
#[serde(rename = "type")]
choice_type: String,
function: FunctionChoice,
},
}
/// Specific function choice
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionChoice {
pub name: String,
}
/// Static content for prediction/prefill functionality
///
/// Static predicted output content, such as the content of a text file
/// that is being regenerated.
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StaticContent {
/// The type of the predicted content you want to provide.
/// This type is currently always "content".
#[serde(rename = "type")]
pub content_type: String,
/// The content that should be matched when generating a model response.
/// If generated tokens would match this content, the entire model response
/// can be returned much more quickly.
///
/// Can be either:
/// - A string for simple text content
/// - An array of content parts for structured content
pub content: StaticContentType,
}
/// Content type for static/predicted content
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum StaticContentType {
/// Simple text content - the content used for a Predicted Output.
/// This is often the text of a file you are regenerating with minor changes.
Text(String),
/// An array of content parts with a defined type.
/// Can contain text inputs and other supported content types.
Parts(Vec<ContentPart>),
}
/// Chat completions API response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatCompletionsResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
pub system_fingerprint: Option<String>,
}
/// Finish reason for completion
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ToolCalls,
ContentFilter,
FunctionCall, // Legacy
}
/// Token usage information
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub prompt_tokens_details: Option<PromptTokensDetails>,
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
/// Detailed breakdown of prompt tokens
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PromptTokensDetails {
pub cached_tokens: Option<u32>,
pub audio_tokens: Option<u32>,
}
/// Detailed breakdown of completion tokens
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
pub audio_tokens: Option<u32>,
pub accepted_prediction_tokens: Option<u32>,
pub rejected_prediction_tokens: Option<u32>,
}
/// A single choice in the response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Choice {
pub index: u32,
pub message: ResponseMessage,
pub finish_reason: Option<FinishReason>,
pub logprobs: Option<Value>,
}
// ============================================================================
// STREAMING API TYPES
// ============================================================================
/// Streaming response from chat completions API
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatCompletionsStreamResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<StreamChoice>,
pub usage: Option<Usage>, // Only in final chunk
pub system_fingerprint: Option<String>,
/// Specifies the processing type used for serving the request
pub service_tier: Option<String>,
}
/// A choice in a streaming response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamChoice {
pub index: u32,
pub delta: MessageDelta,
pub finish_reason: Option<FinishReason>,
pub logprobs: Option<Value>,
}
/// Message delta for streaming updates
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MessageDelta {
pub role: Option<Role>,
pub content: Option<String>,
/// The refusal message generated by the model
pub refusal: Option<String>,
/// Deprecated and replaced by tool_calls. The name and arguments of a function that should be called
pub function_call: Option<FunctionCall>,
pub tool_calls: Option<Vec<ToolCallDelta>>,
}
/// Tool call delta for streaming tool call updates
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCallDelta {
pub index: u32,
pub id: Option<String>,
#[serde(rename = "type")]
pub call_type: Option<String>,
pub function: Option<FunctionCallDelta>,
}
/// Function call delta for streaming function call updates
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct FunctionCallDelta {
pub name: Option<String>,
pub arguments: Option<String>,
}
/// Stream options for controlling streaming behavior
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StreamOptions {
pub include_usage: Option<bool>,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_required_fields() {
// Create a JSON object with only required fields
let original_json = json!({
"model": "gpt-4",
"messages": [
{
"content": "Hello, world!",
"role": "user"
}
]
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields are properly set
assert_eq!(deserialized_request.model, "gpt-4");
assert_eq!(deserialized_request.messages.len(), 1);
let message = &deserialized_request.messages[0];
assert_eq!(message.role, Role::User);
if let MessageContent::Text(content) = &message.content {
assert_eq!(content, "Hello, world!");
} else {
panic!("Expected text content");
}
// Serialize the ChatCompletionsRequest back to JSON
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
assert_eq!(original_json, serialized_json);
}
#[test]
fn test_optional_fields_serialization() {
// Create a JSON object with optional fields set
let original_json = json!({
"model": "gpt-4",
"messages": [
{
"content": "Test message",
"role": "user",
"name": "test_user"
}
],
"temperature": 0.7,
"max_tokens": 150,
"stream": true,
"stream_options": {
"include_usage": true
},
"metadata": {
"user_id": "123"
}
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate required fields
assert_eq!(deserialized_request.model, "gpt-4");
assert_eq!(deserialized_request.messages.len(), 1);
let message = &deserialized_request.messages[0];
assert_eq!(message.role, Role::User);
if let MessageContent::Text(content) = &message.content {
assert_eq!(content, "Test message");
} else {
panic!("Expected text content");
}
assert_eq!(message.name, Some("test_user".to_string()));
// Validate optional fields are properly set
assert!((deserialized_request.temperature.unwrap() - 0.7).abs() < 1e-6);
assert_eq!(deserialized_request.max_tokens, Some(150));
assert_eq!(deserialized_request.stream, Some(true));
assert!(deserialized_request.stream_options.is_some());
assert!(deserialized_request.metadata.is_some());
// Validate fields not in JSON are None
assert!(deserialized_request.top_p.is_none());
assert!(deserialized_request.frequency_penalty.is_none());
assert!(deserialized_request.presence_penalty.is_none());
assert!(deserialized_request.stop.is_none());
assert!(deserialized_request.tools.is_none());
// Serialize back to JSON and compare (handle floating point precision)
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
// Compare all fields except temperature which needs floating point comparison
assert_eq!(serialized_json["model"], original_json["model"]);
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["stream"], original_json["stream"]);
assert_eq!(serialized_json["stream_options"], original_json["stream_options"]);
assert_eq!(serialized_json["metadata"], original_json["metadata"]);
// Handle temperature with floating point tolerance
let original_temp = original_json["temperature"].as_f64().unwrap();
let serialized_temp = serialized_json["temperature"].as_f64().unwrap();
assert!((original_temp - serialized_temp).abs() < 1e-6);
}
#[test]
fn test_nested_types_serialization() {
// Create a comprehensive JSON object with nested types - a ChatCompletionsRequest with complex message content and tools
let original_json = json!({
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What can you see in this image and what's the weather like in the location shown?"
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/cityscape.jpg",
"detail": "high"
}
}
]
},
{
"role": "assistant",
"content": "I can see a beautiful cityscape. Let me check the weather for you.",
"tool_calls": [
{
"id": "call_weather123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\": \"New York, NY\"}"
}
}
]
},
{
"role": "tool",
"content": "Current weather in New York: 72°F, sunny",
"tool_call_id": "call_weather123"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather information for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
},
"required": ["location"]
},
"strict": true
}
}
],
"tool_choice": "auto",
"temperature": 0.7,
"max_tokens": 1000,
"prediction": {
"type": "content",
"content": "Based on the image analysis and weather data, I can provide you with comprehensive information."
}
});
// Deserialize JSON into ChatCompletionsRequest
let deserialized_request: ChatCompletionsRequest = serde_json::from_value(original_json.clone()).unwrap();
// Validate top-level fields
assert_eq!(deserialized_request.model, "gpt-4-vision-preview");
assert_eq!(deserialized_request.messages.len(), 3);
assert!((deserialized_request.temperature.unwrap() - 0.7).abs() < 1e-6);
assert_eq!(deserialized_request.max_tokens, Some(1000));
// Validate first message (user with multimodal content)
let user_message = &deserialized_request.messages[0];
assert_eq!(user_message.role, Role::User);
if let MessageContent::Parts(ref content_parts) = user_message.content {
assert_eq!(content_parts.len(), 2);
// Validate text content part
if let ContentPart::Text { text } = &content_parts[0] {
assert_eq!(text, "What can you see in this image and what's the weather like in the location shown?");
} else {
panic!("Expected text content part");
}
// Validate image URL content part
if let ContentPart::ImageUrl { ref image_url } = content_parts[1] {
assert_eq!(image_url.url, "https://example.com/cityscape.jpg");
assert_eq!(image_url.detail, Some("high".to_string()));
} else {
panic!("Expected image URL content part");
}
} else {
panic!("Expected multimodal content parts for user message");
}
// Validate second message (assistant with tool calls)
let assistant_message = &deserialized_request.messages[1];
assert_eq!(assistant_message.role, Role::Assistant);
if let MessageContent::Text(text) = &assistant_message.content {
assert_eq!(text, "I can see a beautiful cityscape. Let me check the weather for you.");
} else {
panic!("Expected text content for assistant message");
}
// Validate tool calls in assistant message
assert!(assistant_message.tool_calls.is_some());
let tool_calls = assistant_message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
assert_eq!(tool_call.id, "call_weather123");
assert_eq!(tool_call.call_type, "function");
assert_eq!(tool_call.function.name, "get_weather");
assert_eq!(tool_call.function.arguments, "{\"location\": \"New York, NY\"}");
// Validate third message (tool response)
let tool_message = &deserialized_request.messages[2];
assert_eq!(tool_message.role, Role::Tool);
if let MessageContent::Text(text) = &tool_message.content {
assert_eq!(text, "Current weather in New York: 72°F, sunny");
} else {
panic!("Expected text content for tool message");
}
assert_eq!(tool_message.tool_call_id, Some("call_weather123".to_string()));
// Validate tools array
assert!(deserialized_request.tools.is_some());
let tools = deserialized_request.tools.as_ref().unwrap();
assert_eq!(tools.len(), 1);
let tool = &tools[0];
assert_eq!(tool.tool_type, "function");
assert_eq!(tool.function.name, "get_weather");
assert_eq!(tool.function.description, Some("Get current weather information for a location".to_string()));
assert_eq!(tool.function.strict, Some(true));
// Validate tool parameters schema
let parameters = &tool.function.parameters;
assert_eq!(parameters["type"], "object");
assert!(parameters["properties"]["location"].is_object());
assert_eq!(parameters["required"], json!(["location"]));
// Validate tool choice
if let Some(ToolChoice::Type(choice)) = &deserialized_request.tool_choice {
assert_eq!(choice, &ToolChoiceType::Auto);
} else {
panic!("Expected auto tool choice");
}
// Validate prediction
assert!(deserialized_request.prediction.is_some());
let prediction = deserialized_request.prediction.as_ref().unwrap();
assert_eq!(prediction.content_type, "content");
if let StaticContentType::Text(text) = &prediction.content {
assert_eq!(text, "Based on the image analysis and weather data, I can provide you with comprehensive information.");
} else {
panic!("Expected text prediction content");
}
// Serialize back to JSON and compare (handle floating point precision)
let serialized_json = serde_json::to_value(&deserialized_request).unwrap();
// Compare all fields except floating point ones
assert_eq!(serialized_json["model"], original_json["model"]);
assert_eq!(serialized_json["messages"], original_json["messages"]);
assert_eq!(serialized_json["max_tokens"], original_json["max_tokens"]);
assert_eq!(serialized_json["tools"], original_json["tools"]);
assert_eq!(serialized_json["tool_choice"], original_json["tool_choice"]);
assert_eq!(serialized_json["prediction"], original_json["prediction"]);
// Handle floating point field with tolerance
let original_temp = original_json["temperature"].as_f64().unwrap();
let serialized_temp = serialized_json["temperature"].as_f64().unwrap();
assert!((original_temp - serialized_temp).abs() < 1e-6);
}
#[test]
fn test_api_provider_trait() {
// Test the ApiDefinition trait implementation
let api = OpenAIApi::ChatCompletions;
// Test trait methods
assert_eq!(api.endpoint(), "/v1/chat/completions");
assert!(api.supports_streaming());
assert!(api.supports_tools());
assert!(api.supports_vision());
// Test from_endpoint
let found_api = OpenAIApi::from_endpoint("/v1/chat/completions");
assert_eq!(found_api, Some(OpenAIApi::ChatCompletions));
let not_found = OpenAIApi::from_endpoint("/v1/unknown");
assert_eq!(not_found, None);
// Test all_variants
let all_variants = OpenAIApi::all_variants();
assert_eq!(all_variants.len(), 1);
assert_eq!(all_variants[0], OpenAIApi::ChatCompletions);
}
#[test]
fn test_role_specific_behavior() {
// Test 1: User message - basic content, no tool-related fields
let user_json = json!({
"content": "Hello!",
"role": "user",
"name": "user123"
});
let deserialized_user: Message = serde_json::from_value(user_json.clone()).unwrap();
assert_eq!(deserialized_user.role, Role::User);
if let MessageContent::Text(content) = &deserialized_user.content {
assert_eq!(content, "Hello!");
} else {
panic!("Expected text content");
}
assert_eq!(deserialized_user.name, Some("user123".to_string()));
assert!(deserialized_user.tool_calls.is_none());
assert!(deserialized_user.tool_call_id.is_none());
let serialized_user_json = serde_json::to_value(&deserialized_user).unwrap();
assert_eq!(user_json, serialized_user_json);
// Test 2: Assistant message with tool calls
let assistant_json = json!({
"content": "I'll help with that.",
"role": "assistant",
"tool_calls": [
{
"id": "call_456",
"type": "function",
"function": {
"name": "get_weather",
"arguments": r#"{"location":"SF"}"#
}
}
]
});
let deserialized_assistant: Message = serde_json::from_value(assistant_json.clone()).unwrap();
assert_eq!(deserialized_assistant.role, Role::Assistant);
if let MessageContent::Text(content) = &deserialized_assistant.content {
assert_eq!(content, "I'll help with that.");
} else {
panic!("Expected text content");
}
assert!(deserialized_assistant.tool_calls.is_some());
assert!(deserialized_assistant.tool_call_id.is_none());
assert!(deserialized_assistant.name.is_none());
let tool_calls = deserialized_assistant.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "call_456");
assert_eq!(tool_calls[0].function.name, "get_weather");
let serialized_assistant_json = serde_json::to_value(&deserialized_assistant).unwrap();
assert_eq!(assistant_json, serialized_assistant_json);
// Test 3: Tool message responding to a call
let tool_json = json!({
"content": "Weather is sunny",
"role": "tool",
"tool_call_id": "call_456"
});
let deserialized_tool: Message = serde_json::from_value(tool_json.clone()).unwrap();
assert_eq!(deserialized_tool.role, Role::Tool);
if let MessageContent::Text(content) = &deserialized_tool.content {
assert_eq!(content, "Weather is sunny");
} else {
panic!("Expected text content");
}
assert_eq!(deserialized_tool.tool_call_id, Some("call_456".to_string()));
assert!(deserialized_tool.tool_calls.is_none());
assert!(deserialized_tool.name.is_none());
let serialized_tool_json = serde_json::to_value(&deserialized_tool).unwrap();
assert_eq!(tool_json, serialized_tool_json);
// Test 4: ResponseMessage vs Message differences
let response_json = json!({
"role": "assistant",
"content": "Response content",
"annotations": [
{"type": "citation"}
]
});
let deserialized_response: ResponseMessage = serde_json::from_value(response_json.clone()).unwrap();
assert_eq!(deserialized_response.role, Role::Assistant);
assert_eq!(deserialized_response.content, Some("Response content".to_string()));
assert!(deserialized_response.annotations.is_some());
assert!(deserialized_response.refusal.is_none());
assert!(deserialized_response.function_call.is_none());
assert!(deserialized_response.tool_calls.is_none());
let serialized_response_json = serde_json::to_value(&deserialized_response).unwrap();
assert_eq!(response_json, serialized_response_json);
// Test conversion from ResponseMessage to Message
let converted = deserialized_response.to_message();
assert_eq!(converted.role, Role::Assistant);
if let MessageContent::Text(text) = converted.content {
assert_eq!(text, "Response content");
} else {
panic!("Expected text content");
}
assert!(converted.name.is_none());
assert!(converted.tool_call_id.is_none());
}
#[test]
fn test_tool_choice_type_serialization() {
// Test that the enum serializes to the correct string values
let auto_choice = ToolChoice::Type(ToolChoiceType::Auto);
let required_choice = ToolChoice::Type(ToolChoiceType::Required);
let none_choice = ToolChoice::Type(ToolChoiceType::None);
let auto_json = serde_json::to_value(&auto_choice).unwrap();
let required_json = serde_json::to_value(&required_choice).unwrap();
let none_json = serde_json::to_value(&none_choice).unwrap();
assert_eq!(auto_json, "auto");
assert_eq!(required_json, "required");
assert_eq!(none_json, "none");
// Test deserialization from string values
let auto_deserialized: ToolChoice = serde_json::from_value(json!("auto")).unwrap();
let required_deserialized: ToolChoice = serde_json::from_value(json!("required")).unwrap();
let none_deserialized: ToolChoice = serde_json::from_value(json!("none")).unwrap();
assert_eq!(auto_deserialized, ToolChoice::Type(ToolChoiceType::Auto));
assert_eq!(required_deserialized, ToolChoice::Type(ToolChoiceType::Required));
assert_eq!(none_deserialized, ToolChoice::Type(ToolChoiceType::None));
// Test that invalid string values fail deserialization (type safety!)
let invalid_result: Result<ToolChoice, _> = serde_json::from_value(json!("invalid"));
assert!(invalid_result.is_err());
}
}

View file

@ -0,0 +1,130 @@
//! Supported endpoint registry for LLM APIs
//!
//! This module provides a simple registry to check which API endpoint paths
//! we support across different providers.
//!
//! # Examples
//!
//! ```rust
//! use hermesllm::clients::endpoints::{is_supported_endpoint, supported_endpoints};
//!
//! // Check if we support an endpoint
//! assert!(is_supported_endpoint("/v1/chat/completions"));
//! assert!(is_supported_endpoint("/v1/messages"));
//! assert!(!is_supported_endpoint("/v1/unknown"));
//!
//! // Get all supported endpoints
//! let endpoints = supported_endpoints();
//! assert_eq!(endpoints.len(), 2);
//! assert!(endpoints.contains(&"/v1/chat/completions"));
//! assert!(endpoints.contains(&"/v1/messages"));
//! ```
use crate::apis::{AnthropicApi, OpenAIApi, ApiDefinition};
/// Check if the given endpoint path is supported
pub fn is_supported_endpoint(endpoint: &str) -> bool {
// Try OpenAI APIs
if OpenAIApi::from_endpoint(endpoint).is_some() {
return true;
}
// Try Anthropic APIs
if AnthropicApi::from_endpoint(endpoint).is_some() {
return true;
}
false
}
/// Get all supported endpoint paths
pub fn supported_endpoints() -> Vec<&'static str> {
let mut endpoints = Vec::new();
// Add all OpenAI endpoints
for api in OpenAIApi::all_variants() {
endpoints.push(api.endpoint());
}
// Add all Anthropic endpoints
for api in AnthropicApi::all_variants() {
endpoints.push(api.endpoint());
}
endpoints
}
/// Identify which provider supports a given endpoint
pub fn identify_provider(endpoint: &str) -> Option<&'static str> {
if OpenAIApi::from_endpoint(endpoint).is_some() {
return Some("openai");
}
if AnthropicApi::from_endpoint(endpoint).is_some() {
return Some("anthropic");
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_supported_endpoint() {
// OpenAI endpoints
assert!(is_supported_endpoint("/v1/chat/completions"));
// Anthropic endpoints
assert!(is_supported_endpoint("/v1/messages"));
// Unsupported endpoints
assert!(!is_supported_endpoint("/v1/unknown"));
assert!(!is_supported_endpoint("/v2/chat"));
assert!(!is_supported_endpoint(""));
}
#[test]
fn test_supported_endpoints() {
let endpoints = supported_endpoints();
assert_eq!(endpoints.len(), 2);
assert!(endpoints.contains(&"/v1/chat/completions"));
assert!(endpoints.contains(&"/v1/messages"));
}
#[test]
fn test_identify_provider() {
assert_eq!(identify_provider("/v1/chat/completions"), Some("openai"));
assert_eq!(identify_provider("/v1/messages"), Some("anthropic"));
assert_eq!(identify_provider("/v1/unknown"), None);
}
#[test]
fn test_endpoints_generated_from_api_definitions() {
let endpoints = supported_endpoints();
// Verify that we get endpoints from all API variants
let openai_endpoints: Vec<_> = OpenAIApi::all_variants()
.iter()
.map(|api| api.endpoint())
.collect();
let anthropic_endpoints: Vec<_> = AnthropicApi::all_variants()
.iter()
.map(|api| api.endpoint())
.collect();
// All OpenAI endpoints should be in the result
for endpoint in openai_endpoints {
assert!(endpoints.contains(&endpoint), "Missing OpenAI endpoint: {}", endpoint);
}
// All Anthropic endpoints should be in the result
for endpoint in anthropic_endpoints {
assert!(endpoints.contains(&endpoint), "Missing Anthropic endpoint: {}", endpoint);
}
// Total should match
assert_eq!(endpoints.len(), OpenAIApi::all_variants().len() + AnthropicApi::all_variants().len());
}
}

View file

@ -0,0 +1,33 @@
//! Helper functions and utilities for API transformations
//! Contains error types and shared utilities
use thiserror::Error;
// ============================================================================
// ERROR TYPES
// ============================================================================
#[derive(Error, Debug)]
pub enum TransformError {
#[error("JSON serialization error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Unsupported content type: {0}")]
UnsupportedContent(String),
#[error("Invalid tool input format")]
InvalidToolInput,
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Unsupported conversion: {0}")]
UnsupportedConversion(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_types() {
let error = TransformError::MissingField("test".to_string());
assert!(matches!(error, TransformError::MissingField(_)));
}
}

View file

@ -0,0 +1,9 @@
pub mod lib;
pub mod transformer;
pub mod endpoints;
// Re-export the main items for easier access
pub use lib::*;
pub use endpoints::{is_supported_endpoint, supported_endpoints, identify_provider};
// Note: transformer module contains TryFrom trait implementations that are automatically available

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,12 @@
//! hermesllm: A library for translating LLM API requests and responses
//! between Mistral, Grok, Gemini, and OpenAI-compliant formats.
use std::fmt::Display;
pub mod providers;
pub mod apis;
pub mod clients;
use std::fmt::Display;
pub enum Provider {
Arch,
Mistral,

View file

@ -0,0 +1,2 @@
pub mod providers;
pub mod clients;

View file

@ -35,9 +35,16 @@ pub enum MultiPartContentType {
ImageUrl,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ImageUrl {
pub url: String,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MultiPartContent {
pub text: Option<String>,
pub image_url: Option<ImageUrl>,
#[serde(rename = "type")]
pub content_type: MultiPartContentType,
}
@ -307,10 +314,12 @@ mod tests {
MultiPartContent {
text: Some("This is a text part.".to_string()),
content_type: MultiPartContentType::Text,
image_url: None,
},
MultiPartContent {
text: Some("https://example.com/image.png".to_string()),
content_type: MultiPartContentType::ImageUrl,
image_url: None,
},
]);
assert_eq!(multi_part_content.to_string(), "This is a text part.");
@ -364,6 +373,61 @@ mod tests {
}
}
#[test]
fn test_chat_completions_request_image_content() {
const CHAT_COMPLETIONS_REQUEST: &str = r#"
{
"stream": true,
"model": "openai/gpt-4o",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "describe this photo pls"
},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...=="
}
}
]
}
]
}"#;
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
assert_eq!(chat_completions_request.model, "openai/gpt-4o");
if let Some(ContentType::MultiPart(multi_part_content)) =
chat_completions_request.messages[0].content.as_ref()
{
assert_eq!(multi_part_content.len(), 2);
assert_eq!(
multi_part_content[0].content_type,
MultiPartContentType::Text
);
assert_eq!(
multi_part_content[0].text,
Some("describe this photo pls".to_string())
);
assert_eq!(
multi_part_content[1].content_type,
MultiPartContentType::ImageUrl
);
assert_eq!(
multi_part_content[1].image_url,
Some(ImageUrl {
url: "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...==".to_string(),
})
);
} else {
panic!("Expected MultiPartContent");
}
}
#[test]
fn test_sse_streaming() {
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}

View file

@ -1,7 +1,7 @@
{
"manifest_version": 3,
"name": "RouteGPT",
"version": "0.1.1",
"version": "0.1.2",
"description": "RouteGPT: Smart Model Routing for ChatGPT.",
"permissions": [
"storage"

View file

@ -17,7 +17,7 @@
}
// Only intercept conversation fetches
if (pathname === '/backend-api/conversation') {
if (pathname === '/backend-api/conversation' || pathname === '/backend-api/f/conversation') {
console.log(`${TAG} matched → proxy via content script`);
const { port1, port2 } = new MessageChannel();

View file

@ -12,6 +12,9 @@ llm_providers:
- access_key: $OPENAI_API_KEY
model: openai/gpt-4o-mini
- access_key: $OPENAI_API_KEY
model: openai/gpt-4.1
- access_key: $OPENAI_API_KEY
model: openai/gpt-4o
default: true

View file

@ -14,9 +14,9 @@ Make sure your machine is up to date with [latest version of archgw]([url](https
2. start archgw in the foreground
```bash
(venv) $ archgw up --service archgw --foreground
2025-05-30 18:00:09,953 - cli.main - INFO - Starting archgw cli version: 0.3.5
2025-05-30 18:00:09,953 - cli.main - INFO - Starting archgw cli version: 0.3.7
2025-05-30 18:00:09,953 - cli.main - INFO - Validating /Users/adilhafeez/src/intelligent-prompt-gateway/demos/use_cases/preference_based_routing/arch_config.yaml
2025-05-30 18:00:10,422 - cli.core - INFO - Starting arch gateway, image name: archgw, tag: katanemo/archgw:0.3.5
2025-05-30 18:00:10,422 - cli.core - INFO - Starting arch gateway, image name: archgw, tag: katanemo/archgw:0.3.7
2025-05-30 18:00:10,662 - cli.core - INFO - archgw status: running, health status: starting
2025-05-30 18:00:11,712 - cli.core - INFO - archgw status: running, health status: starting
2025-05-30 18:00:12,761 - cli.core - INFO - archgw is running and is healthy!

Binary file not shown.

Before

Width:  |  Height:  |  Size: 359 KiB

After

Width:  |  Height:  |  Size: 217 KiB

Before After
Before After

View file

@ -35,45 +35,42 @@ make outbound LLM calls.
Adding custom LLM Provider
--------------------------
We support any OpenAI compliant LLM for example mistral, openai, ollama etc. We offer first class support for openai and ollama. You can easily configure an LLM that communicates over the OpenAI API interface, by following the below guide.
We support any OpenAI compliant LLM for example mistral, openai, ollama etc. We also offer first class support for OpenAI, Anthropic, DeepSeek, Mistral, Groq, and Ollama based models.
You can easily configure an LLM that communicates over the OpenAI API interface, by following the below guide.
For example following code block shows you how to add an ollama-supported LLM in the `arch_config.yaml` file.
For example following code block shows you how to add an ollama-supported LLM in the ``arch_config.yaml`` file.
.. code-block:: yaml
- name: local-llama
provider_interface: openai
model: llama3.2
endpoint: host.docker.internal:11434
llm_providers:
- model: some_custom_llm_provider/llama3.2
provider_interface: openai
base_url: http://host.docker.internal:11434
For example following code block shows you how to add mistral llm provider in the `arch_config.yaml` file.
And in the following code block shows you how to add mistral llm provider in the ``arch_config.yaml`` file.
.. code-block:: yaml
- name: mistral-ai
provider_interface: openai
model: ministral-3b-latest
endpoint: api.mistral.ai:443
protocol: https
llm_providers:
- name: mistral/ministral-3b-latest
access_key: $MISTRAL_API_KEY
Example: Using the OpenAI Python SDK
------------------------------------
.. code-block:: python
from openai import OpenAI
from openai import OpenAI
# Initialize the Arch client
client = OpenAI(base_url="http://127.0.0.12000/")
# Initialize the Arch client
client = OpenAI(base_url="http://127.0.0.1:2000/")
# Define your LLM provider and prompt
llm_provider = "openai"
prompt = "What is the capital of France?"
# Define your model and messages
model = "llama3.2"
messages = [{"role": "user", "content": "What is the capital of France?"}]
# Send the prompt to the LLM through Arch
response = client.completions.create(llm_provider=llm_provider, prompt=prompt)
# Send the messages to the LLM through Arch
response = client.chat.completions.create(model=model, messages=messages)
# Print the response
print("LLM Response:", response)
# Print the response
print("LLM Response:", response.choices[0].message.content)

View file

@ -15,7 +15,7 @@ from sphinxawesome_theme.postprocess import Icons
project = "Arch Docs"
copyright = "2025, Katanemo Labs, Inc"
author = "Katanemo Labs, Inc"
release = " v0.3.5"
release = " v0.3.7"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

View file

@ -20,7 +20,7 @@ Arch is designed to solve these problems by providing a unified, out-of-process
High-level network flow of where Arch Gateway sits in your agentic stack. Designed for both ingress and egress prompt traffic.
Arch is an AI-native proxy server and the universal data plane for AI built by the contributors of Envoy Proxy with the belief that:
`Arch <https://github.com/katanemo/arch>`_ is a smart edge and AI gateway for AI-native apps - built by the contributors of Envoy Proxy with the belief that:
*Prompts are nuanced and opaque user requests, which require the same capabilities as traditional HTTP requests
including secure handling, intelligent routing, robust observability, and integration with backend (API)

View file

@ -3,9 +3,9 @@
Overview
============
`Arch <https://github.com/katanemo/arch>`_ is an AI-native proxy server and the universal data plane for AI - one that is natively designed to handle and process AI prompts, not just network traffic.
`Arch <https://github.com/katanemo/arch>`_ is a smart edge and AI gateway for AI-native apps - one that is natively designed to handle and process prompts, not just network traffic.
Built by contributors to the widely adopted `Envoy Proxy <https://www.envoyproxy.io/>`_, Arch helps you move faster by handling the pesky *low-level* work in AI agent development—fast input clarification, intelligent agent routing, seamless prompt-to-tool integration, and unified LLM access and observability—all without locking you into a framework.
Built by contributors to the widely adopted `Envoy Proxy <https://www.envoyproxy.io/>`_, Arch handles the *pesky low-level work* in building agentic apps — like applying guardrails, clarifying vague user input, routing prompts to the right agent, and unifying access to any LLM. Its a language and framework friendly infrastructure layer designed to help you build and ship agentic apps faster.
In this documentation, you will learn how to quickly set up Arch to trigger API calls via prompts, apply prompt guardrails without writing any application-level logic,

View file

@ -25,7 +25,7 @@ Arch's CLI allows you to manage and interact with the Arch gateway efficiently.
$ python -m venv venv
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
$ pip install archgw==0.3.5
$ pip install archgw==0.3.7
Build AI Agent with Arch Gateway

View file

@ -14,9 +14,9 @@ Welcome to Arch!
<a href="https://www.producthunt.com/posts/arch-3?embed=true&utm_source=badge-top-post-badge&utm_medium=badge&utm_souce=badge-arch&#0045;3" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/top-post-badge.svg?post_id=565761&theme=dark&period=daily&t=1742433071161" alt="Arch - Build&#0032;fast&#0044;&#0032;hyper&#0045;personalized&#0032;agents&#0032;with&#0032;intelligent&#0032;infra | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
`Arch <https://github.com/katanemo/arch>`_ is an AI-native proxy server and the universal data plane for AI - one that is natively designed to handle and process AI prompts, not just network traffic.
`Arch <https://github.com/katanemo/arch>`_ is a smart edge and AI gateway for AI-native apps - one that is natively designed to handle and process prompts, not just network traffic.
Built by contributors to the widely adopted `Envoy Proxy <https://www.envoyproxy.io/>`_, Arch helps you move faster by handling the pesky *low-level* work in AI agent development—fast input clarification, intelligent agent routing, seamless prompt-to-tool integration, and unified LLM access and observability—all without locking you into a framework.
Built by contributors to the widely adopted `Envoy Proxy <https://www.envoyproxy.io/>`_, Arch handles the *pesky low-level work* in building agentic apps — like applying guardrails, clarifying vague user input, routing prompts to the right agent, and unifying access to any LLM. Its a language and framework friendly infrastructure layer designed to help you build and ship agentic apps faster.
.. tab-set::

197
model_server/poetry.lock generated
View file

@ -2,13 +2,13 @@
[[package]]
name = "accelerate"
version = "1.8.1"
version = "1.9.0"
description = "Accelerate"
optional = false
python-versions = ">=3.9.0"
files = [
{file = "accelerate-1.8.1-py3-none-any.whl", hash = "sha256:c47b8994498875a2b1286e945bd4d20e476956056c7941d512334f4eb44ff991"},
{file = "accelerate-1.8.1.tar.gz", hash = "sha256:f60df931671bc4e75077b852990469d4991ce8bd3a58e72375c3c95132034db9"},
{file = "accelerate-1.9.0-py3-none-any.whl", hash = "sha256:c24739a97ade1d54af4549a65f8b6b046adc87e2b3e4d6c66516e32c53d5a8f1"},
{file = "accelerate-1.9.0.tar.gz", hash = "sha256:0e8c61f81af7bf37195b6175a545ed292617dd90563c88f49020aea5b6a0b47f"},
]
[package.dependencies]
@ -29,7 +29,7 @@ sagemaker = ["sagemaker"]
test-dev = ["bitsandbytes", "datasets", "diffusers", "evaluate", "scikit-learn", "scipy", "timm", "torchdata (>=0.8.0)", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
test-fp8 = ["torchao"]
test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-order", "pytest-subtests", "pytest-xdist"]
test-trackers = ["comet-ml", "dvclive", "matplotlib", "mlflow", "swanlab", "tensorboard", "wandb"]
test-trackers = ["comet-ml", "dvclive", "matplotlib", "mlflow", "swanlab", "tensorboard", "trackio", "wandb"]
testing = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-order", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchdata (>=0.8.0)", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
[[package]]
@ -82,15 +82,26 @@ typing_extensions = {version = ">=4", markers = "python_version < \"3.11\""}
[package.extras]
tests = ["mypy (>=1.14.0)", "pytest", "pytest-asyncio"]
[[package]]
name = "backports-asyncio-runner"
version = "1.2.0"
description = "Backport of asyncio.Runner, a context manager that controls event loop life cycle."
optional = false
python-versions = "<3.11,>=3.8"
files = [
{file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"},
{file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"},
]
[[package]]
name = "certifi"
version = "2025.7.9"
version = "2025.7.14"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.7"
files = [
{file = "certifi-2025.7.9-py3-none-any.whl", hash = "sha256:d842783a14f8fdd646895ac26f719a061408834473cfc10203f6a575beb15d39"},
{file = "certifi-2025.7.9.tar.gz", hash = "sha256:c1d2ec05395148ee10cf672ffc28cd37ea0ab0d99f9cc74c43e588cbd111b079"},
{file = "certifi-2025.7.14-py3-none-any.whl", hash = "sha256:6b31f564a415d79ee77df69d757bb49a5bb53bd9f756cbbe24394ffd6fc1f4b2"},
{file = "certifi-2025.7.14.tar.gz", hash = "sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995"},
]
[[package]]
@ -324,13 +335,13 @@ typing = ["typing-extensions (>=4.12.2)"]
[[package]]
name = "fsspec"
version = "2025.5.1"
version = "2025.7.0"
description = "File-system specification"
optional = false
python-versions = ">=3.9"
files = [
{file = "fsspec-2025.5.1-py3-none-any.whl", hash = "sha256:24d3a2e663d5fc735ab256263c4075f374a174c3410c0b25e5bd1970bceaa462"},
{file = "fsspec-2025.5.1.tar.gz", hash = "sha256:2e55e47a540b91843b755e83ded97c6e897fa0942b11490113f09e9c443c2475"},
{file = "fsspec-2025.7.0-py3-none-any.whl", hash = "sha256:8b012e39f63c7d5f10474de957f3ab793b47b45ae7d39f2fb735f8bbe25c0e21"},
{file = "fsspec-2025.7.0.tar.gz", hash = "sha256:786120687ffa54b8283d942929540d8bc5ccfa820deb555a2b5d0ed2b737bf58"},
]
[package.extras]
@ -338,7 +349,7 @@ abfs = ["adlfs"]
adl = ["adlfs"]
arrow = ["pyarrow (>=1)"]
dask = ["dask", "distributed"]
dev = ["pre-commit", "ruff"]
dev = ["pre-commit", "ruff (>=0.5)"]
doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"]
dropbox = ["dropbox", "dropboxdrivefs", "requests"]
full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
@ -380,66 +391,66 @@ grpc = ["grpcio (>=1.44.0,<2.0.0)"]
[[package]]
name = "grpcio"
version = "1.73.1"
version = "1.74.0"
description = "HTTP/2-based RPC framework"
optional = false
python-versions = ">=3.9"
files = [
{file = "grpcio-1.73.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:2d70f4ddd0a823436c2624640570ed6097e40935c9194482475fe8e3d9754d55"},
{file = "grpcio-1.73.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:3841a8a5a66830261ab6a3c2a3dc539ed84e4ab019165f77b3eeb9f0ba621f26"},
{file = "grpcio-1.73.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:628c30f8e77e0258ab788750ec92059fc3d6628590fb4b7cea8c102503623ed7"},
{file = "grpcio-1.73.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:67a0468256c9db6d5ecb1fde4bf409d016f42cef649323f0a08a72f352d1358b"},
{file = "grpcio-1.73.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68b84d65bbdebd5926eb5c53b0b9ec3b3f83408a30e4c20c373c5337b4219ec5"},
{file = "grpcio-1.73.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c54796ca22b8349cc594d18b01099e39f2b7ffb586ad83217655781a350ce4da"},
{file = "grpcio-1.73.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:75fc8e543962ece2f7ecd32ada2d44c0c8570ae73ec92869f9af8b944863116d"},
{file = "grpcio-1.73.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6a6037891cd2b1dd1406b388660522e1565ed340b1fea2955b0234bdd941a862"},
{file = "grpcio-1.73.1-cp310-cp310-win32.whl", hash = "sha256:cce7265b9617168c2d08ae570fcc2af4eaf72e84f8c710ca657cc546115263af"},
{file = "grpcio-1.73.1-cp310-cp310-win_amd64.whl", hash = "sha256:6a2b372e65fad38842050943f42ce8fee00c6f2e8ea4f7754ba7478d26a356ee"},
{file = "grpcio-1.73.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:ba2cea9f7ae4bc21f42015f0ec98f69ae4179848ad744b210e7685112fa507a1"},
{file = "grpcio-1.73.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:d74c3f4f37b79e746271aa6cdb3a1d7e4432aea38735542b23adcabaaee0c097"},
{file = "grpcio-1.73.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:5b9b1805a7d61c9e90541cbe8dfe0a593dfc8c5c3a43fe623701b6a01b01d710"},
{file = "grpcio-1.73.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3215f69a0670a8cfa2ab53236d9e8026bfb7ead5d4baabe7d7dc11d30fda967"},
{file = "grpcio-1.73.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc5eccfd9577a5dc7d5612b2ba90cca4ad14c6d949216c68585fdec9848befb1"},
{file = "grpcio-1.73.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:dc7d7fd520614fce2e6455ba89791458020a39716951c7c07694f9dbae28e9c0"},
{file = "grpcio-1.73.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:105492124828911f85127e4825d1c1234b032cb9d238567876b5515d01151379"},
{file = "grpcio-1.73.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:610e19b04f452ba6f402ac9aa94eb3d21fbc94553368008af634812c4a85a99e"},
{file = "grpcio-1.73.1-cp311-cp311-win32.whl", hash = "sha256:d60588ab6ba0ac753761ee0e5b30a29398306401bfbceffe7d68ebb21193f9d4"},
{file = "grpcio-1.73.1-cp311-cp311-win_amd64.whl", hash = "sha256:6957025a4608bb0a5ff42abd75bfbb2ed99eda29d5992ef31d691ab54b753643"},
{file = "grpcio-1.73.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:921b25618b084e75d424a9f8e6403bfeb7abef074bb6c3174701e0f2542debcf"},
{file = "grpcio-1.73.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:277b426a0ed341e8447fbf6c1d6b68c952adddf585ea4685aa563de0f03df887"},
{file = "grpcio-1.73.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:96c112333309493c10e118d92f04594f9055774757f5d101b39f8150f8c25582"},
{file = "grpcio-1.73.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f48e862aed925ae987eb7084409a80985de75243389dc9d9c271dd711e589918"},
{file = "grpcio-1.73.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83a6c2cce218e28f5040429835fa34a29319071079e3169f9543c3fbeff166d2"},
{file = "grpcio-1.73.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:65b0458a10b100d815a8426b1442bd17001fdb77ea13665b2f7dc9e8587fdc6b"},
{file = "grpcio-1.73.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:0a9f3ea8dce9eae9d7cb36827200133a72b37a63896e0e61a9d5ec7d61a59ab1"},
{file = "grpcio-1.73.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:de18769aea47f18e782bf6819a37c1c528914bfd5683b8782b9da356506190c8"},
{file = "grpcio-1.73.1-cp312-cp312-win32.whl", hash = "sha256:24e06a5319e33041e322d32c62b1e728f18ab8c9dbc91729a3d9f9e3ed336642"},
{file = "grpcio-1.73.1-cp312-cp312-win_amd64.whl", hash = "sha256:303c8135d8ab176f8038c14cc10d698ae1db9c480f2b2823f7a987aa2a4c5646"},
{file = "grpcio-1.73.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:b310824ab5092cf74750ebd8a8a8981c1810cb2b363210e70d06ef37ad80d4f9"},
{file = "grpcio-1.73.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:8f5a6df3fba31a3485096ac85b2e34b9666ffb0590df0cd044f58694e6a1f6b5"},
{file = "grpcio-1.73.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:052e28fe9c41357da42250a91926a3e2f74c046575c070b69659467ca5aa976b"},
{file = "grpcio-1.73.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c0bf15f629b1497436596b1cbddddfa3234273490229ca29561209778ebe182"},
{file = "grpcio-1.73.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ab860d5bfa788c5a021fba264802e2593688cd965d1374d31d2b1a34cacd854"},
{file = "grpcio-1.73.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:ad1d958c31cc91ab050bd8a91355480b8e0683e21176522bacea225ce51163f2"},
{file = "grpcio-1.73.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:f43ffb3bd415c57224c7427bfb9e6c46a0b6e998754bfa0d00f408e1873dcbb5"},
{file = "grpcio-1.73.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:686231cdd03a8a8055f798b2b54b19428cdf18fa1549bee92249b43607c42668"},
{file = "grpcio-1.73.1-cp313-cp313-win32.whl", hash = "sha256:89018866a096e2ce21e05eabed1567479713ebe57b1db7cbb0f1e3b896793ba4"},
{file = "grpcio-1.73.1-cp313-cp313-win_amd64.whl", hash = "sha256:4a68f8c9966b94dff693670a5cf2b54888a48a5011c5d9ce2295a1a1465ee84f"},
{file = "grpcio-1.73.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:b4adc97d2d7f5c660a5498bda978ebb866066ad10097265a5da0511323ae9f50"},
{file = "grpcio-1.73.1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:c45a28a0cfb6ddcc7dc50a29de44ecac53d115c3388b2782404218db51cb2df3"},
{file = "grpcio-1.73.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:10af9f2ab98a39f5b6c1896c6fc2036744b5b41d12739d48bed4c3e15b6cf900"},
{file = "grpcio-1.73.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:45cf17dcce5ebdb7b4fe9e86cb338fa99d7d1bb71defc78228e1ddf8d0de8cbb"},
{file = "grpcio-1.73.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c502c2e950fc7e8bf05c047e8a14522ef7babac59abbfde6dbf46b7a0d9c71e"},
{file = "grpcio-1.73.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6abfc0f9153dc4924536f40336f88bd4fe7bd7494f028675e2e04291b8c2c62a"},
{file = "grpcio-1.73.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ed451a0e39c8e51eb1612b78686839efd1a920666d1666c1adfdb4fd51680c0f"},
{file = "grpcio-1.73.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:07f08705a5505c9b5b0cbcbabafb96462b5a15b7236bbf6bbcc6b0b91e1cbd7e"},
{file = "grpcio-1.73.1-cp39-cp39-win32.whl", hash = "sha256:ad5c958cc3d98bb9d71714dc69f1c13aaf2f4b53e29d4cc3f1501ef2e4d129b2"},
{file = "grpcio-1.73.1-cp39-cp39-win_amd64.whl", hash = "sha256:42f0660bce31b745eb9d23f094a332d31f210dcadd0fc8e5be7e4c62a87ce86b"},
{file = "grpcio-1.73.1.tar.gz", hash = "sha256:7fce2cd1c0c1116cf3850564ebfc3264fba75d3c74a7414373f1238ea365ef87"},
{file = "grpcio-1.74.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:85bd5cdf4ed7b2d6438871adf6afff9af7096486fcf51818a81b77ef4dd30907"},
{file = "grpcio-1.74.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:68c8ebcca945efff9d86d8d6d7bfb0841cf0071024417e2d7f45c5e46b5b08eb"},
{file = "grpcio-1.74.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:e154d230dc1bbbd78ad2fdc3039fa50ad7ffcf438e4eb2fa30bce223a70c7486"},
{file = "grpcio-1.74.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8978003816c7b9eabe217f88c78bc26adc8f9304bf6a594b02e5a49b2ef9c11"},
{file = "grpcio-1.74.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3d7bd6e3929fd2ea7fbc3f562e4987229ead70c9ae5f01501a46701e08f1ad9"},
{file = "grpcio-1.74.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:136b53c91ac1d02c8c24201bfdeb56f8b3ac3278668cbb8e0ba49c88069e1bdc"},
{file = "grpcio-1.74.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:fe0f540750a13fd8e5da4b3eaba91a785eea8dca5ccd2bc2ffe978caa403090e"},
{file = "grpcio-1.74.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4e4181bfc24413d1e3a37a0b7889bea68d973d4b45dd2bc68bb766c140718f82"},
{file = "grpcio-1.74.0-cp310-cp310-win32.whl", hash = "sha256:1733969040989f7acc3d94c22f55b4a9501a30f6aaacdbccfaba0a3ffb255ab7"},
{file = "grpcio-1.74.0-cp310-cp310-win_amd64.whl", hash = "sha256:9e912d3c993a29df6c627459af58975b2e5c897d93287939b9d5065f000249b5"},
{file = "grpcio-1.74.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:69e1a8180868a2576f02356565f16635b99088da7df3d45aaa7e24e73a054e31"},
{file = "grpcio-1.74.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:8efe72fde5500f47aca1ef59495cb59c885afe04ac89dd11d810f2de87d935d4"},
{file = "grpcio-1.74.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:a8f0302f9ac4e9923f98d8e243939a6fb627cd048f5cd38595c97e38020dffce"},
{file = "grpcio-1.74.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f609a39f62a6f6f05c7512746798282546358a37ea93c1fcbadf8b2fed162e3"},
{file = "grpcio-1.74.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c98e0b7434a7fa4e3e63f250456eaef52499fba5ae661c58cc5b5477d11e7182"},
{file = "grpcio-1.74.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:662456c4513e298db6d7bd9c3b8df6f75f8752f0ba01fb653e252ed4a59b5a5d"},
{file = "grpcio-1.74.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3d14e3c4d65e19d8430a4e28ceb71ace4728776fd6c3ce34016947474479683f"},
{file = "grpcio-1.74.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1bf949792cee20d2078323a9b02bacbbae002b9e3b9e2433f2741c15bdeba1c4"},
{file = "grpcio-1.74.0-cp311-cp311-win32.whl", hash = "sha256:55b453812fa7c7ce2f5c88be3018fb4a490519b6ce80788d5913f3f9d7da8c7b"},
{file = "grpcio-1.74.0-cp311-cp311-win_amd64.whl", hash = "sha256:86ad489db097141a907c559988c29718719aa3e13370d40e20506f11b4de0d11"},
{file = "grpcio-1.74.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:8533e6e9c5bd630ca98062e3a1326249e6ada07d05acf191a77bc33f8948f3d8"},
{file = "grpcio-1.74.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:2918948864fec2a11721d91568effffbe0a02b23ecd57f281391d986847982f6"},
{file = "grpcio-1.74.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:60d2d48b0580e70d2e1954d0d19fa3c2e60dd7cbed826aca104fff518310d1c5"},
{file = "grpcio-1.74.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3601274bc0523f6dc07666c0e01682c94472402ac2fd1226fd96e079863bfa49"},
{file = "grpcio-1.74.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:176d60a5168d7948539def20b2a3adcce67d72454d9ae05969a2e73f3a0feee7"},
{file = "grpcio-1.74.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e759f9e8bc908aaae0412642afe5416c9f983a80499448fcc7fab8692ae044c3"},
{file = "grpcio-1.74.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e7c4389771855a92934b2846bd807fc25a3dfa820fd912fe6bd8136026b2707"},
{file = "grpcio-1.74.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cce634b10aeab37010449124814b05a62fb5f18928ca878f1bf4750d1f0c815b"},
{file = "grpcio-1.74.0-cp312-cp312-win32.whl", hash = "sha256:885912559974df35d92219e2dc98f51a16a48395f37b92865ad45186f294096c"},
{file = "grpcio-1.74.0-cp312-cp312-win_amd64.whl", hash = "sha256:42f8fee287427b94be63d916c90399ed310ed10aadbf9e2e5538b3e497d269bc"},
{file = "grpcio-1.74.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:2bc2d7d8d184e2362b53905cb1708c84cb16354771c04b490485fa07ce3a1d89"},
{file = "grpcio-1.74.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:c14e803037e572c177ba54a3e090d6eb12efd795d49327c5ee2b3bddb836bf01"},
{file = "grpcio-1.74.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f6ec94f0e50eb8fa1744a731088b966427575e40c2944a980049798b127a687e"},
{file = "grpcio-1.74.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:566b9395b90cc3d0d0c6404bc8572c7c18786ede549cdb540ae27b58afe0fb91"},
{file = "grpcio-1.74.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1ea6176d7dfd5b941ea01c2ec34de9531ba494d541fe2057c904e601879f249"},
{file = "grpcio-1.74.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:64229c1e9cea079420527fa8ac45d80fc1e8d3f94deaa35643c381fa8d98f362"},
{file = "grpcio-1.74.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:0f87bddd6e27fc776aacf7ebfec367b6d49cad0455123951e4488ea99d9b9b8f"},
{file = "grpcio-1.74.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:3b03d8f2a07f0fea8c8f74deb59f8352b770e3900d143b3d1475effcb08eec20"},
{file = "grpcio-1.74.0-cp313-cp313-win32.whl", hash = "sha256:b6a73b2ba83e663b2480a90b82fdae6a7aa6427f62bf43b29912c0cfd1aa2bfa"},
{file = "grpcio-1.74.0-cp313-cp313-win_amd64.whl", hash = "sha256:fd3c71aeee838299c5887230b8a1822795325ddfea635edd82954c1eaa831e24"},
{file = "grpcio-1.74.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:4bc5fca10aaf74779081e16c2bcc3d5ec643ffd528d9e7b1c9039000ead73bae"},
{file = "grpcio-1.74.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:6bab67d15ad617aff094c382c882e0177637da73cbc5532d52c07b4ee887a87b"},
{file = "grpcio-1.74.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:655726919b75ab3c34cdad39da5c530ac6fa32696fb23119e36b64adcfca174a"},
{file = "grpcio-1.74.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a2b06afe2e50ebfd46247ac3ba60cac523f54ec7792ae9ba6073c12daf26f0a"},
{file = "grpcio-1.74.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f251c355167b2360537cf17bea2cf0197995e551ab9da6a0a59b3da5e8704f9"},
{file = "grpcio-1.74.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f7b5882fb50632ab1e48cb3122d6df55b9afabc265582808036b6e51b9fd6b7"},
{file = "grpcio-1.74.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:834988b6c34515545b3edd13e902c1acdd9f2465d386ea5143fb558f153a7176"},
{file = "grpcio-1.74.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:22b834cef33429ca6cc28303c9c327ba9a3fafecbf62fae17e9a7b7163cc43ac"},
{file = "grpcio-1.74.0-cp39-cp39-win32.whl", hash = "sha256:7d95d71ff35291bab3f1c52f52f474c632db26ea12700c2ff0ea0532cb0b5854"},
{file = "grpcio-1.74.0-cp39-cp39-win_amd64.whl", hash = "sha256:ecde9ab49f58433abe02f9ed076c7b5be839cf0153883a6d23995937a82392fa"},
{file = "grpcio-1.74.0.tar.gz", hash = "sha256:80d1f4fbb35b0742d3e3d3bb654b7381cd5f015f8497279a1e9c21ba623e01b1"},
]
[package.extras]
protobuf = ["grpcio-tools (>=1.73.1)"]
protobuf = ["grpcio-tools (>=1.74.0)"]
[[package]]
name = "h11"
@ -520,19 +531,19 @@ zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "huggingface-hub"
version = "0.33.4"
version = "0.34.1"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "huggingface_hub-0.33.4-py3-none-any.whl", hash = "sha256:09f9f4e7ca62547c70f8b82767eefadd2667f4e116acba2e3e62a5a81815a7bb"},
{file = "huggingface_hub-0.33.4.tar.gz", hash = "sha256:6af13478deae120e765bfd92adad0ae1aec1ad8c439b46f23058ad5956cbca0a"},
{file = "huggingface_hub-0.34.1-py3-none-any.whl", hash = "sha256:60d843dcb7bc335145b20e7d2f1dfe93910f6787b2b38a936fb772ce2a83757c"},
{file = "huggingface_hub-0.34.1.tar.gz", hash = "sha256:6978ed89ef981de3c78b75bab100a214843be1cc9d24f8e9c0dc4971808ef1b1"},
]
[package.dependencies]
filelock = "*"
fsspec = ">=2023.5.0"
hf-xet = {version = ">=1.1.2,<2.0.0", markers = "platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"arm64\" or platform_machine == \"aarch64\""}
hf-xet = {version = ">=1.1.3,<2.0.0", markers = "platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"arm64\" or platform_machine == \"aarch64\""}
packaging = ">=20.9"
pyyaml = ">=5.1"
requests = "*"
@ -540,16 +551,16 @@ tqdm = ">=4.42.1"
typing-extensions = ">=3.7.4.3"
[package.extras]
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "authlib (>=1.3.2)", "fastapi", "gradio (>=4.0.0)", "httpx", "itsdangerous", "jedi", "libcst (==1.4.0)", "mypy (==1.15.0)", "mypy (>=1.14.1,<1.15.0)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "authlib (>=1.3.2)", "fastapi", "gradio (>=4.0.0)", "httpx", "itsdangerous", "jedi", "libcst (>=1.4.0)", "mypy (==1.15.0)", "mypy (>=1.14.1,<1.15.0)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
cli = ["InquirerPy (==0.3.4)"]
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "authlib (>=1.3.2)", "fastapi", "gradio (>=4.0.0)", "httpx", "itsdangerous", "jedi", "libcst (==1.4.0)", "mypy (==1.15.0)", "mypy (>=1.14.1,<1.15.0)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "authlib (>=1.3.2)", "fastapi", "gradio (>=4.0.0)", "httpx", "itsdangerous", "jedi", "libcst (>=1.4.0)", "mypy (==1.15.0)", "mypy (>=1.14.1,<1.15.0)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
hf-transfer = ["hf-transfer (>=0.1.4)"]
hf-xet = ["hf-xet (>=1.1.2,<2.0.0)"]
inference = ["aiohttp"]
mcp = ["aiohttp", "mcp (>=1.8.0)", "typer"]
oauth = ["authlib (>=1.3.2)", "fastapi", "httpx", "itsdangerous"]
quality = ["libcst (==1.4.0)", "mypy (==1.15.0)", "mypy (>=1.14.1,<1.15.0)", "ruff (>=0.9.0)"]
quality = ["libcst (>=1.4.0)", "mypy (==1.15.0)", "mypy (>=1.14.1,<1.15.0)", "ruff (>=0.9.0)"]
tensorflow = ["graphviz", "pydot", "tensorflow"]
tensorflow-testing = ["keras (<3.0)", "tensorflow"]
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "authlib (>=1.3.2)", "fastapi", "gradio (>=4.0.0)", "httpx", "itsdangerous", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
@ -1046,13 +1057,13 @@ files = [
[[package]]
name = "openai"
version = "1.95.1"
version = "1.97.1"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.8"
files = [
{file = "openai-1.95.1-py3-none-any.whl", hash = "sha256:8bbdfeceef231b1ddfabbc232b179d79f8b849aab5a7da131178f8d10e0f162f"},
{file = "openai-1.95.1.tar.gz", hash = "sha256:f089b605282e2a2b6776090b4b46563ac1da77f56402a222597d591e2dcc1086"},
{file = "openai-1.97.1-py3-none-any.whl", hash = "sha256:4e96bbdf672ec3d44968c9ea39d2c375891db1acc1794668d8149d5fa6000606"},
{file = "openai-1.97.1.tar.gz", hash = "sha256:a744b27ae624e3d4135225da9b1c89c107a2a7e5bc4c93e5b7b5214772ce7a4e"},
]
[package.dependencies]
@ -1521,16 +1532,17 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests
[[package]]
name = "pytest-asyncio"
version = "1.0.0"
version = "1.1.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.9"
files = [
{file = "pytest_asyncio-1.0.0-py3-none-any.whl", hash = "sha256:4f024da9f1ef945e680dc68610b52550e36590a67fd31bb3b4943979a1f90ef3"},
{file = "pytest_asyncio-1.0.0.tar.gz", hash = "sha256:d15463d13f4456e1ead2594520216b225a16f781e144f8fdf6c5bb4667c48b3f"},
{file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"},
{file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"},
]
[package.dependencies]
backports-asyncio-runner = {version = ">=1.1,<2", markers = "python_version < \"3.11\""}
pytest = ">=8.2,<9"
[package.extras]
@ -2044,18 +2056,18 @@ telegram = ["requests"]
[[package]]
name = "transformers"
version = "4.53.2"
version = "4.54.0"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = false
python-versions = ">=3.9.0"
files = [
{file = "transformers-4.53.2-py3-none-any.whl", hash = "sha256:db8f4819bb34f000029c73c3c557e7d06fc1b8e612ec142eecdae3947a9c78bf"},
{file = "transformers-4.53.2.tar.gz", hash = "sha256:6c3ed95edfb1cba71c4245758f1b4878c93bf8cde77d076307dacb2cbbd72be2"},
{file = "transformers-4.54.0-py3-none-any.whl", hash = "sha256:c96e607f848625965b76c677b2c2576f2c7b7097c1c5292b281919d90675a25e"},
{file = "transformers-4.54.0.tar.gz", hash = "sha256:843da4d66a573cef3d1b2e7a1d767e77da054621e69d9f3faff761e55a1f8203"},
]
[package.dependencies]
filelock = "*"
huggingface-hub = ">=0.30.0,<1.0"
huggingface-hub = ">=0.34.0,<1.0"
numpy = ">=1.17"
packaging = ">=20.0"
pyyaml = ">=5.1"
@ -2067,15 +2079,15 @@ tqdm = ">=4.27"
[package.extras]
accelerate = ["accelerate (>=0.26.0)"]
all = ["Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.6.1,<0.7)", "librosa", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.1)", "torchaudio", "torchvision"]
all = ["Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.6.1,<0.7)", "librosa", "mistral-common[opencv] (>=1.6.3)", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (!=1.0.18,<=1.0.19)", "tokenizers (>=0.21,<0.22)", "torch (>=2.1)", "torchaudio", "torchvision"]
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
benchmark = ["optimum-benchmark (>=0.3.0)"]
codecarbon = ["codecarbon (>=2.8.1)"]
deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
dev = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.6.1,<0.7)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "pandas (<2.3.0)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.1)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)"]
dev-tensorflow = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "pandas (<2.3.0)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
dev-torch = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "kenlm", "kernels (>=0.6.1,<0.7)", "libcst", "librosa", "nltk (<=3.8.1)", "num2words", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "pandas (<2.3.0)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.1)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)"]
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "mistral-common[opencv] (>=1.6.3)", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic (>=2)", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
dev = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.6.1,<0.7)", "libcst", "librosa", "mistral-common[opencv] (>=1.6.3)", "mistral-common[opencv] (>=1.6.3)", "nltk (<=3.8.1)", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "pandas (<2.3.0)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (>=2)", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (!=1.0.18,<=1.0.19)", "tokenizers (>=0.21,<0.22)", "torch (>=2.1)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)"]
dev-tensorflow = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "mistral-common[opencv] (>=1.6.3)", "nltk (<=3.8.1)", "onnxconverter-common", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "pandas (<2.3.0)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (>=2)", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
dev-torch = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "kenlm", "kernels (>=0.6.1,<0.7)", "libcst", "librosa", "mistral-common[opencv] (>=1.6.3)", "nltk (<=3.8.1)", "num2words", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "pandas (<2.3.0)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (>=2)", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (!=1.0.18,<=1.0.19)", "tokenizers (>=0.21,<0.22)", "torch (>=2.1)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)"]
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
ftfy = ["ftfy"]
@ -2083,6 +2095,7 @@ hf-xet = ["hf_xet"]
hub-kernels = ["kernels (>=0.6.1,<0.7)"]
integrations = ["kernels (>=0.6.1,<0.7)", "optuna", "ray[tune] (>=2.7.0)", "sigopt"]
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)"]
mistral-common = ["mistral-common[opencv] (>=1.6.3)"]
modelcreation = ["cookiecutter (==1.7.3)"]
natten = ["natten (>=0.14.6,<0.15.0)"]
num2words = ["num2words"]
@ -2090,27 +2103,27 @@ onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
open-telemetry = ["opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk"]
optuna = ["optuna"]
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "libcst", "pandas (<2.3.0)", "rich", "ruff (==0.11.2)", "urllib3 (<2.0.0)"]
quality = ["GitPython (<3.1.19)", "datasets (>=2.15.0)", "libcst", "pandas (<2.3.0)", "rich", "ruff (==0.11.2)", "urllib3 (<2.0.0)"]
ray = ["ray[tune] (>=2.7.0)"]
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
retrieval = ["datasets (>=2.15.0)", "faiss-cpu"]
ruff = ["ruff (==0.11.2)"]
sagemaker = ["sagemaker (>=2.31.0)"]
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
serving = ["accelerate (>=0.26.0)", "fastapi", "openai", "pydantic (>=2)", "starlette", "torch (>=2.1)", "uvicorn"]
sigopt = ["sigopt"]
sklearn = ["scikit-learn"]
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "mistral-common[opencv] (>=1.6.3)", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic (>=2)", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.11.2)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
tiktoken = ["blobfile", "tiktoken"]
timm = ["timm (<=1.0.11)"]
timm = ["timm (!=1.0.18,<=1.0.19)"]
tokenizers = ["tokenizers (>=0.21,<0.22)"]
torch = ["accelerate (>=0.26.0)", "torch (>=2.1)"]
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
torchhub = ["filelock", "huggingface-hub (>=0.30.0,<1.0)", "importlib_metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.1)", "tqdm (>=4.27)"]
torchhub = ["filelock", "huggingface-hub (>=0.34.0,<1.0)", "importlib_metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.1)", "tqdm (>=4.27)"]
video = ["av"]
vision = ["Pillow (>=10.0.1,<=15.0)"]

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw_modelserver"
version = "0.3.5"
version = "0.3.7"
description = "A model server for serving models"
authors = ["Katanemo Labs, Inc <info@katanemo.com>"]
license = "Apache 2.0"

View file

@ -72,7 +72,7 @@ def start_server(port=51000, foreground=False):
if foreground:
process = subprocess.Popen(
[
"python",
sys.executable,
"-m",
"uvicorn",
"src.main:app",
@ -85,7 +85,7 @@ def start_server(port=51000, foreground=False):
else:
process = subprocess.Popen(
[
"python",
sys.executable,
"-m",
"uvicorn",
"src.main:app",