removed the dependency on try_streaming_from_bytes into a try_from trait implementation

This commit is contained in:
Salman Paracha 2025-08-19 15:47:21 -07:00
parent 327b29ec6f
commit 6c1dc658cb
5 changed files with 94 additions and 72 deletions

View file

@ -8,7 +8,7 @@ use thiserror::Error;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, TokenUsage};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage, SseStreamIter};
use super::ApiDefinition;
// ============================================================================
@ -600,27 +600,31 @@ impl ProviderResponse for ChatCompletionsResponse {
}
}
/// SSE-based streaming iterator for OpenAI chat completions
/// Implements ProviderStreamResponseIter directly
pub struct SseChatCompletionIter<I>
// ============================================================================
// OPENAI SSE STREAMING ITERATOR
// ============================================================================
/// OpenAI-specific SSE streaming iterator
/// Handles OpenAI's specific SSE format and ChatCompletionsStreamResponse parsing
pub struct OpenAISseIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
lines: I,
sse_stream: SseStreamIter<I>,
}
impl<I> SseChatCompletionIter<I>
impl<I> OpenAISseIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
pub fn new(lines: I) -> Self {
Self { lines }
pub fn new(sse_stream: SseStreamIter<I>) -> Self {
Self { sse_stream }
}
}
impl<I> Iterator for SseChatCompletionIter<I>
impl<I> Iterator for OpenAISseIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
@ -628,7 +632,7 @@ where
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
fn next(&mut self) -> Option<Self::Item> {
for line in &mut self.lines {
for line in &mut self.sse_stream.lines {
let line = line.as_ref();
if line.is_empty() {
continue;
@ -640,14 +644,16 @@ where
return None;
}
// Skip ping messages (usually from other providers, but handle gracefully)
if data == r#"{"type": "ping"}"# {
continue; // Skip ping messages - that is usually from anthropic
continue;
}
// OpenAI-specific parsing of ChatCompletionsStreamResponse
match serde_json::from_str::<ChatCompletionsStreamResponse>(data) {
Ok(response) => return Some(Ok(Box::new(response))),
Err(e) => return Some(Err(Box::new(
OpenAIStreamError::InvalidStreamingData(format!("Error parsing: {}, data: {}", e, data))
OpenAIStreamError::InvalidStreamingData(format!("Error parsing OpenAI streaming data: {}, data: {}", e, data))
))),
}
}
@ -656,15 +662,6 @@ where
}
}
impl<I> ProviderStreamResponseIter for SseChatCompletionIter<I>
where
I: Iterator + Send + Sync,
I::Item: AsRef<str>,
{
// Just marking that this type implements the trait - no additional methods needed
}
// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse
impl ProviderStreamResponse for ChatCompletionsStreamResponse {
fn content_delta(&self) -> Option<&str> {

View file

@ -7,7 +7,7 @@ pub mod clients;
// Re-export important types and traits
pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError};
pub use providers::response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, ProviderResponseError, TokenUsage, try_streaming_from_bytes};
pub use providers::response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, ProviderResponseError, TokenUsage};
pub use providers::id::ProviderId;
pub use providers::adapters::{has_compatible_api, supported_apis};
@ -71,7 +71,7 @@ mod tests {
data: [DONE]
"#;
let result = try_streaming_from_bytes(sse_data.as_bytes(), &ProviderId::OpenAI);
let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &ProviderId::OpenAI));
assert!(result.is_ok());
let mut streaming_response = result.unwrap();

View file

@ -10,5 +10,5 @@ pub mod adapters;
pub use id::ProviderId;
pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ;
pub use response::{ProviderResponseType, ProviderStreamResponseType, ProviderResponse, ProviderStreamResponse, TokenUsage };
pub use response::{ProviderResponseType, ProviderStreamResponseIter, ProviderResponse, ProviderStreamResponse, TokenUsage };
pub use adapters::*;

View file

