mirror of
https://github.com/katanemo/plano.git
synced 2026-05-08 23:32:43 +02:00
updating the implementation of /v1/chat/completions to use the generi… (#548)
* updating the implementation of /v1/chat/completions to use the generic provider interfaces * saving changes, although we will need a small re-factor after this as well * more refactoring changes, getting close * more refactoring changes to avoid unecessary re-direction and duplication * more clean up * more refactoring * more refactoring to clean code and make stream_context.rs work * removing unecessary trait implemenations * some more clean-up * fixed bugs * fixing test cases, and making sure all references to the ChatCOmpletions* objects point to the new types * refactored changes to support enum dispatch * removed the dependency on try_streaming_from_bytes into a try_from trait implementation * updated readme based on new usage * updated code based on code review comments --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-2.local> Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-4.local>
This commit is contained in:
parent
1fdde8181a
commit
89ab51697a
22 changed files with 1044 additions and 972 deletions
|
|
@ -2,7 +2,13 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::Value;
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
|
||||
|
||||
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage, SseStreamIter};
|
||||
use super::ApiDefinition;
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -115,8 +121,8 @@ pub enum Role {
|
|||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Message {
|
||||
pub content: MessageContent,
|
||||
pub role: Role,
|
||||
pub content: MessageContent,
|
||||
pub name: Option<String>,
|
||||
/// Tool calls made by the assistant (only present for assistant role)
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
|
@ -124,8 +130,6 @@ pub struct Message {
|
|||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ResponseMessage {
|
||||
|
|
@ -170,6 +174,28 @@ pub enum MessageContent {
|
|||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
impl Display for MessageContent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
MessageContent::Text(text) => write!(f, "{}", text),
|
||||
MessageContent::Parts(parts) => {
|
||||
let text_parts: Vec<String> = parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.clone()),
|
||||
ContentPart::ImageUrl { .. } => {
|
||||
// skip image URLs or their data in text representation
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let combined_text = text_parts.join("\n");
|
||||
write!(f, "{}", combined_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Individual content part within a message (text or image)
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
|
|
@ -424,6 +450,239 @@ pub struct StreamOptions {
|
|||
pub include_usage: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelDetail {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: usize,
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ModelObject {
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Models {
|
||||
pub object: ModelObject,
|
||||
pub data: Vec<ModelDetail>,
|
||||
}
|
||||
|
||||
|
||||
// Error type for streaming operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OpenAIStreamError {
|
||||
#[error("JSON parsing error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
#[error("UTF-8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("Invalid streaming data: {0}")]
|
||||
InvalidStreamingData(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OpenAIError {
|
||||
#[error("json error: {0}")]
|
||||
JsonParseError(#[from] serde_json::Error),
|
||||
#[error("utf8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
#[error("invalid streaming data err {source}, data: {data}")]
|
||||
InvalidStreamingData {
|
||||
source: serde_json::Error,
|
||||
data: String,
|
||||
},
|
||||
#[error("unsupported provider: {provider}")]
|
||||
UnsupportedProvider { provider: String },
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
/// Trait Implementations
|
||||
/// ===========================================================================
|
||||
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsRequest
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsResponse
|
||||
impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
||||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of TokenUsage for OpenAI Usage type
|
||||
impl TokenUsage for Usage {
|
||||
fn completion_tokens(&self) -> usize {
|
||||
self.completion_tokens as usize
|
||||
}
|
||||
|
||||
fn prompt_tokens(&self) -> usize {
|
||||
self.prompt_tokens as usize
|
||||
}
|
||||
|
||||
fn total_tokens(&self) -> usize {
|
||||
self.total_tokens as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of ProviderRequest for ChatCompletionsRequest
|
||||
impl ProviderRequest for ChatCompletionsRequest {
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: String) {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
fn is_streaming(&self) -> bool {
|
||||
self.stream.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn extract_messages_text(&self) -> String {
|
||||
self.messages.iter().fold(String::new(), |acc, m| {
|
||||
acc + " " + &match &m.content {
|
||||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Parts(parts) => parts.iter().map(|part| match part {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
|
||||
}).collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_recent_user_message(&self) -> Option<String> {
|
||||
self.messages.last().and_then(|msg| {
|
||||
match &msg.content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
MessageContent::Parts(_) => None, // No user message in parts
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||
serde_json::to_vec(&self).map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to serialize OpenAI request: {}", e),
|
||||
source: Some(Box::new(e)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of ProviderResponse for ChatCompletionsResponse
|
||||
impl ProviderResponse for ChatCompletionsResponse {
|
||||
fn usage(&self) -> Option<&dyn TokenUsage> {
|
||||
Some(&self.usage)
|
||||
}
|
||||
|
||||
fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> {
|
||||
Some((
|
||||
self.usage.prompt_tokens(),
|
||||
self.usage.completion_tokens(),
|
||||
self.usage.total_tokens(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI SSE STREAMING ITERATOR
|
||||
// ============================================================================
|
||||
|
||||
/// OpenAI-specific SSE streaming iterator
|
||||
/// Handles OpenAI's specific SSE format and ChatCompletionsStreamResponse parsing
|
||||
pub struct OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
sse_stream: SseStreamIter<I>,
|
||||
}
|
||||
|
||||
impl<I> OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(sse_stream: SseStreamIter<I>) -> Self {
|
||||
Self { sse_stream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for line in &mut self.sse_stream.lines {
|
||||
let line = line.as_ref();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with("data: ") {
|
||||
let data = &line[6..]; // Remove "data: " prefix
|
||||
if data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Skip ping messages (usually from other providers, but handle gracefully)
|
||||
if data == r#"{"type": "ping"}"# {
|
||||
continue;
|
||||
}
|
||||
|
||||
// OpenAI-specific parsing of ChatCompletionsStreamResponse
|
||||
match serde_json::from_str::<ChatCompletionsStreamResponse>(data) {
|
||||
Ok(response) => return Some(Ok(Box::new(response))),
|
||||
Err(e) => return Some(Err(Box::new(
|
||||
OpenAIStreamError::InvalidStreamingData(format!("Error parsing OpenAI streaming data: {}, data: {}", e, data))
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
|
||||
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
self.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.content.as_deref())
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
self.choices
|
||||
.first()
|
||||
.map(|choice| choice.finish_reason.is_some())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
self.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.role.as_ref().map(|r| match r {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue