diff --git a/crates/hermesllm/README.md b/crates/hermesllm/README.md index 3237d905..6f6b99e1 100644 --- a/crates/hermesllm/README.md +++ b/crates/hermesllm/README.md @@ -43,7 +43,7 @@ let request = ProviderRequestType::try_from((request_bytes.as_bytes(), &Provider // Access request properties println!("Model: {}", request.model()); -println!("User message: {:?}", request.extract_user_message()); +println!("User message: {:?}", request.get_recent_user_message()); println!("Is streaming: {}", request.is_streaming()); ``` diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index bca5be27..2471fc35 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -549,13 +549,6 @@ impl ProviderRequest for ChatCompletionsRequest { self.stream.unwrap_or_default() } - fn set_streaming_options(&mut self) { - self.stream = Some(true); - if self.stream_options.is_none() { - self.stream_options = Some(StreamOptions { include_usage: Some(true) }); - } - } - fn extract_messages_text(&self) -> String { self.messages.iter().fold(String::new(), |acc, m| { acc + " " + &match &m.content { @@ -568,7 +561,7 @@ impl ProviderRequest for ChatCompletionsRequest { }) } - fn extract_user_message(&self) -> Option { + fn get_recent_user_message(&self) -> Option { self.messages.last().and_then(|msg| { match &msg.content { MessageContent::Text(text) => Some(text.clone()), diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index bb2863d6..b4ad9932 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -60,7 +60,7 @@ mod tests { let request = result.unwrap(); assert_eq!(request.model(), "gpt-4"); - assert_eq!(request.extract_user_message(), Some("Hello!".to_string())); + assert_eq!(request.get_recent_user_message(), Some("Hello!".to_string())); } #[test] diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 577be2b7..1eb39416 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -46,14 +46,11 @@ pub trait ProviderRequest: Send + Sync { /// Check if this is a streaming request fn is_streaming(&self) -> bool; - /// Set streaming options (e.g., include_usage) - fn set_streaming_options(&mut self); - /// Extract text content from messages for token counting fn extract_messages_text(&self) -> String; /// Extract the user message for tracing/logging purposes - fn extract_user_message(&self) -> Option; + fn get_recent_user_message(&self) -> Option; /// Convert the request to bytes for transmission fn to_bytes(&self) -> Result, ProviderRequestError>; @@ -78,21 +75,15 @@ impl ProviderRequest for ProviderRequestType { } } - fn set_streaming_options(&mut self) { - match self { - Self::ChatCompletionsRequest(r) => r.set_streaming_options(), - } - } - fn extract_messages_text(&self) -> String { match self { Self::ChatCompletionsRequest(r) => r.extract_messages_text(), } } - fn extract_user_message(&self) -> Option { + fn get_recent_user_message(&self) -> Option { match self { - Self::ChatCompletionsRequest(r) => r.extract_user_message(), + Self::ChatCompletionsRequest(r) => r.get_recent_user_message(), } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 96354fbf..6b2c5f15 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -356,7 +356,7 @@ impl HttpContext for StreamContext { deserialized_body.set_model(resolved_model.clone()); // Extract user message for tracing - self.user_message = deserialized_body.extract_user_message(); + self.user_message = deserialized_body.get_recent_user_message(); info!( "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}", @@ -368,11 +368,6 @@ impl HttpContext for StreamContext { // Use provider interface for streaming detection and setup self.streaming_response = deserialized_body.is_streaming(); - // Set streaming options if needed - if self.streaming_response { - deserialized_body.set_streaming_options(); - } - // Use provider interface for text extraction (after potential mutation) let input_tokens_str = deserialized_body.extract_messages_text(); // enforce ratelimits on ingress @@ -385,9 +380,6 @@ impl HttpContext for StreamContext { return Action::Continue; } - let llm_provider_str = self.llm_provider().provider_interface.to_string(); - let _hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str()); - // Convert chat completion request to llm provider specific request using provider interface let deserialized_body_bytes = match deserialized_body.to_bytes() { Ok(bytes) => bytes, @@ -562,17 +554,9 @@ impl HttpContext for StreamContext { ); } - let llm_provider_str = self.llm_provider().provider_interface.to_string(); - let _provider_id = ProviderId::from(llm_provider_str.as_str()); - if self.streaming_response { debug!("processing streaming response"); - - // Parse streaming response using OpenAI-compatible format - // Since all providers use OpenAI-compatible streaming format - let provider_id = self.get_provider_id(); - - match ProviderStreamResponseIter::try_from((&body[..], &provider_id)) { + match ProviderStreamResponseIter::try_from((&body[..], &self.get_provider_id())) { Ok(mut streaming_response) => { // Process each streaming chunk while let Some(chunk_result) = streaming_response.next() {