@ -2,7 +2,7 @@ use std::error::Error;
use std::fmt;
use crate::apis::openai::ChatCompletionsResponse;
use crate::apis::openai::ChatCompletionsStreamResponse;
use crate::apis::OpenAISseIter;
use crate::providers::id::ProviderId;
use crate::providers::adapters::{get_provider_config, AdapterType};
@ -11,9 +11,9 @@ pub enum ProviderResponseType {
//MessagesResponse(MessagesResponse),
}
pub enum ProviderStreamResponseType {
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
//MessagesStreamResponse(MessagesStreamMessage),
pub enum ProviderStreamResponseIter {
ChatCompletionsStream(OpenAISseIter<std::vec::IntoIter<String>>),
//MessagesStream(AnthropicSseIter<std::vec::IntoIter<String>>),
}
impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType {
@ -31,6 +31,46 @@ impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType {
}
}
}
impl TryFrom<(&[u8], &ProviderId)> for ProviderStreamResponseIter {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
let config = get_provider_config(provider_id);
// Parse SSE (Server-Sent Events) streaming data - protocol layer
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
match config.adapter_type {
AdapterType::OpenAICompatible => {
// Delegate to OpenAI-specific iterator implementation
let sse_container = SseStreamIter::new(lines.into_iter());
let iter = crate::apis::openai::OpenAISseIter::new(sse_container);
Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter))
}
// Future: AdapterType::Claude => {
// let sse_container = SseStreamIter::new(lines.into_iter());
// let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container);
// Ok(ProviderStreamResponseIter::MessagesStream(iter))
// }
}
}
}
impl Iterator for ProviderStreamResponseIter {
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;
fn next(&mut self) -> Option<Self::Item> {
match self {
ProviderStreamResponseIter::ChatCompletionsStream(iter) => iter.next(),
// Future: ProviderStreamResponseIter::MessagesStream(iter) => iter.next(),
}
}
}
pub trait ProviderResponse: Send + Sync {
/// Get usage information if available - returns dynamic trait object
fn usage(&self) -> Option<&dyn TokenUsage>;
@ -52,9 +92,30 @@ pub trait ProviderStreamResponse: Send + Sync {
fn role(&self) -> Option<&str>;
}
/// Trait for streaming response iterators
pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync {
// ============================================================================
// GENERIC SSE STREAMING ITERATOR (Container Only)
// ============================================================================
/// Generic SSE (Server-Sent Events) streaming iterator container
/// This is just a simple wrapper - actual Iterator implementation is delegated to provider-specific modules
pub struct SseStreamIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
pub lines: I,
}
impl<I> SseStreamIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
pub fn new(lines: I) -> Self {
Self { lines }
}
}
@ -74,28 +135,9 @@ impl ProviderResponse for ProviderResponseType {
}
}
impl ProviderStreamResponse for ProviderStreamResponseType {
fn content_delta(&self) -> Option<&str> {
match self {
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.content_delta(),
// Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.content_delta(),
}
}
fn is_final(&self) -> bool {
match self {
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.is_final(),
// Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.is_final(),
}
}
fn role(&self) -> Option<&str> {
match self {
ProviderStreamResponseType::ChatCompletionsStreamResponse(resp) => resp.role(),
// Future: ProviderStreamResponseType::MessagesStreamResponse(resp) => resp.role(),
}
}
}
// Implement Send + Sync for the enum to match the original trait requirements
unsafe impl Send for ProviderStreamResponseIter {}
unsafe impl Sync for ProviderStreamResponseIter {}
/// Trait for token usage information
pub trait TokenUsage {
@ -123,20 +165,3 @@ impl Error for ProviderResponseError {
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
}
}
/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object
pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<Box<dyn ProviderStreamResponseIter>, Box<dyn std::error::Error + Send + Sync>> {
let config = get_provider_config(provider_id);
match config.adapter_type {
AdapterType::OpenAICompatible => {
// Parse SSE (Server-Sent Events) streaming data
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
let iter = crate::apis::openai::SseChatCompletionIter::new(lines.into_iter());
// Return the iterator directly - it implements ProviderStreamResponseIter
Ok(Box::new(iter))
}
}
}

View file

@ -10,9 +10,9 @@ use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::providers::response::ProviderStreamResponseIter;
use hermesllm::{
try_streaming_from_bytes, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse,
ProviderResponseType,
ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse, ProviderResponseType,
};
use http::StatusCode;
use log::{debug, info, warn};
@ -572,7 +572,7 @@ impl HttpContext for StreamContext {
// Since all providers use OpenAI-compatible streaming format
let provider_id = self.get_provider_id();
match try_streaming_from_bytes(&body, &provider_id) {
match ProviderStreamResponseIter::try_from((&body[..], &provider_id)) {
Ok(mut streaming_response) => {
// Process each streaming chunk
while let Some(chunk_result) = streaming_response.next() {