updating the implementation of /v1/chat/completions to use the generi… (#548)

* updating the implementation of /v1/chat/completions to use the generic provider interfaces

* saving changes, although we will need a small re-factor after this as well

* more refactoring changes, getting close

* more refactoring changes to avoid unecessary re-direction and duplication

* more clean up

* more refactoring

* more refactoring to clean code and make stream_context.rs work

* removing unecessary trait implemenations

* some more clean-up

* fixed bugs

* fixing test cases, and making sure all references to the ChatCOmpletions* objects point to the new types

* refactored changes to support enum dispatch

* removed the dependency on try_streaming_from_bytes into a try_from trait implementation

* updated readme based on new usage

* updated code based on code review comments

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-2.local>
Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-4.local>
This commit is contained in:
Salman Paracha 2025-08-20 12:55:29 -07:00 committed by GitHub
parent 1fdde8181a
commit 89ab51697a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1044 additions and 972 deletions

View file

@ -2,7 +2,13 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use std::fmt::Display;
use thiserror::Error;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage, SseStreamIter};
use super::ApiDefinition;
// ============================================================================
@ -115,8 +121,8 @@ pub enum Role {
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub content: MessageContent,
pub role: Role,
pub content: MessageContent,
pub name: Option<String>,
/// Tool calls made by the assistant (only present for assistant role)
pub tool_calls: Option<Vec<ToolCall>>,
@ -124,8 +130,6 @@ pub struct Message {
pub tool_call_id: Option<String>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ResponseMessage {
@ -170,6 +174,28 @@ pub enum MessageContent {
Parts(Vec<ContentPart>),
}
impl Display for MessageContent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageContent::Text(text) => write!(f, "{}", text),
MessageContent::Parts(parts) => {
let text_parts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
ContentPart::ImageUrl { .. } => {
// skip image URLs or their data in text representation
None
}
})
.collect();
let combined_text = text_parts.join("\n");
write!(f, "{}", combined_text)
}
}
}
}
/// Individual content part within a message (text or image)
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type")]
@ -424,6 +450,239 @@ pub struct StreamOptions {
pub include_usage: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDetail {
pub id: String,
pub object: String,
pub created: usize,
pub owned_by: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelObject {
#[serde(rename = "list")]
List,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Models {
pub object: ModelObject,
pub data: Vec<ModelDetail>,
}
// Error type for streaming operations
#[derive(Debug, thiserror::Error)]
pub enum OpenAIStreamError {
#[error("JSON parsing error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("UTF-8 parsing error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[error("Invalid streaming data: {0}")]
InvalidStreamingData(String),
}
#[derive(Debug, Error)]
pub enum OpenAIError {
#[error("json error: {0}")]
JsonParseError(#[from] serde_json::Error),
#[error("utf8 parsing error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[error("invalid streaming data err {source}, data: {data}")]
InvalidStreamingData {
source: serde_json::Error,
data: String,
},
#[error("unsupported provider: {provider}")]
UnsupportedProvider { provider: String },
}
// ============================================================================
/// Trait Implementations
/// ===========================================================================
/// Parameterized conversion for ChatCompletionsRequest
impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIStreamError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
}
}
/// Parameterized conversion for ChatCompletionsResponse
impl TryFrom<&[u8]> for ChatCompletionsResponse {
type Error = OpenAIStreamError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
}
}
/// Implementation of TokenUsage for OpenAI Usage type
impl TokenUsage for Usage {
fn completion_tokens(&self) -> usize {
self.completion_tokens as usize
}
fn prompt_tokens(&self) -> usize {
self.prompt_tokens as usize
}
fn total_tokens(&self) -> usize {
self.total_tokens as usize
}
}
/// Implementation of ProviderRequest for ChatCompletionsRequest
impl ProviderRequest for ChatCompletionsRequest {
fn model(&self) -> &str {
&self.model
}
fn set_model(&mut self, model: String) {
self.model = model;
}
fn is_streaming(&self) -> bool {
self.stream.unwrap_or_default()
}
fn extract_messages_text(&self) -> String {
self.messages.iter().fold(String::new(), |acc, m| {
acc + " " + &match &m.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => parts.iter().map(|part| match part {
ContentPart::Text { text } => text.clone(),
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
}).collect::<Vec<_>>().join(" ")
}
})
}
fn get_recent_user_message(&self) -> Option<String> {
self.messages.last().and_then(|msg| {
match &msg.content {
MessageContent::Text(text) => Some(text.clone()),
MessageContent::Parts(_) => None, // No user message in parts
}
})
}
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
serde_json::to_vec(&self).map_err(|e| ProviderRequestError {
message: format!("Failed to serialize OpenAI request: {}", e),
source: Some(Box::new(e)),
})
}
}
/// Implementation of ProviderResponse for ChatCompletionsResponse
impl ProviderResponse for ChatCompletionsResponse {
fn usage(&self) -> Option<&dyn TokenUsage> {
Some(&self.usage)
}
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
Some((
self.usage.prompt_tokens(),
self.usage.completion_tokens(),
self.usage.total_tokens(),
))
}
}
// ============================================================================
// OPENAI SSE STREAMING ITERATOR
// ============================================================================
/// OpenAI-specific SSE streaming iterator
/// Handles OpenAI's specific SSE format and ChatCompletionsStreamResponse parsing
pub struct OpenAISseIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
sse_stream: SseStreamIter<I>,
}
impl<I> OpenAISseIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
pub fn new(sse_stream: SseStreamIter<I>) -> Self {
Self { sse_stream }
}
}
impl<I> Iterator for OpenAISseIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
fn next(&mut self) -> Option<Self::Item> {
for line in &mut self.sse_stream.lines {
let line = line.as_ref();
if line.is_empty() {
continue;
}
if line.starts_with("data: ") {
let data = &line[6..]; // Remove "data: " prefix
if data == "[DONE]" {
return None;
}
// Skip ping messages (usually from other providers, but handle gracefully)
if data == r#"{"type": "ping"}"# {
continue;
}
// OpenAI-specific parsing of ChatCompletionsStreamResponse
match serde_json::from_str::<ChatCompletionsStreamResponse>(data) {
Ok(response) => return Some(Ok(Box::new(response))),
Err(e) => return Some(Err(Box::new(
OpenAIStreamError::InvalidStreamingData(format!("Error parsing OpenAI streaming data: {}, data: {}", e, data))
))),
}
}
}
None
}
}
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
fn content_delta(&self) -> Option<&str> {
self.choices
.first()
.and_then(|choice| choice.delta.content.as_deref())
}
fn is_final(&self) -> bool {
self.choices
.first()
.map(|choice| choice.finish_reason.is_some())
.unwrap_or(false)
}
fn role(&self) -> Option<&str> {
self.choices
.first()
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}))
}
}
#[cfg(test)]
mod tests {
use super::*;