diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index a0eaf393..c0e1212c 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -25,7 +25,7 @@ pub struct StreamContext { context_id: u32, metrics: Rc, ratelimit_selector: Option
, - streaming_response: Option, + streaming_response: bool, response_tokens: usize, is_chat_completions_request: bool, llm_providers: Rc, @@ -33,24 +33,13 @@ pub struct StreamContext { request_id: Option, } -#[derive(Debug)] -struct StreamingResponse { - bytes_read: usize, -} - -impl StreamingResponse { - fn new() -> Self { - StreamingResponse { bytes_read: 0 } - } -} - impl StreamContext { pub fn new(context_id: u32, metrics: Rc, llm_providers: Rc) -> Self { StreamContext { context_id, metrics, ratelimit_selector: None, - streaming_response: None, + streaming_response: false, response_tokens: 0, is_chat_completions_request: false, llm_providers, @@ -229,7 +218,7 @@ impl HttpContext for StreamContext { ); if deserialized_body.stream { - self.streaming_response = Some(StreamingResponse::new()); + self.streaming_response = true; } if deserialized_body.stream && deserialized_body.stream_options.is_none() { deserialized_body.stream_options = Some(StreamOptions { @@ -265,47 +254,42 @@ impl HttpContext for StreamContext { return Action::Continue; } - let body = match self.streaming_response.take() { - Some(mut streaming_response) => { - if end_of_stream && body_size == 0 { + let body = if self.streaming_response { + if end_of_stream && body_size == 0 { + return Action::Continue; + } + let chunk_start = 0; + let chunk_size = body_size; + debug!( + "streaming response reading, {}..{}", + chunk_start, chunk_size + ); + let streaming_chunk = match self.get_http_response_body(0, chunk_size) { + Some(chunk) => chunk, + None => { + warn!( + "response body empty, chunk_start: {}, chunk_size: {}", + chunk_start, chunk_size + ); return Action::Continue; } - let chunk_start = 0; - let chunk_size = body_size; - debug!("streaming respose reading, {}..{}", chunk_start, chunk_size); - let streaming_chunk = match self.get_http_response_body(0, chunk_size) { - Some(chunk) => chunk, - None => { - warn!( - "response body empy, chunk_start: {}, chunk_size: {}", - chunk_start, chunk_size - ); - return Action::Continue; - } - }; + }; - if streaming_chunk.len() != chunk_size { - warn!( - "chunk size mismatch: read: {} != requested: {}", - streaming_chunk.len(), - chunk_size - ); - } - streaming_response.bytes_read += chunk_size; - // n.b: this funky take and replace of the streaming_response struct is done to appease the borrow - // checker which wouldn't let us take a mut ref of streaming_response, and then a ref for - // `get_http_response_body` - self.streaming_response = Some(streaming_response); - streaming_chunk + if streaming_chunk.len() != chunk_size { + warn!( + "chunk size mismatch: read: {} != requested: {}", + streaming_chunk.len(), + chunk_size + ); } - None => { - debug!("non streaming response bytes read: 0:{}", body_size); - match self.get_http_response_body(0, body_size) { - Some(body) => body, - None => { - warn!("non streaming response body empty"); - return Action::Continue; - } + streaming_chunk + } else { + debug!("non streaming response bytes read: 0:{}", body_size); + match self.get_http_response_body(0, body_size) { + Some(body) => body, + None => { + warn!("non streaming response body empty"); + return Action::Continue; } } }; @@ -318,7 +302,7 @@ impl HttpContext for StreamContext { } }; - if self.streaming_response.is_some() { + if self.streaming_response { let chat_completions_chunk_response_events = match ChatCompletionStreamResponseServerEvents::try_from(body_utf8.as_str()) { Ok(response) => response,