mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
updated tests and added necessary response transformations for Anthropics' message response object
This commit is contained in:
parent
e73a9eb61c
commit
9f3a6f71a3
5 changed files with 509 additions and 55 deletions
|
|
@ -1,10 +1,14 @@
|
|||
use crate::providers::response::TokenUsage;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderStreamResponse, SseStreamIter};
|
||||
use crate::clients::transformer::ExtractText;
|
||||
|
||||
// Enum for all supported Anthropic APIs
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
|
@ -187,6 +191,19 @@ pub enum MessagesContentBlock {
|
|||
},
|
||||
}
|
||||
|
||||
impl ExtractText for Vec<MessagesContentBlock> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.iter()
|
||||
.filter_map(|block| match block {
|
||||
MessagesContentBlock::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessagesImageSource {
|
||||
|
|
@ -221,6 +238,15 @@ pub enum MessagesMessageContent {
|
|||
Blocks(Vec<MessagesContentBlock>),
|
||||
}
|
||||
|
||||
impl ExtractText for MessagesMessageContent {
|
||||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessagesMessageContent::Single(text) => text.clone(),
|
||||
MessagesMessageContent::Blocks(parts) => parts.extract_text()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessagesSystemPrompt {
|
||||
|
|
@ -378,6 +404,27 @@ impl TryFrom<&[u8]> for MessagesRequest {
|
|||
}
|
||||
}
|
||||
|
||||
impl TokenUsage for MessagesResponse {
|
||||
fn completion_tokens(&self) -> usize {
|
||||
self.usage.output_tokens as usize
|
||||
}
|
||||
fn prompt_tokens(&self) -> usize {
|
||||
self.usage.input_tokens as usize
|
||||
}
|
||||
fn total_tokens(&self) -> usize {
|
||||
(self.usage.input_tokens + self.usage.output_tokens) as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl MessagesResponse {
|
||||
pub fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||
Some(self)
|
||||
}
|
||||
pub fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
Some((self.usage.input_tokens as usize, self.usage.output_tokens as usize, (self.usage.input_tokens + self.usage.output_tokens) as usize))
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderRequest for MessagesRequest {
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
|
|
@ -464,6 +511,91 @@ impl MessagesStreamEvent {
|
|||
}
|
||||
}
|
||||
|
||||
impl MessagesRole {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
MessagesRole::User => "user",
|
||||
MessagesRole::Assistant => "assistant",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Anthropic SSE streaming iterator for MessagesStreamEvent
|
||||
pub struct AnthropicSseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
sse_stream: SseStreamIter<I>,
|
||||
}
|
||||
|
||||
impl<I> AnthropicSseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(sse_stream: SseStreamIter<I>) -> Self {
|
||||
Self { sse_stream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for AnthropicSseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn Error + Send + Sync>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for line in &mut self.sse_stream.lines {
|
||||
let line = line.as_ref();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with("data: ") {
|
||||
let data = &line[6..];
|
||||
if data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
// Anthropic-specific parsing of MessagesStreamEvent
|
||||
match serde_json::from_str::<MessagesStreamEvent>(data) {
|
||||
Ok(event) => return Some(Ok(Box::new(event))),
|
||||
Err(e) => return Some(Err(Box::new(e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// Implement ProviderStreamResponse for MessagesStreamEvent
|
||||
impl ProviderStreamResponse for MessagesStreamEvent {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => {
|
||||
if let MessagesContentDelta::TextDelta { text } = delta {
|
||||
Some(text)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
matches!(self, MessagesStreamEvent::MessageStop)
|
||||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessagesStreamEvent::MessageStart { message } => Some(message.role.as_str()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -5,11 +5,10 @@ use std::collections::HashMap;
|
|||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
|
||||
|
||||
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage, SseStreamIter};
|
||||
use super::ApiDefinition;
|
||||
use crate::clients::transformer::{ExtractText};
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI API ENUMERATION
|
||||
|
|
@ -174,6 +173,28 @@ pub enum MessageContent {
|
|||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
// Content Extraction
|
||||
impl ExtractText for MessageContent {
|
||||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.extract_text()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtractText for Vec<ContentPart> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for MessageContent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
|
|
|||
|
|
@ -44,8 +44,6 @@
|
|||
|
||||
use serde_json::Value;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
// Import centralized types
|
||||
use crate::apis::*;
|
||||
use super::TransformError;
|
||||
|
||||
|
|
@ -61,7 +59,7 @@ const DEFAULT_MAX_TOKENS: u32 = 4096;
|
|||
// ============================================================================
|
||||
|
||||
/// Trait for extracting text content from various types
|
||||
trait ExtractText {
|
||||
pub trait ExtractText {
|
||||
fn extract_text(&self) -> String;
|
||||
}
|
||||
|
||||
|
|
@ -541,40 +539,6 @@ impl Into<Role> for MessagesRole {
|
|||
}
|
||||
}
|
||||
|
||||
// Content Extraction
|
||||
impl ExtractText for MessageContent {
|
||||
fn extract_text(&self) -> String {
|
||||
match self {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.extract_text()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtractText for Vec<ContentPart> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtractText for Vec<MessagesContentBlock> {
|
||||
fn extract_text(&self) -> String {
|
||||
self.iter()
|
||||
.filter_map(|block| match block {
|
||||
MessagesContentBlock::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Content Utilities
|
||||
impl ContentUtils<ToolCall> for Vec<MessagesContentBlock> {
|
||||
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError> {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::clients::endpoints::SupportedApi;
|
||||
|
|
@ -125,3 +124,195 @@ impl Error for ProviderRequestError {
|
|||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ...existing code...
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::endpoints::SupportedApi;
|
||||
use crate::apis::anthropic::AnthropicApi::Messages;
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::{ChatCompletionsRequest};
|
||||
use crate::clients::transformer::ExtractText;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_openai_request_from_bytes() {
|
||||
let req = json!({
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let result = ProviderRequestType::try_from(bytes.as_slice());
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.messages.len(), 2);
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_request_from_bytes_with_endpoint() {
|
||||
let req = json!({
|
||||
"model": "claude-3-sonnet",
|
||||
"system": "You are a helpful assistant",
|
||||
"max_tokens": 100,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let endpoint = SupportedApi::Anthropic(Messages);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::MessagesRequest(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet");
|
||||
assert_eq!(r.messages.len(), 1);
|
||||
},
|
||||
_ => panic!("Expected MessagesRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_request_from_bytes_with_endpoint() {
|
||||
let req = json!({
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let endpoint = SupportedApi::OpenAI(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.messages.len(), 2);
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_request_from_bytes_wrong_endpoint() {
|
||||
let req = json!({
|
||||
"model": "claude-3-sonnet",
|
||||
"system": "You are a helpful assistant",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
// Intentionally use OpenAI endpoint for Anthropic payload
|
||||
let endpoint = SupportedApi::OpenAI(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
// Should parse as ChatCompletionsRequest, not error
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ChatCompletionsRequest(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet");
|
||||
assert_eq!(r.messages.len(), 1);
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_messages_to_v1_chat_completions_roundtrip() {
|
||||
let anthropic_req = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
system: Some(crate::apis::anthropic::MessagesSystemPrompt::Single("You are a helpful assistant".to_string())),
|
||||
messages: vec![
|
||||
crate::apis::anthropic::MessagesMessage {
|
||||
role: crate::apis::anthropic::MessagesRole::User,
|
||||
content: crate::apis::anthropic::MessagesMessageContent::Single("Hello!".to_string()),
|
||||
}
|
||||
],
|
||||
max_tokens: 128,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(1.0),
|
||||
top_k: None,
|
||||
stream: Some(false),
|
||||
stop_sequences: Some(vec!["\n".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let openai_req = ChatCompletionsRequest::try_from(anthropic_req.clone()).expect("Anthropic->OpenAI conversion failed");
|
||||
let anthropic_req2 = AnthropicMessagesRequest::try_from(openai_req).expect("OpenAI->Anthropic conversion failed");
|
||||
|
||||
assert_eq!(anthropic_req.model, anthropic_req2.model);
|
||||
// Compare system prompt text if present
|
||||
assert_eq!(
|
||||
anthropic_req.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None }),
|
||||
anthropic_req2.system.as_ref().and_then(|s| match s { crate::apis::anthropic::MessagesSystemPrompt::Single(t) => Some(t), _ => None })
|
||||
);
|
||||
assert_eq!(anthropic_req.messages[0].role, anthropic_req2.messages[0].role);
|
||||
// Compare message content text if present
|
||||
assert_eq!(
|
||||
anthropic_req.messages[0].content.extract_text(),
|
||||
anthropic_req2.messages[0].content.extract_text()
|
||||
);
|
||||
assert_eq!(anthropic_req.max_tokens, anthropic_req2.max_tokens);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v1_chat_completions_to_v1_messages_roundtrip() {
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
|
||||
|
||||
let openai_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are a helpful assistant".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(1.0),
|
||||
max_tokens: Some(128),
|
||||
stream: Some(false),
|
||||
stop: Some(vec!["\n".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let anthropic_req = AnthropicMessagesRequest::try_from(openai_req.clone()).expect("OpenAI->Anthropic conversion failed");
|
||||
let openai_req2 = ChatCompletionsRequest::try_from(anthropic_req).expect("Anthropic->OpenAI conversion failed");
|
||||
|
||||
assert_eq!(openai_req.model, openai_req2.model);
|
||||
assert_eq!(openai_req.messages[0].role, openai_req2.messages[0].role);
|
||||
assert_eq!(openai_req.messages[0].content.extract_text(), openai_req2.messages[0].content.extract_text());
|
||||
assert_eq!(openai_req.max_tokens, openai_req2.max_tokens);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,15 +9,19 @@ use crate::apis::OpenAISseIter;
|
|||
use crate::clients::endpoints::SupportedApi;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use crate::apis::anthropic::MessagesResponse;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub enum ProviderResponseType {
|
||||
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||
//MessagesResponse(MessagesResponse),
|
||||
MessagesResponse(MessagesResponse),
|
||||
}
|
||||
|
||||
use crate::apis::anthropic::AnthropicSseIter;
|
||||
|
||||
pub enum ProviderStreamResponseIter {
|
||||
ChatCompletionsStream(OpenAISseIter<std::vec::IntoIter<String>>),
|
||||
//MessagesStream(AnthropicSseIter<std::vec::IntoIter<String>>),
|
||||
MessagesStream(AnthropicSseIter<std::vec::IntoIter<String>>),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -33,11 +37,21 @@ impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderResponseType {
|
|||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
|
||||
}
|
||||
(SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => {
|
||||
// If you add a MessagesResponse variant, return it here. For now, just error or serialize as needed.
|
||||
Err(std::io::Error::new(std::io::ErrorKind::Other, "Anthropic response variant not implemented"))
|
||||
(SupportedApi::Anthropic(_), SupportedApi::Anthropic(_)) => {
|
||||
let resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::MessagesResponse(resp))
|
||||
}
|
||||
(SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => {
|
||||
let resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::MessagesResponse(resp))
|
||||
}
|
||||
(SupportedApi::Anthropic(_), SupportedApi::OpenAI(_)) => {
|
||||
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
|
||||
}
|
||||
_ => Err(std::io::Error::new(std::io::ErrorKind::Other, "Unsupported response transformation")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -55,11 +69,27 @@ impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderStreamResponseIter
|
|||
let iter = crate::apis::openai::OpenAISseIter::new(sse_container);
|
||||
Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter))
|
||||
}
|
||||
(SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => {
|
||||
// TODO: Implement streaming transformation from OpenAI to Anthropic
|
||||
Err("Anthropic streaming response variant not implemented".into())
|
||||
(SupportedApi::Anthropic(_), SupportedApi::Anthropic(_)) => {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter());
|
||||
let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container);
|
||||
Ok(ProviderStreamResponseIter::MessagesStream(iter))
|
||||
}
|
||||
(SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter());
|
||||
let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container);
|
||||
Ok(ProviderStreamResponseIter::MessagesStream(iter))
|
||||
}
|
||||
(SupportedApi::Anthropic(_), SupportedApi::OpenAI(_)) => {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter());
|
||||
let iter = crate::apis::openai::OpenAISseIter::new(sse_container);
|
||||
Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter))
|
||||
}
|
||||
_ => Err("Unsupported streaming response transformation".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -70,12 +100,10 @@ impl Iterator for ProviderStreamResponseIter {
|
|||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self {
|
||||
ProviderStreamResponseIter::ChatCompletionsStream(iter) => iter.next(),
|
||||
// Future: ProviderStreamResponseIter::MessagesStream(iter) => iter.next(),
|
||||
ProviderStreamResponseIter::MessagesStream(iter) => iter.next(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub trait ProviderResponse: Send + Sync {
|
||||
/// Get usage information if available - returns dynamic trait object
|
||||
fn usage(&self) -> Option<&dyn TokenUsage>;
|
||||
|
|
@ -128,14 +156,14 @@ impl ProviderResponse for ProviderResponseType {
|
|||
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(),
|
||||
// Future: ProviderResponseType::MessagesResponse(resp) => resp.usage(),
|
||||
ProviderResponseType::MessagesResponse(resp) => resp.usage(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(),
|
||||
// Future: ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(),
|
||||
ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -170,3 +198,121 @@ impl Error for ProviderResponseError {
|
|||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::endpoints::SupportedApi;
|
||||
use crate::providers::id::ProviderId;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_openai_response_from_bytes() {
|
||||
let resp = json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": { "role": "assistant", "content": "Hello!" },
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 },
|
||||
"system_fingerprint": null
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedApi::OpenAI(OpenAIApi::ChatCompletions), &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::ChatCompletionsResponse(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.choices.len(), 1);
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsResponse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_response_from_bytes() {
|
||||
let resp = json!({
|
||||
"id": "msg_01ABC123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{ "type": "text", "text": "Hello! How can I help you today?" }
|
||||
],
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"stop_reason": "end_turn",
|
||||
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedApi::Anthropic(AnthropicApi::Messages), &ProviderId::Claude));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::MessagesResponse(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(r.content.len(), 1);
|
||||
},
|
||||
_ => panic!("Expected MessagesResponse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_response_from_bytes_with_openai_provider() {
|
||||
// Simulate Anthropic response with OpenAI provider (should parse as MessagesResponse)
|
||||
let resp = json!({
|
||||
"id": "msg_01ABC123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{ "type": "text", "text": "Hello! How can I help you today?" }
|
||||
],
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"stop_reason": "end_turn",
|
||||
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedApi::Anthropic(AnthropicApi::Messages), &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::MessagesResponse(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet-20240229");
|
||||
},
|
||||
_ => panic!("Expected MessagesResponse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_response_from_bytes_with_claude_provider() {
|
||||
// Simulate OpenAI response with Claude provider (should parse as ChatCompletionsResponse)
|
||||
let resp = json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": { "role": "assistant", "content": "Hello!" },
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 },
|
||||
"system_fingerprint": null
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedApi::OpenAI(OpenAIApi::ChatCompletions), &ProviderId::Claude));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::ChatCompletionsResponse(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsResponse variant"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue