mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
removed the dependency on try_streaming_from_bytes into a try_from trait implementation
This commit is contained in:
parent
327b29ec6f
commit
6c1dc658cb
5 changed files with 94 additions and 72 deletions
|
|
@ -8,7 +8,7 @@ use thiserror::Error;
|
|||
|
||||
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, TokenUsage};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage, SseStreamIter};
|
||||
use super::ApiDefinition;
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -600,27 +600,31 @@ impl ProviderResponse for ChatCompletionsResponse {
|
|||
}
|
||||
}
|
||||
|
||||
/// SSE-based streaming iterator for OpenAI chat completions
|
||||
/// Implements ProviderStreamResponseIter directly
|
||||
pub struct SseChatCompletionIter<I>
|
||||
// ============================================================================
|
||||
// OPENAI SSE STREAMING ITERATOR
|
||||
// ============================================================================
|
||||
|
||||
/// OpenAI-specific SSE streaming iterator
|
||||
/// Handles OpenAI's specific SSE format and ChatCompletionsStreamResponse parsing
|
||||
pub struct OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
lines: I,
|
||||
sse_stream: SseStreamIter<I>,
|
||||
}
|
||||
|
||||
impl<I> SseChatCompletionIter<I>
|
||||
impl<I> OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(lines: I) -> Self {
|
||||
Self { lines }
|
||||
pub fn new(sse_stream: SseStreamIter<I>) -> Self {
|
||||
Self { sse_stream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for SseChatCompletionIter<I>
|
||||
impl<I> Iterator for OpenAISseIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
|
|
@ -628,7 +632,7 @@ where
|
|||
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for line in &mut self.lines {
|
||||
for line in &mut self.sse_stream.lines {
|
||||
let line = line.as_ref();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
|
|
@ -640,14 +644,16 @@ where
|
|||
return None;
|
||||
}
|
||||
|
||||
// Skip ping messages (usually from other providers, but handle gracefully)
|
||||
if data == r#"{"type": "ping"}"# {
|
||||
continue; // Skip ping messages - that is usually from anthropic
|
||||
continue;
|
||||
}
|
||||
|
||||
// OpenAI-specific parsing of ChatCompletionsStreamResponse
|
||||
match serde_json::from_str::<ChatCompletionsStreamResponse>(data) {
|
||||
Ok(response) => return Some(Ok(Box::new(response))),
|
||||
Err(e) => return Some(Err(Box::new(
|
||||
OpenAIStreamError::InvalidStreamingData(format!("Error parsing: {}, data: {}", e, data))
|
||||
OpenAIStreamError::InvalidStreamingData(format!("Error parsing OpenAI streaming data: {}, data: {}", e, data))
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
|
@ -656,15 +662,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<I> ProviderStreamResponseIter for SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator + Send + Sync,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
// Just marking that this type implements the trait - no additional methods needed
|
||||
}
|
||||
|
||||
|
||||
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
|
||||
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ pub mod clients;
|
|||
|
||||
// Re-export important types and traits
|
||||
pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError};
|
||||
pub use providers::response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage, try_streaming_from_bytes};
|
||||
pub use providers::response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, ProviderResponseError, TokenUsage};
|
||||
pub use providers::id::ProviderId;
|
||||
pub use providers::adapters::{has_compatible_api, supported_apis};
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ mod tests {
|
|||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let result = try_streaming_from_bytes(sse_data.as_bytes(), &ProviderId::OpenAI);
|
||||
let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
|
||||
let mut streaming_response = result.unwrap();
|
||||
|
|
|
|||
|
|
@ -10,5 +10,5 @@ pub mod adapters;
|
|||
|
||||
pub use id::ProviderId;
|
||||
pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ;
|
||||
pub use response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage };
|
||||
pub use response::{ProviderResponseType, ProviderStreamResponseIter, ProviderResponse, ProviderStreamResponse, TokenUsage };
|
||||
pub use adapters::*;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use std::error::Error;
|
|||
use std::fmt;
|
||||
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
use crate::apis::OpenAISseIter;
|
||||
use crate::providers::id::ProviderId;
|
||||
use crate::providers::adapters::{get_provider_config, AdapterType};
|
||||
|
||||
|
|
@ -11,9 +11,9 @@ pub enum ProviderResponseType {
|
|||
//MessagesResponse(MessagesResponse),
|
||||
}
|
||||
|
||||
pub enum ProviderStreamResponseType {
|
||||
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
|
||||
//MessagesStreamResponse(MessagesStreamMessage),
|
||||
pub enum ProviderStreamResponseIter {
|
||||
ChatCompletionsStream(OpenAISseIter<std::vec::IntoIter<String>>),
|
||||
//MessagesStream(AnthropicSseIter<std::vec::IntoIter<String>>),
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType {
|
||||
|
|
@ -31,6 +31,46 @@ impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<(&[u8], &ProviderId)> for ProviderStreamResponseIter {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let config = get_provider_config(provider_id);
|
||||
|
||||
// Parse SSE (Server-Sent Events) streaming data - protocol layer
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
// Delegate to OpenAI-specific iterator implementation
|
||||
let sse_container = SseStreamIter::new(lines.into_iter());
|
||||
let iter = crate::apis::openai::OpenAISseIter::new(sse_container);
|
||||
Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter))
|
||||
}
|
||||
// Future: AdapterType::Claude => {
|
||||
// let sse_container = SseStreamIter::new(lines.into_iter());
|
||||
// let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container);
|
||||
// Ok(ProviderStreamResponseIter::MessagesStream(iter))
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl Iterator for ProviderStreamResponseIter {
|
||||
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self {
|
||||
ProviderStreamResponseIter::ChatCompletionsStream(iter) => iter.next(),
|
||||
// Future: ProviderStreamResponseIter::MessagesStream(iter) => iter.next(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub trait ProviderResponse: Send + Sync {
|
||||
/// Get usage information if available - returns dynamic trait object
|
||||
fn usage(&self) -> Option<&dyn TokenUsage>;
|
||||
|
|
@ -52,9 +92,30 @@ pub trait ProviderStreamResponse: Send + Sync {
|
|||
fn role(&self) -> Option<&str>;
|
||||
}
|
||||
|
||||
/// Trait for streaming response iterators
|
||||
pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync {
|
||||
|
||||
|
||||
// ============================================================================
|
||||
// GENERIC SSE STREAMING ITERATOR (Container Only)
|
||||
// ============================================================================
|
||||
|
||||
/// Generic SSE (Server-Sent Events) streaming iterator container
|
||||
/// This is just a simple wrapper - actual Iterator implementation is delegated to provider-specific modules
|
||||
pub struct SseStreamIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub lines: I,
|
||||
}
|
||||
|
||||
impl<I> SseStreamIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(lines: I) -> Self {
|
||||
Self { lines }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -74,28 +135,9 @@ impl ProviderResponse for ProviderResponseType {
|
|||
}
|
||||
}
|
||||
|
||||
impl ProviderStreamResponse for ProviderStreamResponseType {
|
||||
fn content_delta(&self) -> Option<&str> {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(),
|
||||
// Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.content_delta(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_final(&self) -> bool {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(),
|
||||
// Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.is_final(),
|
||||
}
|
||||
}
|
||||
|
||||
fn role(&self) -> Option<&str> {
|
||||
match self {
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(),
|
||||
// Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.role(),
|
||||
}
|
||||
}
|
||||
}
|
||||
// Implement Send + Sync for the enum to match the original trait requirements
|
||||
unsafe impl Send for ProviderStreamResponseIter {}
|
||||
unsafe impl Sync for ProviderStreamResponseIter {}
|
||||
|
||||
/// Trait for token usage information
|
||||
pub trait TokenUsage {
|
||||
|
|
@ -123,20 +165,3 @@ impl Error for ProviderResponseError {
|
|||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||
}
|
||||
}
|
||||
|
||||
/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object
|
||||
pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<Box<dyn ProviderStreamResponseIter>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let config = get_provider_config(provider_id);
|
||||
|
||||
match config.adapter_type {
|
||||
AdapterType::OpenAICompatible => {
|
||||
// Parse SSE (Server-Sent Events) streaming data
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||
let iter = crate::apis::openai::SseChatCompletionIter::new(lines.into_iter());
|
||||
|
||||
// Return the iterator directly - it implements ProviderStreamResponseIter
|
||||
Ok(Box::new(iter))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ use common::ratelimit::Header;
|
|||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::providers::response::ProviderStreamResponseIter;
|
||||
use hermesllm::{
|
||||
try_streaming_from_bytes, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse,
|
||||
ProviderResponseType,
|
||||
ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse, ProviderResponseType,
|
||||
};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -572,7 +572,7 @@ impl HttpContext for StreamContext {
|
|||
// Since all providers use OpenAI-compatible streaming format
|
||||
let provider_id = self.get_provider_id();
|
||||
|
||||
match try_streaming_from_bytes(&body, &provider_id) {
|
||||
match ProviderStreamResponseIter::try_from((&body[..], &provider_id)) {
|
||||
Ok(mut streaming_response) => {
|
||||
// Process each streaming chunk
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue