mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
add support for v1/messages and transformations (#558)
* pushing draft PR * transformations are working. Now need to add some tests next * updated tests and added necessary response transformations for Anthropics' message response object * fixed bugs for integration tests * fixed doc tests * fixed serialization issues with enums on response * adding some debug logs to help * fixed issues with non-streaming responses * updated the stream_context to update response bytes * the serialized bytes length must be set in the response side * fixed the debug statement that was causing the integration tests for wasm to fail * fixing json parsing errors * intentionally removing the headers * making sure that we convert the raw bytes to the correct provider type upstream * fixing non-streaming responses to tranform correctly * /v1/messages works with transformations to and from /v1/chat/completions * updating the CLI and demos to support anthropic vs. claude * adding the anthropic key to the preference based routing tests * fixed test cases and added more structured logs * fixed integration tests and cleaned up logs * added python client tests for anthropic and openai * cleaned up logs and fixed issue with connectivity for llm gateway in weather forecast demo * fixing the tests. python dependency order was broken * updated the openAI client to fix demos * removed the raw response debug statement * fixed the dup cloning issue and cleaned up the ProviderRequestType enum and traits * fixing logs * moved away from string literals to consts * fixed streaming from Anthropic Client to OpenAI * removed debug statement that would likely trip up integration tests * fixed integration tests for llm_gateway * cleaned up test cases and removed unnecessary crates * fixing comments from PR * fixed bug whereby we were sending an OpenAIChatCompletions request object to llm_gateway even though the request may have been AnthropicMessages --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-4.local> Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-9.local> Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-10.local> Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-41.local> Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-136.local>
This commit is contained in:
parent
bb71d041a0
commit
fb0581fd39
38 changed files with 2842 additions and 919 deletions
|
|
@ -4,6 +4,8 @@ use bytes::Bytes;
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
|
|
@ -22,66 +24,61 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
|||
.boxed()
|
||||
}
|
||||
|
||||
pub async fn chat_completions(
|
||||
pub async fn chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
llm_provider_endpoint: String,
|
||||
full_qualified_llm_provider_url: String,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let mut request_headers = request.headers().clone();
|
||||
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
|
||||
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
let chat_request_parsed = serde_json::from_slice::<serde_json::Value>(&chat_request_bytes)
|
||||
.inspect_err(|err| {
|
||||
warn!(
|
||||
"Failed to parse request body as JSON: err: {}, str: {}",
|
||||
err,
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
)
|
||||
})
|
||||
.unwrap_or_else(|_| {
|
||||
warn!(
|
||||
"Failed to parse request body as JSON: {}",
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
serde_json::Value::Null
|
||||
});
|
||||
// Clone metadata for routing and remove archgw_preference_config from original
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
if chat_request_parsed == serde_json::Value::Null {
|
||||
warn!("Request body is not valid JSON");
|
||||
let err_msg = "Request body is not valid JSON".to_string();
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!("Removed archgw_preference_config from metadata");
|
||||
}
|
||||
|
||||
let chat_completion_request: ChatCompletionsRequest =
|
||||
serde_json::from_value(chat_request_parsed.clone()).unwrap();
|
||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||
|
||||
// remove metadata from the request
|
||||
let mut chat_request_user_preferences_removed = chat_request_parsed;
|
||||
if let Some(metadata) = chat_request_user_preferences_removed.get_mut("metadata") {
|
||||
debug!("Removing metadata from request");
|
||||
if let Some(m) = metadata.as_object_mut() {
|
||||
m.remove("archgw_preference_config");
|
||||
debug!("Removed archgw_preference_config from metadata");
|
||||
}
|
||||
|
||||
// if metadata is empty, remove it
|
||||
if metadata.as_object().map_or(false, |m| m.is_empty()) {
|
||||
debug!("Removing empty metadata from request");
|
||||
chat_request_user_preferences_removed
|
||||
.as_object_mut()
|
||||
.map(|m| m.remove("metadata"));
|
||||
}
|
||||
}
|
||||
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
|
||||
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
|
||||
match ProviderRequestType::try_from((client_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(ProviderRequestType::MessagesRequest(_)) => {
|
||||
// This should not happen after conversion to OpenAI format
|
||||
warn!("Unexpected: got MessagesRequest after converting to OpenAI format");
|
||||
let err_msg = "Request conversion failed".to_string();
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
},
|
||||
Err(err) => {
|
||||
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
|
||||
let err_msg = format!("Failed to convert request: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"arch-router request received: {}",
|
||||
&serde_json::to_string(&chat_completion_request).unwrap()
|
||||
"[BRIGHTSTAFF -> ARCH_ROUTER] REQ: {}",
|
||||
&serde_json::to_string(&chat_completions_request_for_arch_router).unwrap()
|
||||
);
|
||||
|
||||
let trace_parent = request_headers
|
||||
|
|
@ -90,7 +87,7 @@ pub async fn chat_completions(
|
|||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
let usage_preferences_str: Option<String> =
|
||||
chat_completion_request.metadata.and_then(|metadata| {
|
||||
routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
|
|
@ -101,7 +98,7 @@ pub async fn chat_completions(
|
|||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
|
||||
let latest_message_for_log =
|
||||
chat_completion_request
|
||||
chat_completions_request_for_arch_router
|
||||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
|
|
@ -126,7 +123,7 @@ pub async fn chat_completions(
|
|||
|
||||
let model_name = match router_service
|
||||
.determine_route(
|
||||
&chat_completion_request.messages,
|
||||
&chat_completions_request_for_arch_router.messages,
|
||||
trace_parent.clone(),
|
||||
usage_preferences,
|
||||
)
|
||||
|
|
@ -137,9 +134,9 @@ pub async fn chat_completions(
|
|||
None => {
|
||||
debug!(
|
||||
"No route determined, using default model from request: {}",
|
||||
chat_completion_request.model
|
||||
chat_completions_request_for_arch_router.model
|
||||
);
|
||||
chat_completion_request.model.clone()
|
||||
chat_completions_request_for_arch_router.model.clone()
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
|
|
@ -151,8 +148,8 @@ pub async fn chat_completions(
|
|||
};
|
||||
|
||||
debug!(
|
||||
"sending request to llm provider: {}, with model hint: {}",
|
||||
llm_provider_endpoint, model_name
|
||||
"[BRIGHTSTAFF -> ARCH_ROUTER] URL: {}, Model Hint: {}",
|
||||
full_qualified_llm_provider_url, model_name
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
|
|
@ -166,17 +163,13 @@ pub async fn chat_completions(
|
|||
header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
let chat_request_parsed_bytes =
|
||||
serde_json::to_string(&chat_request_user_preferences_removed).unwrap();
|
||||
|
||||
// remove content-length header if it exists
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(llm_provider_endpoint)
|
||||
.post(full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(chat_request_parsed_bytes)
|
||||
.body(client_request_bytes_for_upstream)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
use brightstaff::handlers::chat_completions::chat_completions;
|
||||
use brightstaff::handlers::chat_completions::chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::Configuration;
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH};
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::server::conn::http1;
|
||||
|
|
@ -67,10 +68,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
&serde_json::to_string(arch_config.as_ref()).unwrap()
|
||||
);
|
||||
|
||||
let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT")
|
||||
.unwrap_or_else(|_| "http://localhost:12001/v1/chat/completions".to_string());
|
||||
let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT")
|
||||
.unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
info!("llm provider endpoint: {}", llm_provider_endpoint);
|
||||
info!("llm provider url: {}", llm_provider_url);
|
||||
info!("listening on http://{}", bind_address);
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
|
||||
|
|
@ -88,7 +89,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
arch_config.llm_providers.clone(),
|
||||
llm_provider_endpoint.clone(),
|
||||
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
|
||||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
));
|
||||
|
|
@ -99,19 +100,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let service = service_fn(move |req| {
|
||||
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, "/v1/chat/completions") => {
|
||||
chat_completions(req, router_service, llm_provider_endpoint)
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
|
||||
chat(req, router_service, fully_qualified_url)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -149,8 +149,8 @@ pub struct EmbeddingProviver {
|
|||
pub enum LlmProviderType {
|
||||
#[serde(rename = "arch")]
|
||||
Arch,
|
||||
#[serde(rename = "claude")]
|
||||
Claude,
|
||||
#[serde(rename = "anthropic")]
|
||||
Anthropic,
|
||||
#[serde(rename = "deepseek")]
|
||||
Deepseek,
|
||||
#[serde(rename = "groq")]
|
||||
|
|
@ -167,7 +167,7 @@ impl Display for LlmProviderType {
|
|||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
LlmProviderType::Arch => write!(f, "arch"),
|
||||
LlmProviderType::Claude => write!(f, "claude"),
|
||||
LlmProviderType::Anthropic => write!(f, "anthropic"),
|
||||
LlmProviderType::Deepseek => write!(f, "deepseek"),
|
||||
LlmProviderType::Groq => write!(f, "groq"),
|
||||
LlmProviderType::Gemini => write!(f, "gemini"),
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
|||
pub const MESSAGES_KEY: &str = "messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||
pub const HEALTHZ_PATH: &str = "/healthz";
|
||||
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message";
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ pub fn get_llm_provider(
|
|||
return provider;
|
||||
}
|
||||
|
||||
|
||||
if llm_providers.default().is_some() {
|
||||
return llm_providers.default().unwrap();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use log::debug;
|
|||
|
||||
#[allow(dead_code)]
|
||||
pub fn token_count(model_name: &str, text: &str) -> Result<usize, String> {
|
||||
debug!("getting token count model={}", model_name);
|
||||
debug!("TOKENIZER: computing token count for model={}", model_name);
|
||||
//HACK: add support for tokenizing mistral and other models
|
||||
//filed issue https://github.com/katanemo/arch/issues/222
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
use crate::providers::response::TokenUsage;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use crate::{MESSAGES_PATH};
|
||||
|
||||
// Enum for all supported Anthropic APIs
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
|
@ -17,13 +22,13 @@ pub enum AnthropicApi {
|
|||
impl ApiDefinition for AnthropicApi {
|
||||
fn endpoint(&self) -> &'static str {
|
||||
match self {
|
||||
AnthropicApi::Messages => "/v1/messages",
|
||||
AnthropicApi::Messages => MESSAGES_PATH,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
match endpoint {
|
||||
"/v1/messages" => Some(AnthropicApi::Messages),
|
||||
MESSAGES_PATH => Some(AnthropicApi::Messages),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -186,6 +191,19 @@ pub enum MessagesContentBlock {
|
|||
},
|
||||
}
|
||||
|
||||
impl ExtractText for Vec<MessagesContentBlock> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.iter()
|
||||
.filter_map(|block| match block {
|
||||
MessagesContentBlock::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessagesImageSource {
|
||||
|
|
@ -220,6 +238,15 @@ pub enum MessagesMessageContent {
|
|||
Blocks(Vec<MessagesContentBlock>),
|
||||
}
|
||||
|
||||
impl ExtractText for MessagesMessageContent {
|
||||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessagesMessageContent::Single(text) => text.clone(),
|
||||
MessagesMessageContent::Blocks(parts) => parts.extract_text()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessagesSystemPrompt {
|
||||
|
|
@ -369,6 +396,121 @@ impl MessagesRequest {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for MessagesRequest {
|
||||
type Error = serde_json::Error;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenUsage for MessagesResponse {
|
||||
fn completion_tokens(&self) -> usize {
|
||||
self.usage.output_tokens as usize
|
||||
}
|
||||
fn prompt_tokens(&self) -> usize {
|
||||
self.usage.input_tokens as usize
|
||||
}
|
||||
fn total_tokens(&self) -> usize {
|
||||
(self.usage.input_tokens + self.usage.output_tokens) as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for MessagesResponse {
|
||||
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||
Some(self)
|
||||
}
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderRequest for MessagesRequest {
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: String) {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
fn is_streaming(&self) -> bool {
|
||||
self.stream.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
let mut text_parts = Vec::new();
|
||||
|
||||
// Include system prompt if present
|
||||
if let Some(system) = &self.system {
|
||||
match system {
|
||||
MessagesSystemPrompt::Single(s) => text_parts.push(s.clone()),
|
||||
MessagesSystemPrompt::Blocks(blocks) => {
|
||||
for block in blocks {
|
||||
if let MessagesContentBlock::Text { text } = block {
|
||||
text_parts.push(text.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract text from all messages
|
||||
for message in &self.messages {
|
||||
match &message.content {
|
||||
MessagesMessageContent::Single(text) => text_parts.push(text.clone()),
|
||||
MessagesMessageContent::Blocks(blocks) => {
|
||||
for block in blocks {
|
||||
if let MessagesContentBlock::Text { text } = block {
|
||||
text_parts.push(text.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text_parts.join(" ")
|
||||
}
|
||||
|
||||
fn get_recent_user_message(&self) -> Option<String> {
|
||||
// Find the most recent user message
|
||||
for message in self.messages.iter().rev() {
|
||||
if message.role == MessagesRole::User {
|
||||
match &message.content {
|
||||
MessagesMessageContent::Single(text) => return Some(text.clone()),
|
||||
MessagesMessageContent::Blocks(blocks) => {
|
||||
for block in blocks {
|
||||
if let MessagesContentBlock::Text { text } = block {
|
||||
return Some(text.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||
serde_json::to_vec(self).map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to serialize MessagesRequest: {}", e),
|
||||
source: Some(Box::new(e)),
|
||||
})
|
||||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
return &self.metadata;
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
if let Some(ref mut metadata) = self.metadata {
|
||||
metadata.remove(key).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MessagesResponse {
|
||||
pub fn api_type() -> AnthropicApi {
|
||||
AnthropicApi::Messages
|
||||
|
|
@ -381,6 +523,54 @@ impl MessagesStreamEvent {
|
|||
}
|
||||
}
|
||||
|
||||
impl MessagesRole {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
MessagesRole::User => "user",
|
||||
MessagesRole::Assistant => "assistant",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Implement ProviderStreamResponse for MessagesStreamEvent
|
||||
impl ProviderStreamResponse for MessagesStreamEvent {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
|
||||
if let MessagesContentDelta::TextDelta { text } = delta {
|
||||
Some(text)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
matches!(self, MessagesStreamEvent::MessageStop)
|
||||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessagesStreamEvent::MessageStart { message } => Some(message.role.as_str()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn event_type(&self) -> Option<&str> {
|
||||
Some(match self {
|
||||
MessagesStreamEvent::MessageStart { .. } => "message_start",
|
||||
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
|
||||
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
|
||||
MessagesStreamEvent::MessageStop => "message_stop",
|
||||
MessagesStreamEvent::Ping => "ping",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -878,13 +1068,13 @@ mod tests {
|
|||
let api = AnthropicApi::Messages;
|
||||
|
||||
// Test trait methods
|
||||
assert_eq!(api.endpoint(), "/v1/messages");
|
||||
assert_eq!(api.endpoint(), MESSAGES_PATH);
|
||||
assert!(api.supports_streaming());
|
||||
assert!(api.supports_tools());
|
||||
assert!(api.supports_vision());
|
||||
|
||||
// Test from_endpoint trait method
|
||||
let found_api = AnthropicApi::from_endpoint("/v1/messages");
|
||||
let found_api = AnthropicApi::from_endpoint(MESSAGES_PATH);
|
||||
assert_eq!(found_api, Some(AnthropicApi::Messages));
|
||||
|
||||
let not_found = AnthropicApi::from_endpoint("/v1/unknown");
|
||||
|
|
|
|||
|
|
@ -1,110 +1,9 @@
|
|||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
|
||||
// Re-export all types for convenience
|
||||
pub use anthropic::*;
|
||||
pub use openai::*;
|
||||
|
||||
/// Common trait that all API definitions must implement
|
||||
///
|
||||
/// This trait ensures consistency across different AI provider API definitions
|
||||
/// and makes it easy to add new providers like Gemini, Claude, etc.
|
||||
///
|
||||
/// Note: This is different from the `ApiProvider` enum in `clients::endpoints`
|
||||
/// which represents provider identification, while this trait defines API capabilities.
|
||||
///
|
||||
/// # Benefits
|
||||
///
|
||||
/// - **Consistency**: All API providers implement the same interface
|
||||
/// - **Extensibility**: Easy to add new providers without breaking existing code
|
||||
/// - **Type Safety**: Compile-time guarantees that all providers implement required methods
|
||||
/// - **Discoverability**: Clear documentation of what capabilities each API supports
|
||||
///
|
||||
/// # Example implementation for a new provider:
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use serde::{Deserialize, Serialize};
|
||||
/// use super::ApiDefinition;
|
||||
///
|
||||
/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
/// pub enum GeminiApi {
|
||||
/// GenerateContent,
|
||||
/// ChatCompletions,
|
||||
/// }
|
||||
///
|
||||
/// impl GeminiApi {
|
||||
/// pub fn endpoint(&self) -> &'static str {
|
||||
/// match self {
|
||||
/// GeminiApi::GenerateContent => "/v1/models/gemini-pro:generateContent",
|
||||
/// GeminiApi::ChatCompletions => "/v1/models/gemini-pro:chat",
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
/// match endpoint {
|
||||
/// "/v1/models/gemini-pro:generateContent" => Some(GeminiApi::GenerateContent),
|
||||
/// "/v1/models/gemini-pro:chat" => Some(GeminiApi::ChatCompletions),
|
||||
/// _ => None,
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// pub fn supports_streaming(&self) -> bool {
|
||||
/// match self {
|
||||
/// GeminiApi::GenerateContent => true,
|
||||
/// GeminiApi::ChatCompletions => true,
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// pub fn supports_tools(&self) -> bool {
|
||||
/// match self {
|
||||
/// GeminiApi::GenerateContent => true,
|
||||
/// GeminiApi::ChatCompletions => false,
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// pub fn supports_vision(&self) -> bool {
|
||||
/// match self {
|
||||
/// GeminiApi::GenerateContent => true,
|
||||
/// GeminiApi::ChatCompletions => false,
|
||||
/// }
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// impl ApiDefinition for GeminiApi {
|
||||
/// fn endpoint(&self) -> &'static str {
|
||||
/// self.endpoint()
|
||||
/// }
|
||||
///
|
||||
/// fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
/// Self::from_endpoint(endpoint)
|
||||
/// }
|
||||
///
|
||||
/// fn supports_streaming(&self) -> bool {
|
||||
/// self.supports_streaming()
|
||||
/// }
|
||||
///
|
||||
/// fn supports_tools(&self) -> bool {
|
||||
/// self.supports_tools()
|
||||
/// }
|
||||
///
|
||||
/// fn supports_vision(&self) -> bool {
|
||||
/// self.supports_vision()
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// // Now you can use generic code that works with any API:
|
||||
/// fn print_api_info<T: ApiDefinition>(api: &T) {
|
||||
/// println!("Endpoint: {}", api.endpoint());
|
||||
/// println!("Supports streaming: {}", api.supports_streaming());
|
||||
/// println!("Supports tools: {}", api.supports_tools());
|
||||
/// println!("Supports vision: {}", api.supports_vision());
|
||||
/// }
|
||||
///
|
||||
/// // Works with both OpenAI and Anthropic (and future Gemini)
|
||||
/// print_api_info(&OpenAIApi::ChatCompletions);
|
||||
/// print_api_info(&AnthropicApi::Messages);
|
||||
/// print_api_info(&GeminiApi::GenerateContent);
|
||||
/// ```
|
||||
|
||||
pub trait ApiDefinition {
|
||||
/// Returns the endpoint path for this API
|
||||
fn endpoint(&self) -> &'static str;
|
||||
|
|
@ -132,6 +31,7 @@ pub trait ApiDefinition {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH};
|
||||
|
||||
#[test]
|
||||
fn test_generic_api_functionality() {
|
||||
|
|
@ -150,8 +50,8 @@ mod tests {
|
|||
fn test_api_detection_from_endpoints() {
|
||||
// Test that we can detect APIs from endpoints using the trait
|
||||
let endpoints = vec![
|
||||
"/v1/chat/completions",
|
||||
"/v1/messages",
|
||||
CHAT_COMPLETIONS_PATH,
|
||||
MESSAGES_PATH,
|
||||
"/v1/unknown"
|
||||
];
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ use std::collections::HashMap;
|
|||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
|
||||
|
||||
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage, SseStreamIter};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::{ExtractText};
|
||||
use crate::{CHAT_COMPLETIONS_PATH};
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI API ENUMERATION
|
||||
|
|
@ -28,13 +28,13 @@ pub enum OpenAIApi {
|
|||
impl ApiDefinition for OpenAIApi {
|
||||
fn endpoint(&self) -> &'static str {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => "/v1/chat/completions",
|
||||
OpenAIApi::ChatCompletions => CHAT_COMPLETIONS_PATH,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
match endpoint {
|
||||
"/v1/chat/completions" => Some(OpenAIApi::ChatCompletions),
|
||||
CHAT_COMPLETIONS_PATH => Some(OpenAIApi::ChatCompletions),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -81,7 +81,7 @@ pub struct ChatCompletionsRequest {
|
|||
// Maximum tokens in the response has been deprecated, but we keep it for compatibility
|
||||
pub max_tokens: Option<u32>,
|
||||
pub modalities: Option<Vec<String>>,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
pub n: Option<u32>,
|
||||
pub presence_penalty: Option<f32>,
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
|
|
@ -174,6 +174,28 @@ pub enum MessageContent {
|
|||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
// Content Extraction
|
||||
impl ExtractText for MessageContent {
|
||||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.extract_text()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtractText for Vec<ContentPart> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for MessageContent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
|
@ -328,6 +350,7 @@ pub struct ChatCompletionsResponse {
|
|||
pub choices: Vec<Choice>,
|
||||
pub usage: Usage,
|
||||
pub system_fingerprint: Option<String>,
|
||||
pub service_tier: Option<String>,
|
||||
}
|
||||
|
||||
/// Finish reason for completion
|
||||
|
|
@ -576,6 +599,18 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
source: Some(Box::new(e)),
|
||||
})
|
||||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
return &self.metadata;
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
if let Some(ref mut metadata) = self.metadata {
|
||||
metadata.remove(key).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of ProviderResponse for ChatCompletionsResponse
|
||||
|
|
@ -593,68 +628,6 @@ impl ProviderResponse for ChatCompletionsResponse {
|
|||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI SSE STREAMING ITERATOR
|
||||
// ============================================================================
|
||||
|
||||
/// OpenAI-specific SSE streaming iterator
|
||||
/// Handles OpenAI's specific SSE format and ChatCompletionsStreamResponse parsing
|
||||
pub struct OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
sse_stream: SseStreamIter<I>,
|
||||
}
|
||||
|
||||
impl<I> OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(sse_stream: SseStreamIter<I>) -> Self {
|
||||
Self { sse_stream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for line in &mut self.sse_stream.lines {
|
||||
let line = line.as_ref();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with("data: ") {
|
||||
let data = &line[6..]; // Remove "data: " prefix
|
||||
if data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Skip ping messages (usually from other providers, but handle gracefully)
|
||||
if data == r#"{"type": "ping"}"# {
|
||||
continue;
|
||||
}
|
||||
|
||||
// OpenAI-specific parsing of ChatCompletionsStreamResponse
|
||||
match serde_json::from_str::<ChatCompletionsStreamResponse>(data) {
|
||||
Ok(response) => return Some(Ok(Box::new(response))),
|
||||
Err(e) => return Some(Err(Box::new(
|
||||
OpenAIStreamError::InvalidStreamingData(format!("Error parsing OpenAI streaming data: {}, data: {}", e, data))
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
|
||||
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
|
|
@ -680,6 +653,10 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
|||
Role::Tool => "tool",
|
||||
}))
|
||||
}
|
||||
|
||||
fn event_type(&self) -> Option<&str> {
|
||||
None // OpenAI doesn't use event types in SSE
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -982,13 +959,13 @@ mod tests {
|
|||
let api = OpenAIApi::ChatCompletions;
|
||||
|
||||
// Test trait methods
|
||||
assert_eq!(api.endpoint(), "/v1/chat/completions");
|
||||
assert_eq!(api.endpoint(), CHAT_COMPLETIONS_PATH);
|
||||
assert!(api.supports_streaming());
|
||||
assert!(api.supports_tools());
|
||||
assert!(api.supports_vision());
|
||||
|
||||
// Test from_endpoint
|
||||
let found_api = OpenAIApi::from_endpoint("/v1/chat/completions");
|
||||
let found_api = OpenAIApi::from_endpoint(CHAT_COMPLETIONS_PATH);
|
||||
assert_eq!(found_api, Some(OpenAIApi::ChatCompletions));
|
||||
|
||||
let not_found = OpenAIApi::from_endpoint("/v1/unknown");
|
||||
|
|
@ -1139,4 +1116,84 @@ mod tests {
|
|||
let invalid_result: Result<ToolChoice, _> = serde_json::from_value(json!("invalid"));
|
||||
assert!(invalid_result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_response_with_service_tier() {
|
||||
// Test that ChatCompletionsResponse can deserialize OpenAI responses with service_tier field
|
||||
let json_response = r#"{
|
||||
"id": "chatcmpl-CAJc2Df6QCc7Mv3RP0Cf2xlbDV1x2",
|
||||
"object": "chat.completion",
|
||||
"created": 1756574706,
|
||||
"model": "gpt-4o-2024-08-06",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Test response content",
|
||||
"annotations": []
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 65,
|
||||
"completion_tokens": 184,
|
||||
"total_tokens": 249,
|
||||
"prompt_tokens_details": {
|
||||
"cached_tokens": 0,
|
||||
"audio_tokens": 0
|
||||
},
|
||||
"completion_tokens_details": {
|
||||
"reasoning_tokens": 0,
|
||||
"audio_tokens": 0,
|
||||
"accepted_prediction_tokens": 0,
|
||||
"rejected_prediction_tokens": 0
|
||||
}
|
||||
},
|
||||
"service_tier": "default",
|
||||
"system_fingerprint": "fp_f33640a400"
|
||||
}"#;
|
||||
|
||||
let response: ChatCompletionsResponse = serde_json::from_str(json_response).unwrap();
|
||||
|
||||
assert_eq!(response.id, "chatcmpl-CAJc2Df6QCc7Mv3RP0Cf2xlbDV1x2");
|
||||
assert_eq!(response.object, "chat.completion");
|
||||
assert_eq!(response.created, 1756574706);
|
||||
assert_eq!(response.model, "gpt-4o-2024-08-06");
|
||||
assert_eq!(response.service_tier, Some("default".to_string()));
|
||||
assert_eq!(response.system_fingerprint, Some("fp_f33640a400".to_string()));
|
||||
assert_eq!(response.choices.len(), 1);
|
||||
assert_eq!(response.usage.prompt_tokens, 65);
|
||||
assert_eq!(response.usage.completion_tokens, 184);
|
||||
assert_eq!(response.usage.total_tokens, 249);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_response_without_service_tier() {
|
||||
// Test that ChatCompletionsResponse can deserialize responses without service_tier (backward compatibility)
|
||||
let json_response = r#"{
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Test response"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}"#;
|
||||
|
||||
let response: ChatCompletionsResponse = serde_json::from_str(json_response).unwrap();
|
||||
|
||||
assert_eq!(response.id, "chatcmpl-123");
|
||||
assert_eq!(response.service_tier, None); // Should be None when not present
|
||||
assert_eq!(response.system_fingerprint, None);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,12 +6,13 @@
|
|||
//! # Examples
|
||||
//!
|
||||
//! ```rust
|
||||
//! use hermesllm::clients::endpoints::{is_supported_endpoint, supported_endpoints};
|
||||
//! use hermesllm::clients::endpoints::supported_endpoints;
|
||||
//!
|
||||
//! // Check if we support an endpoint
|
||||
//! assert!(is_supported_endpoint("/v1/chat/completions"));
|
||||
//! assert!(is_supported_endpoint("/v1/messages"));
|
||||
//! assert!(!is_supported_endpoint("/v1/unknown"));
|
||||
//! use hermesllm::clients::endpoints::SupportedAPIs;
|
||||
//! assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some());
|
||||
//! assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
|
||||
//! assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
|
||||
//!
|
||||
//! // Get all supported endpoints
|
||||
//! let endpoints = supported_endpoints();
|
||||
|
|
@ -20,23 +21,81 @@
|
|||
//! assert!(endpoints.contains(&"/v1/messages"));
|
||||
//! ```
|
||||
|
||||
use crate::apis::{AnthropicApi, OpenAIApi, ApiDefinition};
|
||||
use crate::{apis::{AnthropicApi, ApiDefinition, OpenAIApi}, ProviderId};
|
||||
use std::fmt;
|
||||
|
||||
/// Check if the given endpoint path is supported
|
||||
pub fn is_supported_endpoint(endpoint: &str) -> bool {
|
||||
// Try OpenAI APIs
|
||||
if OpenAIApi::from_endpoint(endpoint).is_some() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Try Anthropic APIs
|
||||
if AnthropicApi::from_endpoint(endpoint).is_some() {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
/// Unified enum representing all supported API endpoints across providers
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SupportedAPIs {
|
||||
OpenAIChatCompletions(OpenAIApi),
|
||||
AnthropicMessagesAPI(AnthropicApi),
|
||||
}
|
||||
|
||||
impl fmt::Display for SupportedAPIs {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SupportedAPIs::OpenAIChatCompletions(api) => write!(f, "OpenAI API ({})", api.endpoint()),
|
||||
SupportedAPIs::AnthropicMessagesAPI(api) => write!(f, "Anthropic API ({})", api.endpoint()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SupportedAPIs {
|
||||
/// Create a SupportedApi from an endpoint path
|
||||
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
||||
return Some(SupportedAPIs::OpenAIChatCompletions(openai_api));
|
||||
}
|
||||
|
||||
if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) {
|
||||
return Some(SupportedAPIs::AnthropicMessagesAPI(anthropic_api));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Get the endpoint path for this API
|
||||
pub fn endpoint(&self) -> &'static str {
|
||||
match self {
|
||||
SupportedAPIs::OpenAIChatCompletions(api) => api.endpoint(),
|
||||
SupportedAPIs::AnthropicMessagesAPI(api) => api.endpoint(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str) -> String {
|
||||
let default_endpoint = "/v1/chat/completions".to_string();
|
||||
match self {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
|
||||
match provider_id {
|
||||
ProviderId::Anthropic => "/v1/messages".to_string(),
|
||||
_ => default_endpoint,
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
match provider_id {
|
||||
ProviderId::Groq => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai{}", request_path)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/v1beta/openai/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
_ => default_endpoint,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Get all supported endpoint paths
|
||||
pub fn supported_endpoints() -> Vec<&'static str> {
|
||||
let mut endpoints = Vec::new();
|
||||
|
|
@ -74,15 +133,15 @@ mod tests {
|
|||
#[test]
|
||||
fn test_is_supported_endpoint() {
|
||||
// OpenAI endpoints
|
||||
assert!(is_supported_endpoint("/v1/chat/completions"));
|
||||
assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some());
|
||||
|
||||
// Anthropic endpoints
|
||||
assert!(is_supported_endpoint("/v1/messages"));
|
||||
assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
|
||||
|
||||
// Unsupported endpoints
|
||||
assert!(!is_supported_endpoint("/v1/unknown"));
|
||||
assert!(!is_supported_endpoint("/v2/chat"));
|
||||
assert!(!is_supported_endpoint(""));
|
||||
assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
|
||||
assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some());
|
||||
assert!(!SupportedAPIs::from_endpoint("").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,6 @@ pub mod endpoints;
|
|||
|
||||
// Re-export the main items for easier access
|
||||
pub use lib::*;
|
||||
pub use endpoints::{is_supported_endpoint, supported_endpoints, identify_provider};
|
||||
pub use endpoints::{SupportedAPIs, identify_provider};
|
||||
|
||||
// Note: transformer module contains TryFrom trait implementations that are automatically available
|
||||
|
|
|
|||
|
|
@ -44,8 +44,6 @@
|
|||
|
||||
use serde_json::Value;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
// Import centralized types
|
||||
use crate::apis::*;
|
||||
use super::TransformError;
|
||||
|
||||
|
|
@ -61,7 +59,7 @@ const DEFAULT_MAX_TOKENS: u32 = 4096;
|
|||
// ============================================================================
|
||||
|
||||
/// Trait for extracting text content from various types
|
||||
trait ExtractText {
|
||||
pub trait ExtractText {
|
||||
fn extract_text(&self) -> String;
|
||||
}
|
||||
|
||||
|
|
@ -213,6 +211,7 @@ impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
|
|||
choices: vec![choice],
|
||||
usage,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -541,40 +540,6 @@ impl Into<Role> for MessagesRole {
|
|||
}
|
||||
}
|
||||
|
||||
// Content Extraction
|
||||
impl ExtractText for MessageContent {
|
||||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.extract_text()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtractText for Vec<ContentPart> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtractText for Vec<MessagesContentBlock> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.iter()
|
||||
.filter_map(|block| match block {
|
||||
MessagesContentBlock::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Content Utilities
|
||||
impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
|
||||
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError> {
|
||||
|
|
|
|||
|
|
@ -4,12 +4,16 @@
|
|||
pub mod providers;
|
||||
pub mod apis;
|
||||
pub mod clients;
|
||||
|
||||
// Re-export important types and traits
|
||||
pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError};
|
||||
pub use providers::response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, ProviderResponseError, TokenUsage};
|
||||
pub use providers::response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage, SseEvent, SseStreamIter};
|
||||
pub use providers::id::ProviderId;
|
||||
pub use providers::adapters::{has_compatible_api, supported_apis};
|
||||
|
||||
|
||||
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
|
@ -23,72 +27,51 @@ mod tests {
|
|||
assert_eq!(ProviderId::from("arch"), ProviderId::Arch);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_api_compatibility() {
|
||||
assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions"));
|
||||
assert!(!has_compatible_api(&ProviderId::OpenAI, "/v1/embeddings"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_supported_apis() {
|
||||
let apis = supported_apis(&ProviderId::OpenAI);
|
||||
assert!(apis.contains(&"/v1/chat/completions"));
|
||||
|
||||
// Test that provider supports the expected API endpoints
|
||||
assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_request_parsing() {
|
||||
// Test with a sample JSON request
|
||||
let json_request = r#"{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: Result<ProviderRequestType, std::io::Error> = ProviderRequestType::try_from(json_request.as_bytes());
|
||||
assert!(result.is_ok());
|
||||
|
||||
let request = result.unwrap();
|
||||
assert_eq!(request.model(), "gpt-4");
|
||||
assert_eq!(request.get_recent_user_message(), Some("Hello!".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_streaming_response() {
|
||||
// Test streaming response parsing with sample SSE data
|
||||
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
let client_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
|
||||
let mut streaming_response = result.unwrap();
|
||||
// Test the new simplified architecture - create SseStreamIter directly
|
||||
let sse_iter = SseStreamIter::try_from(sse_data.as_bytes());
|
||||
assert!(sse_iter.is_ok());
|
||||
|
||||
// Test that we can iterate over chunks - it's just an iterator now!
|
||||
let first_chunk = streaming_response.next();
|
||||
assert!(first_chunk.is_some());
|
||||
let mut streaming_iter = sse_iter.unwrap();
|
||||
|
||||
let chunk_result = first_chunk.unwrap();
|
||||
assert!(chunk_result.is_ok());
|
||||
// Test that we can iterate over SseEvents
|
||||
let first_event = streaming_iter.next();
|
||||
assert!(first_event.is_some());
|
||||
|
||||
let chunk = chunk_result.unwrap();
|
||||
assert_eq!(chunk.content_delta(), Some("Hello"));
|
||||
assert!(!chunk.is_final());
|
||||
let sse_event = first_event.unwrap();
|
||||
|
||||
// Test that stream ends properly
|
||||
let final_chunk = streaming_response.next();
|
||||
assert!(final_chunk.is_none());
|
||||
// Test SseEvent properties
|
||||
assert!(!sse_event.is_done());
|
||||
assert!(sse_event.data.as_ref().unwrap().contains("Hello"));
|
||||
|
||||
// Test that we can parse the event into a provider stream response
|
||||
let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api));
|
||||
if let Err(e) = &transformed_event {
|
||||
println!("Transform error: {:?}", e);
|
||||
}
|
||||
assert!(transformed_event.is_ok());
|
||||
|
||||
let transformed_event = transformed_event.unwrap();
|
||||
let provider_response = transformed_event.provider_response();
|
||||
assert!(provider_response.is_ok());
|
||||
|
||||
let stream_response = provider_response.unwrap();
|
||||
assert_eq!(stream_response.content_delta(), Some("Hello"));
|
||||
assert!(!stream_response.is_final());
|
||||
|
||||
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
|
||||
let final_event = streaming_iter.next();
|
||||
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,39 +0,0 @@
|
|||
use crate::providers::id::ProviderId;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AdapterType {
|
||||
OpenAICompatible,
|
||||
// Future: Claude, Gemini, etc.
|
||||
}
|
||||
|
||||
/// Provider adapter configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderConfig {
|
||||
pub supported_apis: &'static [&'static str],
|
||||
pub adapter_type: AdapterType,
|
||||
}
|
||||
|
||||
/// Check if provider has compatible API
|
||||
pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool {
|
||||
let config = get_provider_config(provider_id);
|
||||
config.supported_apis.iter().any(|&supported| supported == api_path)
|
||||
}
|
||||
|
||||
/// Get supported APIs for provider
|
||||
pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> {
|
||||
let config = get_provider_config(provider_id);
|
||||
config.supported_apis.to_vec()
|
||||
}
|
||||
|
||||
/// Get provider configuration
|
||||
pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig {
|
||||
match provider_id {
|
||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||
ProviderConfig {
|
||||
supported_apis: &["/v1/chat/completions"],
|
||||
adapter_type: AdapterType::OpenAICompatible,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
use std::fmt::Display;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::{OpenAIApi, AnthropicApi};
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
|
|
@ -8,7 +10,7 @@ pub enum ProviderId {
|
|||
Deepseek,
|
||||
Groq,
|
||||
Gemini,
|
||||
Claude,
|
||||
Anthropic,
|
||||
GitHub,
|
||||
Arch,
|
||||
}
|
||||
|
|
@ -21,7 +23,7 @@ impl From<&str> for ProviderId {
|
|||
"deepseek" => ProviderId::Deepseek,
|
||||
"groq" => ProviderId::Groq,
|
||||
"gemini" => ProviderId::Gemini,
|
||||
"claude" => ProviderId::Claude,
|
||||
"anthropic" => ProviderId::Anthropic,
|
||||
"github" => ProviderId::GitHub,
|
||||
"arch" => ProviderId::Arch,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
|
|
@ -29,6 +31,21 @@ impl From<&str> for ProviderId {
|
|||
}
|
||||
}
|
||||
|
||||
impl ProviderId {
|
||||
/// Given a client API, return the compatible upstream API for this provider
|
||||
pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs {
|
||||
match (self, client_api) {
|
||||
// Claude/Anthropic providers natively support Anthropic APIs
|
||||
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
|
||||
(ProviderId::Anthropic, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// OpenAI-compatible providers only support OpenAI chat completions
|
||||
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ProviderId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
|
@ -37,7 +54,7 @@ impl Display for ProviderId {
|
|||
ProviderId::Deepseek => write!(f, "Deepseek"),
|
||||
ProviderId::Groq => write!(f, "Groq"),
|
||||
ProviderId::Gemini => write!(f, "Gemini"),
|
||||
ProviderId::Claude => write!(f, "Claude"),
|
||||
ProviderId::Anthropic => write!(f, "Anthropic"),
|
||||
ProviderId::GitHub => write!(f, "GitHub"),
|
||||
ProviderId::Arch => write!(f, "Arch"),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@
|
|||
pub mod id;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod adapters;
|
||||
|
||||
pub use id::ProviderId;
|
||||
pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ;
|
||||
pub use response::{ProviderResponseType, ProviderStreamResponseIter, ProviderResponse, ProviderStreamResponse, TokenUsage };
|
||||
pub use adapters::*;
|
||||
pub use response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage };
|
||||
|
|
|
|||
|
|
@ -1,41 +1,17 @@
|
|||
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use super::{ProviderId, get_provider_config, AdapterType};
|
||||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
|
||||
use serde_json::Value;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::collections::HashMap;
|
||||
#[derive(Clone)]
|
||||
pub enum ProviderRequestType {
|
||||
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||
//MessagesRequest(MessagesRequest),
|
||||
MessagesRequest(MessagesRequest),
|
||||
//add more request types here
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
// if passing bytes without provider id we assume the request is in OpenAI format
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], &ProviderId)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let config = get_provider_config(provider_id);
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
// Future: handle other adapter types like Claude
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ProviderRequest: Send + Sync {
|
||||
/// Extract the model name from the request
|
||||
fn model(&self) -> &str;
|
||||
|
|
@ -54,46 +30,129 @@ pub trait ProviderRequest: Send + Sync {
|
|||
|
||||
/// Convert the request to bytes for transmission
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>>;
|
||||
|
||||
/// Remove a metadata key from the request and return true if the key was present
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool;
|
||||
}
|
||||
|
||||
impl ProviderRequest for ProviderRequestType {
|
||||
fn model(&self) -> &str {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.model(),
|
||||
Self::MessagesRequest(r) => r.model(),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: String) {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.set_model(model),
|
||||
Self::MessagesRequest(r) => r.set_model(model),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_streaming(&self) -> bool {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.is_streaming(),
|
||||
Self::MessagesRequest(r) => r.is_streaming(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.extract_messages_text(),
|
||||
Self::MessagesRequest(r) => r.extract_messages_text(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_recent_user_message(&self) -> Option<String> {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.get_recent_user_message(),
|
||||
Self::MessagesRequest(r) => r.get_recent_user_message(),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.to_bytes(),
|
||||
Self::MessagesRequest(r) => r.to_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.metadata(),
|
||||
Self::MessagesRequest(r) => r.metadata(),
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.remove_metadata_key(key),
|
||||
Self::MessagesRequest(r) => r.remove_metadata_key(key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the client API from a byte slice.
|
||||
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, client_api): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
// Use SupportedApi to determine the appropriate request type
|
||||
match client_api {
|
||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Conversion from one ProviderRequestType to a different ProviderRequestType (SupportedAPIs)
|
||||
impl TryFrom<(ProviderRequestType, &SupportedAPIs)> for ProviderRequestType {
|
||||
type Error = ProviderRequestError;
|
||||
|
||||
fn try_from((request, upstream_api): (ProviderRequestType, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
match (request, upstream_api) {
|
||||
// Same API - no conversion needed, just clone the reference
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
|
||||
// Cross-API conversion - cloning is necessary for transformation
|
||||
(ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let messages_req = MessagesRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
|
||||
(ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(messages_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Error types for provider operations
|
||||
#[derive(Debug)]
|
||||
|
|
@ -113,3 +172,194 @@ impl Error for ProviderRequestError {
|
|||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::anthropic::AnthropicApi::Messages;
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::{ChatCompletionsRequest};
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_openai_request_from_bytes() {
|
||||
let req = json!({
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &api));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.messages.len(), 2);
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_request_from_bytes_with_endpoint() {
|
||||
let req = json!({
|
||||
"model": "claude-3-sonnet",
|
||||
"system": "You are a helpful assistant",
|
||||
"max_tokens": 100,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let endpoint = SupportedAPIs::AnthropicMessagesAPI(Messages);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::MessagesRequest(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet");
|
||||
assert_eq!(r.messages.len(), 1);
|
||||
},
|
||||
_ => panic!("Expected MessagesRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_request_from_bytes_with_endpoint() {
|
||||
let req = json!({
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.messages.len(), 2);
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_request_from_bytes_wrong_endpoint() {
|
||||
let req = json!({
|
||||
"model": "claude-3-sonnet",
|
||||
"system": "You are a helpful assistant",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
// Intentionally use OpenAI endpoint for Anthropic payload
|
||||
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
// Should parse as ChatCompletionsRequest, not error
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet");
|
||||
assert_eq!(r.messages.len(), 1);
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_messages_to_v1_chat_completions_roundtrip() {
|
||||
let anthropic_req = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single("You are a helpful assistant".to_string())),
|
||||
messages: vec![
|
||||
crate::apis::anthropic::MessagesMessage {
|
||||
role: crate::apis::anthropic::MessagesRole::User,
|
||||
content: crate::apis::anthropic::MessagesMessageContent::Single("Hello!".to_string()),
|
||||
}
|
||||
],
|
||||
max_tokens: 128,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(1.0),
|
||||
top_k: None,
|
||||
stream: Some(false),
|
||||
stop_sequences: Some(vec!["\n".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone()).expect("Anthropic->OpenAI conversion failed");
|
||||
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req).expect("OpenAI->Anthropic conversion failed");
|
||||
|
||||
assert_eq!(anthropic_req.model, anthropic_req2.model);
|
||||
// Compare system prompt text if present
|
||||
assert_eq!(
|
||||
anthropic_req.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }),
|
||||
anthropic_req2.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None })
|
||||
);
|
||||
assert_eq!(anthropic_req.messages[0].role, anthropic_req2.messages[0].role);
|
||||
// Compare message content text if present
|
||||
assert_eq!(
|
||||
anthropic_req.messages[0].content.extract_text(),
|
||||
anthropic_req2.messages[0].content.extract_text()
|
||||
);
|
||||
assert_eq!(anthropic_req.max_tokens, anthropic_req2.max_tokens);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
|
||||
|
||||
let openai_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(1.0),
|
||||
max_tokens: Some(128),
|
||||
stream: Some(false),
|
||||
stop: Some(vec!["\n".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone()).expect("OpenAI->Anthropic conversion failed");
|
||||
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req).expect("Anthropic->OpenAI conversion failed");
|
||||
|
||||
assert_eq!(openai_req.model, openai_req2.model);
|
||||
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
|
||||
assert_eq!(openai_req.messages[0].content.extract_text(), openai_req2.messages[0].content.extract_text());
|
||||
assert_eq!(openai_req.max_tokens, openai_req2.max_tokens);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,76 +1,37 @@
|
|||
use crate::providers::id::ProviderId;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::convert::TryFrom;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::apis::OpenAISseIter;
|
||||
use crate::providers::id::ProviderId;
|
||||
use crate::providers::adapters::{get_provider_config, AdapterType};
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::anthropic::MessagesResponse;
|
||||
|
||||
/// Trait for token usage information
|
||||
pub trait TokenUsage {
|
||||
fn completion_tokens(&self) -> usize;
|
||||
fn prompt_tokens(&self) -> usize;
|
||||
fn total_tokens(&self) -> usize;
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProviderResponseType {
|
||||
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||
//MessagesResponse(MessagesResponse),
|
||||
MessagesResponse(MessagesResponse),
|
||||
}
|
||||
|
||||
pub enum ProviderStreamResponseIter {
|
||||
ChatCompletionsStream(OpenAISseIter<std::vec::IntoIter<String>>),
|
||||
//MessagesStream(AnthropicSseIter<std::vec::IntoIter<String>>),
|
||||
#[derive(Serialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProviderStreamResponseType {
|
||||
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
|
||||
MessagesStreamEvent(MessagesStreamEvent),
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, provider_id): (&[u8], ProviderId)) -> Result<Self, Self::Error> {
|
||||
let config = get_provider_config(&provider_id);
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(chat_completions_response))
|
||||
}
|
||||
// Future: handle other adapter types like Claude
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], &ProviderId)> for ProviderStreamResponseIter {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let config = get_provider_config(provider_id);
|
||||
|
||||
// Parse SSE (Server-Sent Events) streaming data - protocol layer
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
// Delegate to OpenAI-specific iterator implementation
|
||||
let sse_container = SseStreamIter::new(lines.into_iter());
|
||||
let iter = crate::apis::openai::OpenAISseIter::new(sse_container);
|
||||
Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter))
|
||||
}
|
||||
// Future: AdapterType::Claude => {
|
||||
// let sse_container = SseStreamIter::new(lines.into_iter());
|
||||
// let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container);
|
||||
// Ok(ProviderStreamResponseIter::MessagesStream(iter))
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl Iterator for ProviderStreamResponseIter {
|
||||
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self {
|
||||
ProviderStreamResponseIter::ChatCompletionsStream(iter) => iter.next(),
|
||||
// Future: ProviderStreamResponseIter::MessagesStream(iter) => iter.next(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub trait ProviderResponse: Send + Sync {
|
||||
/// Get usage information if available - returns dynamic trait object
|
||||
fn usage(&self) -> Option<&dyn TokenUsage>;
|
||||
|
|
@ -81,6 +42,22 @@ pub trait ProviderResponse: Send + Sync {
|
|||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for ProviderResponseType {
|
||||
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(),
|
||||
ProviderResponseType::MessagesResponse(resp) => resp.usage(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(),
|
||||
ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ProviderStreamResponse: Send + Sync {
|
||||
/// Get the content delta for this chunk
|
||||
fn content_delta(&self) -> Option<&str>;
|
||||
|
|
@ -90,16 +67,313 @@ pub trait ProviderStreamResponse: Send + Sync {
|
|||
|
||||
/// Get role information if available
|
||||
fn role(&self) -> Option<&str>;
|
||||
|
||||
/// Get event type for SSE streaming (used by Anthropic)
|
||||
fn event_type(&self) -> Option<&str>;
|
||||
}
|
||||
|
||||
impl ProviderStreamResponse for ProviderStreamResponseType {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(),
|
||||
ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.content_delta(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(),
|
||||
ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.is_final(),
|
||||
}
|
||||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(),
|
||||
ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.role(),
|
||||
}
|
||||
}
|
||||
|
||||
fn event_type(&self) -> Option<&str> {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(_resp) => None, // OpenAI doesn't use event types
|
||||
ProviderStreamResponseType::MessagesStreamEvent(resp) => resp.event_type(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE EVENT CONTAINER
|
||||
// ============================================================================
|
||||
|
||||
/// Represents a single Server-Sent Event with the complete wire format
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SseEvent {
|
||||
#[serde(rename = "data")]
|
||||
pub data: Option<String>, // The JSON payload after "data: "
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub event: Option<String>, // Optional event type (e.g., "message_start", "content_block_delta")
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
|
||||
}
|
||||
|
||||
impl SseEvent {
|
||||
/// Check if this event represents the end of the stream
|
||||
pub fn is_done(&self) -> bool {
|
||||
self.data == Some("[DONE]".into())
|
||||
}
|
||||
|
||||
/// Check if this event should be skipped during processing
|
||||
/// This includes ping messages and other provider-specific events that don't contain content
|
||||
pub fn should_skip(&self) -> bool {
|
||||
// Skip ping messages (commonly used by providers for connection keep-alive)
|
||||
self.data == Some(r#"{"type": "ping"}"#.into())
|
||||
}
|
||||
|
||||
/// Check if this is an event-only SSE event (no data payload)
|
||||
pub fn is_event_only(&self) -> bool {
|
||||
self.event.is_some() && self.data.is_none()
|
||||
}
|
||||
|
||||
/// Get the parsed provider response if available
|
||||
pub fn provider_response(&self) -> Result<&dyn ProviderStreamResponse, std::io::Error> {
|
||||
self.provider_stream_response.as_ref()
|
||||
.map(|resp| resp as &dyn ProviderStreamResponse)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(std::io::ErrorKind::NotFound, "Provider response not found")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl FromStr for SseEvent {
|
||||
type Err = SseParseError;
|
||||
|
||||
fn from_str(line: &str) -> Result<Self, Self::Err> {
|
||||
if line.starts_with("data: ") {
|
||||
let data: String = line[6..].to_string(); // Remove "data: " prefix
|
||||
if data.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty data field is not a valid SSE event".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(SseEvent {
|
||||
data: Some(data),
|
||||
event: None,
|
||||
raw_line: line.to_string(),
|
||||
sse_transform_buffer: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else if line.starts_with("event: ") { //used by Anthropic
|
||||
let event_type = line[7..].to_string();
|
||||
if event_type.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty event field is not a valid SSE event".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(SseEvent {
|
||||
data: None,
|
||||
event: Some(event_type),
|
||||
raw_line: line.to_string(),
|
||||
sse_transform_buffer: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else {
|
||||
Err(SseParseError {
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", line),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SseEvent {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.sse_transform_buffer)
|
||||
}
|
||||
}
|
||||
|
||||
// Into implementation to convert SseEvent to bytes for response buffer
|
||||
impl Into<Vec<u8>> for SseEvent {
|
||||
fn into(self) -> Vec<u8> {
|
||||
format!("{}\n\n", self.sse_transform_buffer).into_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// --- Response transformation logic for client API compatibility ---
|
||||
impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let upstream_api = provider_id.compatible_api_for_client(client_api);
|
||||
match (&upstream_api, client_api) {
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
|
||||
}
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::MessagesResponse(resp))
|
||||
}
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to OpenAI ChatCompletions format using the transformer
|
||||
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
|
||||
}
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to Anthropic Messages format using the transformer
|
||||
let messages_resp: MessagesResponse = openai_resp.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
|
||||
Ok(ProviderResponseType::MessagesResponse(messages_resp))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream response transformation logic for client API compatibility
|
||||
impl TryFrom<(&[u8], &SupportedAPIs, &SupportedAPIs)> for ProviderStreamResponseType {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((bytes, client_api, upstream_api): (&[u8], &SupportedAPIs, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
match (upstream_api, client_api) {
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(resp))
|
||||
}
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?;
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(resp))
|
||||
}
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let anthropic_resp: crate::apis::anthropic::MessagesStreamEvent = serde_json::from_slice(bytes)?;
|
||||
|
||||
// Transform to OpenAI ChatCompletions stream format using the transformer
|
||||
let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = anthropic_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::ChatCompletionsStreamResponse(chat_resp))
|
||||
}
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?;
|
||||
|
||||
// Transform to Anthropic Messages stream format using the transformer
|
||||
let messages_resp: crate::apis::anthropic::MessagesStreamEvent = openai_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::MessagesStreamEvent(messages_resp))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TryFrom implementation to convert raw bytes to SseEvent with parsed provider response
|
||||
impl TryFrom<(SseEvent, &SupportedAPIs, &SupportedAPIs)> for SseEvent {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIs, &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
// Create a new transformed event based on the original
|
||||
let mut transformed_event = sse_event;
|
||||
|
||||
// If not [DONE] and has data, parse the data as a provider stream response (business logic layer)
|
||||
if !transformed_event.is_done() && transformed_event.data.is_some() {
|
||||
let data_str = transformed_event.data.as_ref().unwrap();
|
||||
let data_bytes = data_str.as_bytes();
|
||||
let transformed_response = ProviderStreamResponseType::try_from((data_bytes, client_api, upstream_api))?;
|
||||
let transformed_json = serde_json::to_string(&transformed_response)?;
|
||||
transformed_event.sse_transform_buffer = format!("data: {}\n\n", transformed_json);
|
||||
transformed_event.provider_stream_response = Some(transformed_response);
|
||||
}
|
||||
|
||||
match (client_api, upstream_api) {
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
// No transformation needed
|
||||
}
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
// No transformation needed
|
||||
}
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
if let Some(provider_response) = &transformed_event.provider_stream_response {
|
||||
if let Some(event_type) = provider_response.event_type() {
|
||||
// This ensures the required Anthropic sequence: MessageStart → ContentBlockStart → ContentBlockDelta(s)
|
||||
if event_type == "message_start" {
|
||||
let content_block_start_json = serde_json::json!({
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {
|
||||
"type": "text",
|
||||
"text": ""
|
||||
}
|
||||
});
|
||||
// Format as proper SSE: MessageStart first, then ContentBlockStart
|
||||
transformed_event.sse_transform_buffer = format!(
|
||||
"event: {}\n{}\nevent: content_block_start\ndata: {}\n\n",
|
||||
event_type,
|
||||
transformed_event.sse_transform_buffer,
|
||||
content_block_start_json,
|
||||
);
|
||||
} else if event_type == "message_delta" {
|
||||
let content_block_stop_json = serde_json::json!({
|
||||
"type": "content_block_stop",
|
||||
"index": 0
|
||||
});
|
||||
// Format as proper SSE: ContentBlockStop first, then MessageDelta
|
||||
transformed_event.sse_transform_buffer = format!(
|
||||
"event: content_block_stop\ndata: {}\n\nevent: {}\n{}",
|
||||
content_block_stop_json,
|
||||
event_type,
|
||||
transformed_event.sse_transform_buffer
|
||||
);
|
||||
} else {
|
||||
transformed_event.sse_transform_buffer = format!("event: {}\n{}", event_type, transformed_event.sse_transform_buffer);
|
||||
}
|
||||
}
|
||||
// If event_type is None, we just keep the data line as-is without an event line
|
||||
// This handles cases where the transformation might not produce a valid event type
|
||||
}
|
||||
}
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
if transformed_event.is_event_only() && transformed_event.event.is_some() {
|
||||
transformed_event.sse_transform_buffer = format!("\n"); // suppress the event upstream for OpenAI
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(transformed_event)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SseParseError {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl fmt::Display for SseParseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "SSE parse error: {}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for SseParseError {}
|
||||
|
||||
// ============================================================================
|
||||
// GENERIC SSE STREAMING ITERATOR (Container Only)
|
||||
// ============================================================================
|
||||
|
||||
/// Generic SSE (Server-Sent Events) streaming iterator container
|
||||
/// This is just a simple wrapper - actual Iterator implementation is delegated to provider-specific modules
|
||||
/// Parses raw SSE lines into SseEvent objects
|
||||
pub struct SseStreamIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
|
|
@ -118,35 +392,45 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
// TryFrom implementation to parse bytes into SseStreamIter
|
||||
impl TryFrom<&[u8]> for SseStreamIter<std::vec::IntoIter<String>> {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
impl ProviderResponse for ProviderResponseType {
|
||||
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(),
|
||||
// Future: ProviderResponseType::MessagesResponse(resp) => resp.usage(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(),
|
||||
// Future: ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(),
|
||||
}
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
Ok(SseStreamIter::new(lines.into_iter()))
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Send + Sync for the enum to match the original trait requirements
|
||||
unsafe impl Send for ProviderStreamResponseIter {}
|
||||
unsafe impl Sync for ProviderStreamResponseIter {}
|
||||
impl<I> Iterator for SseStreamIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
type Item = SseEvent;
|
||||
|
||||
/// Trait for token usage information
|
||||
pub trait TokenUsage {
|
||||
fn completion_tokens(&self) -> usize;
|
||||
fn prompt_tokens(&self) -> usize;
|
||||
fn total_tokens(&self) -> usize;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for line in &mut self.lines {
|
||||
let line_str = line.as_ref();
|
||||
|
||||
// Try to parse as either data: or event: line
|
||||
if let Ok(event) = line_str.parse::<SseEvent>() {
|
||||
// For data: lines, check if this is the [DONE] marker - if so, end the stream
|
||||
if event.data.is_some() && event.is_done() {
|
||||
return None;
|
||||
}
|
||||
// For data: lines, skip events that should be filtered at the transport layer
|
||||
if event.data.is_some() && event.should_skip() {
|
||||
continue;
|
||||
}
|
||||
return Some(event);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProviderResponseError {
|
||||
pub message: String,
|
||||
|
|
@ -165,3 +449,331 @@ impl Error for ProviderResponseError {
|
|||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::providers::id::ProviderId;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_openai_response_from_bytes() {
|
||||
let resp = json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": { "role": "assistant", "content": "Hello!" },
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 },
|
||||
"system_fingerprint": null
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::ChatCompletionsResponse(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.choices.len(), 1);
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsResponse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_response_from_bytes() {
|
||||
let resp = json!({
|
||||
"id": "msg_01ABC123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{ "type": "text", "text": "Hello! How can I help you today?" }
|
||||
],
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"stop_reason": "end_turn",
|
||||
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Anthropic));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::MessagesResponse(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(r.content.len(), 1);
|
||||
},
|
||||
_ => panic!("Expected MessagesResponse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_response_from_bytes_with_openai_provider() {
|
||||
// OpenAI provider receives OpenAI response but client expects Anthropic format
|
||||
// Upstream API = OpenAI, Client API = Anthropic -> parse OpenAI, convert to Anthropic
|
||||
let resp = json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": { "role": "assistant", "content": "Hello! How can I help you today?" },
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 }
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::MessagesResponse(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.usage.input_tokens, 10);
|
||||
assert_eq!(r.usage.output_tokens, 25);
|
||||
},
|
||||
_ => panic!("Expected MessagesResponse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_response_from_bytes_with_claude_provider() {
|
||||
// Claude provider using OpenAI-compatible API returns OpenAI format response
|
||||
// Client API = OpenAI, Provider = Anthropic -> Anthropic returns OpenAI format via their compatible API
|
||||
let resp = json!({
|
||||
"id": "chatcmpl-01ABC123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 25,
|
||||
"total_tokens": 35
|
||||
}
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Anthropic));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::ChatCompletionsResponse(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(r.usage.prompt_tokens, 10);
|
||||
assert_eq!(r.usage.completion_tokens, 25);
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsResponse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_parsing() {
|
||||
// Test valid SSE data line
|
||||
let line = "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n";
|
||||
let event: Result<SseEvent, _> = line.parse();
|
||||
assert!(event.is_ok());
|
||||
let event = event.unwrap();
|
||||
assert_eq!(event.data, Some("{\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n".to_string()));
|
||||
|
||||
// Test conversion back to line using Display trait
|
||||
let wire_format = event.to_string();
|
||||
assert_eq!(wire_format, "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n");
|
||||
|
||||
// Test [DONE] marker - should be valid SSE event
|
||||
let done_line = "data: [DONE]";
|
||||
let done_result: Result<SseEvent, _> = done_line.parse();
|
||||
assert!(done_result.is_ok());
|
||||
let done_event = done_result.unwrap();
|
||||
assert_eq!(done_event.data, Some("[DONE]".to_string()));
|
||||
assert!(done_event.is_done()); // Test the helper method
|
||||
|
||||
// Test non-DONE event
|
||||
assert!(!event.is_done());
|
||||
|
||||
// Test empty data - should return error
|
||||
let empty_line = "data: ";
|
||||
let empty_result: Result<SseEvent, _> = empty_line.parse();
|
||||
assert!(empty_result.is_err());
|
||||
|
||||
// Test non-data line - should return error
|
||||
let comment_line = ": this is a comment";
|
||||
let comment_result: Result<SseEvent, _> = comment_line.parse();
|
||||
assert!(comment_result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_serde() {
|
||||
// Test serialization and deserialization with serde
|
||||
let event = SseEvent {
|
||||
data: Some(r#"{"id":"test","object":"chat.completion.chunk"}"#.to_string()),
|
||||
event: None,
|
||||
raw_line: r#"data: {"id":"test","object":"chat.completion.chunk"}
|
||||
|
||||
"#.to_string(),
|
||||
sse_transform_buffer: r#"data: {"id":"test","object":"chat.completion.chunk"}
|
||||
|
||||
"#.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
|
||||
// Test JSON serialization - raw_line should be skipped
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains("test"));
|
||||
assert!(json.contains("chat.completion.chunk"));
|
||||
assert!(!json.contains("raw_line")); // Should be excluded from serialization
|
||||
|
||||
// Test JSON deserialization
|
||||
let deserialized: SseEvent = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.data, event.data);
|
||||
assert_eq!(deserialized.raw_line, ""); // Should be empty since it's skipped
|
||||
|
||||
// Test round trip for data field only
|
||||
assert_eq!(event.data, deserialized.data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_should_skip() {
|
||||
// Test ping message should be skipped
|
||||
let ping_event = SseEvent {
|
||||
data: Some(r#"{"type": "ping"}"#.to_string()),
|
||||
event: None,
|
||||
raw_line: r#"data: {"type": "ping"}"#.to_string(),
|
||||
sse_transform_buffer: r#"data: {"type": "ping"}"#.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
assert!(ping_event.should_skip());
|
||||
assert!(!ping_event.is_done());
|
||||
|
||||
// Test normal event should not be skipped
|
||||
let normal_event = SseEvent {
|
||||
data: Some(r#"{"id": "test", "object": "chat.completion.chunk"}"#.to_string()),
|
||||
event: Some("content_block_delta".to_string()),
|
||||
raw_line: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
|
||||
sse_transform_buffer: r#"data: {"id": "test", "object": "chat.completion.chunk"}"#.to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
assert!(!normal_event.should_skip());
|
||||
assert!(!normal_event.is_done());
|
||||
|
||||
// Test [DONE] event should not be skipped (but is handled separately)
|
||||
let done_event = SseEvent {
|
||||
data: Some("[DONE]".to_string()),
|
||||
event: None,
|
||||
raw_line: "data: [DONE]".to_string(),
|
||||
sse_transform_buffer: "data: [DONE]".to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
assert!(!done_event.should_skip());
|
||||
assert!(done_event.is_done());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_stream_iter_filters_ping_messages() {
|
||||
// Create test data with ping messages mixed in
|
||||
let test_lines = vec![
|
||||
"data: {\"id\": \"msg1\", \"object\": \"chat.completion.chunk\"}".to_string(),
|
||||
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
|
||||
"data: {\"id\": \"msg2\", \"object\": \"chat.completion.chunk\"}".to_string(),
|
||||
"data: {\"type\": \"ping\"}".to_string(), // This should be filtered out
|
||||
"data: [DONE]".to_string(), // This should end the stream
|
||||
];
|
||||
|
||||
let mut iter = SseStreamIter::new(test_lines.into_iter());
|
||||
|
||||
// First event should be msg1 (ping filtered out)
|
||||
let event1 = iter.next().unwrap();
|
||||
assert!(event1.data.as_ref().unwrap().contains("msg1"));
|
||||
assert!(!event1.should_skip());
|
||||
|
||||
// Second event should be msg2 (ping filtered out)
|
||||
let event2 = iter.next().unwrap();
|
||||
assert!(event2.data.as_ref().unwrap().contains("msg2"));
|
||||
assert!(!event2.should_skip());
|
||||
|
||||
// Iterator should end at [DONE] (no more events)
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_stream_iter_handles_anthropic_events() {
|
||||
// Create test data with Anthropic-style event/data pairs
|
||||
let test_lines = vec![
|
||||
"event: message_start".to_string(),
|
||||
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\"}}".to_string(),
|
||||
"event: content_block_delta".to_string(),
|
||||
"data: {\"type\":\"content_block_delta\",\"delta\":{\"text\":\"Hello\"}}".to_string(),
|
||||
"data: [DONE]".to_string(),
|
||||
];
|
||||
|
||||
let mut iter = SseStreamIter::new(test_lines.into_iter());
|
||||
|
||||
// First event should be the event: line
|
||||
let event1 = iter.next().unwrap();
|
||||
assert!(event1.is_event_only());
|
||||
assert_eq!(event1.event, Some("message_start".to_string()));
|
||||
assert_eq!(event1.data, None);
|
||||
|
||||
// Second event should be the data: line
|
||||
let event2 = iter.next().unwrap();
|
||||
assert!(!event2.is_event_only());
|
||||
assert_eq!(event2.event, None);
|
||||
assert!(event2.data.as_ref().unwrap().contains("message_start"));
|
||||
|
||||
// Third event should be another event: line
|
||||
let event3 = iter.next().unwrap();
|
||||
assert!(event3.is_event_only());
|
||||
assert_eq!(event3.event, Some("content_block_delta".to_string()));
|
||||
|
||||
// Fourth event should be the content delta data
|
||||
let event4 = iter.next().unwrap();
|
||||
assert!(!event4.is_event_only());
|
||||
assert!(event4.data.as_ref().unwrap().contains("Hello"));
|
||||
|
||||
// Iterator should end at [DONE]
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_stream_response_event_type() {
|
||||
use crate::apis::anthropic::{MessagesStreamEvent, MessagesContentDelta};
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
|
||||
// Test Anthropic event type
|
||||
let anthropic_event = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta { text: "Hello".to_string() },
|
||||
};
|
||||
let provider_type = ProviderStreamResponseType::MessagesStreamEvent(anthropic_event);
|
||||
assert_eq!(provider_type.event_type(), Some("content_block_delta"));
|
||||
|
||||
// Test OpenAI event type (should be None)
|
||||
let openai_event = ChatCompletionsStreamResponse {
|
||||
id: "test".to_string(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: 123456789,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
let provider_type = ProviderStreamResponseType::ChatCompletionsStreamResponse(openai_event);
|
||||
assert_eq!(provider_type.event_type(), None);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,7 +89,6 @@ impl RootContext for FilterContext {
|
|||
);
|
||||
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(
|
||||
self.llm_providers
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -31,14 +31,15 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
|||
)
|
||||
.returning(None)
|
||||
.expect_log(
|
||||
Some(LogLevel::Debug),
|
||||
Some("request received: llm provider hint: default, selected provider: open-ai-gpt-4"),
|
||||
Some(LogLevel::Info),
|
||||
None, // Dynamic request ID - could be context_id or x-request-id
|
||||
)
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-llm-provider"),
|
||||
Some("openai"),
|
||||
)
|
||||
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-api-key"))
|
||||
.expect_replace_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("Authorization"),
|
||||
|
|
@ -193,10 +194,7 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
|
|||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(
|
||||
Some(LogLevel::Trace),
|
||||
Some("||| create_http_context called with context_id: 2 |||"),
|
||||
)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -211,15 +209,19 @@ fn llm_gateway_successful_request_to_open_ai_chat_completions() {
|
|||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None) // Dynamic request ID - REQUEST_BODY_CHUNK
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID - CLIENT_REQUEST_RECEIVED
|
||||
.expect_log(Some(LogLevel::Debug), None) // Dynamic request ID - CLIENT_REQUEST_PAYLOAD
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID - MODEL_RESOLUTION
|
||||
.expect_log(Some(LogLevel::Debug), Some("TOKENIZER: computing token count for model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID - TOKEN_COUNT
|
||||
.expect_metric_record("input_sequence_length", 21)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID - RATELIMIT_CHECK
|
||||
.expect_log(Some(LogLevel::Debug), Some("Checking limit for provider=gpt-4, with selector=Header { key: \"selector-key\", value: \"selector-value\" }, consuming tokens=21"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID - UPSTREAM_TRANSFORM
|
||||
.expect_log(Some(LogLevel::Debug), None) // Dynamic request ID - UPSTREAM_REQUEST_PAYLOAD
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -263,15 +265,19 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
|
|||
incomplete_chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None) // Dynamic request ID - REQUEST_BODY_CHUNK
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(incomplete_chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 13"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID - CLIENT_REQUEST_RECEIVED
|
||||
.expect_log(Some(LogLevel::Debug), None) // Dynamic request ID - CLIENT_REQUEST_PAYLOAD
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID - MODEL_RESOLUTION
|
||||
.expect_log(Some(LogLevel::Debug), Some("TOKENIZER: computing token count for model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID - TOKEN_COUNT
|
||||
.expect_metric_record("input_sequence_length", 13)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=13"#))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID - RATELIMIT_CHECK
|
||||
.expect_log(Some(LogLevel::Debug), Some("Checking limit for provider=gpt-4, with selector=Header { key: \"selector-key\", value: \"selector-value\" }, consuming tokens=13"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID - RATELIMIT_CHECK
|
||||
.expect_log(Some(LogLevel::Debug), Some("[ARCHGW_REQ_ID:NO_REQUEST_ID] UPSTREAM_REQUEST_PAYLOAD: {\"messages\":[{\"role\":\"system\",\"content\":\"Compose a poem that explains the concept of recursion in programming.\"}],\"model\":\"gpt-4\"}"))
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -322,16 +328,18 @@ fn llm_gateway_request_ratelimited() {
|
|||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None) // Dynamic request ID)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("[ARCHGW_REQ_ID:NO_REQUEST_ID] CLIENT_REQUEST_PAYLOAD: {\"messages\": [{\"role\": \"system\",\"content\": \"You are a helpful poetic assistant!, skilled in explaining complex programming concepts with creative flair. Be sure to be concise and to the point.\"},{\"role\": \"user\",\"content\": \"Compose a poem that explains the concept of recursion in programming. Compose a poem that explains the concept of recursion in programming. Compose a poem that explains the concept of recursion in programming. And also summarize it how a 4th graded would understand it. Compose a poem that explains the concept of recursion in programming. And also summarize it how a 4th graded would understand it.\"}],\"model\": \"gpt-4\"}"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("TOKENIZER: computing token count for model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Info), None)// Dynamic request ID)
|
||||
.expect_metric_record("input_sequence_length", 107)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Checking limit for provider=gpt-4, with selector=Header { key: \"selector-key\", value: \"selector-value\" }, consuming tokens=107"))
|
||||
.expect_log(Some(LogLevel::Warn), Some(r#"server error occurred: exceeded limit provider=gpt-4, selector=Header { key: "selector-key", value: "selector-value" }, tokens_used=107"#))
|
||||
.expect_send_local_response(
|
||||
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
|
||||
|
|
@ -376,16 +384,21 @@ fn llm_gateway_request_not_ratelimited() {
|
|||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None) // Dynamic request ID)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||
// Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("[ARCHGW_REQ_ID:NO_REQUEST_ID] CLIENT_REQUEST_PAYLOAD: {\"model\":\"gpt-1\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},{\"role\":\"user\",\"content\":\"Compose a poem that explains the concept of recursion in programming.\"}]}"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("TOKENIZER: computing token count for model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Checking limit for provider=gpt-4, with selector=Header { key: \"selector-key\", value: \"selector-value\" }, consuming tokens=29"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("[ARCHGW_REQ_ID:NO_REQUEST_ID] UPSTREAM_REQUEST_PAYLOAD: {\"messages\":[{\"role\":\"system\",\"content\":\"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},{\"role\":\"user\",\"content\":\"Compose a poem that explains the concept of recursion in programming.\"}],\"model\":\"gpt-4\"}"))
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -423,16 +436,20 @@ fn llm_gateway_override_model_name() {
|
|||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None) // Dynamic request ID)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("[ARCHGW_REQ_ID:NO_REQUEST_ID] CLIENT_REQUEST_PAYLOAD: {\"model\":\"gpt-1\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},{\"role\":\"user\",\"content\":\"Compose a poem that explains the concept of recursion in programming.\"}]}"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("TOKENIZER: computing token count for model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Checking limit for provider=gpt-4, with selector=Header { key: \"selector-key\", value: \"selector-value\" }, consuming tokens=29"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("[ARCHGW_REQ_ID:NO_REQUEST_ID] UPSTREAM_REQUEST_PAYLOAD: {\"messages\":[{\"role\":\"system\",\"content\":\"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},{\"role\":\"user\",\"content\":\"Compose a poem that explains the concept of recursion in programming.\"}],\"model\":\"gpt-4\"}"))
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -470,19 +487,23 @@ fn llm_gateway_override_use_default_model() {
|
|||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None) // Dynamic request ID)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("[ARCHGW_REQ_ID:NO_REQUEST_ID] CLIENT_REQUEST_PAYLOAD: {\"model\":\"gpt-1\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},{\"role\":\"user\",\"content\":\"Compose a poem that explains the concept of recursion in programming.\"}]}"))
|
||||
.expect_log(
|
||||
Some(LogLevel::Info),
|
||||
Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"),
|
||||
None // Dynamic request ID,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("TOKENIZER: computing token count for model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Checking limit for provider=gpt-4, with selector=Header { key: \"selector-key\", value: \"selector-value\" }, consuming tokens=29"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("[ARCHGW_REQ_ID:NO_REQUEST_ID] UPSTREAM_REQUEST_PAYLOAD: {\"messages\":[{\"role\":\"system\",\"content\":\"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},{\"role\":\"user\",\"content\":\"Compose a poem that explains the concept of recursion in programming.\"}],\"model\":\"gpt-4\"}"))
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
@ -520,16 +541,21 @@ fn llm_gateway_override_use_model_name_none() {
|
|||
chat_completions_request_body.len() as i32,
|
||||
true,
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), None) // Dynamic request ID)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): none, model selected: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
// Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("[ARCHGW_REQ_ID:NO_REQUEST_ID] CLIENT_REQUEST_PAYLOAD: {\"model\":\"none\",\"messages\":[{\"role\":\"system\",\"content\":\"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},{\"role\":\"user\",\"content\":\"Compose a poem that explains the concept of recursion in programming.\"}]}"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("TOKENIZER: computing token count for model=gpt-4"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("Checking limit for provider=gpt-4, with selector=Header { key: \"selector-key\", value: \"selector-value\" }, consuming tokens=29"))
|
||||
.expect_log(Some(LogLevel::Info), None) // Dynamic request ID)
|
||||
.expect_log(Some(LogLevel::Debug), Some("[ARCHGW_REQ_ID:NO_REQUEST_ID] UPSTREAM_REQUEST_PAYLOAD: {\"messages\":[{\"role\":\"system\",\"content\":\"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"},{\"role\":\"user\",\"content\":\"Compose a poem that explains the concept of recursion in programming.\"}],\"model\":\"gpt-4\"}"))
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.unwrap();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue