diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 867836a0..53ec8e74 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -80,6 +80,7 @@ properties: - groq - mistral - openai + - gemini access_key: type: string model: diff --git a/arch/supervisord.conf b/arch/supervisord.conf index 7ef06b49..c538ae66 100644 --- a/arch/supervisord.conf +++ b/arch/supervisord.conf @@ -2,14 +2,14 @@ nodaemon=true [program:brightstaff] -command=sh -c "/app/brightstaff 2>&1 | tee /var/log/brightstaff.log" +command=sh -c "RUST_LOG=trace /app/brightstaff 2>&1 | tee /var/log/brightstaff.log" stdout_logfile=/dev/stdout redirect_stderr=true stdout_logfile_maxbytes=0 stderr_logfile_maxbytes=0 [program:envoy] -command=/bin/sh -c "python /app/config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:info 2>&1 | tee /var/log//envoy.log" +command=/bin/sh -c "python /app/config_generator.py && envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && envoy -c /etc/envoy.env_sub.yaml --component-log-level wasm:trace 2>&1 | tee /var/log//envoy.log" stdout_logfile=/dev/stdout redirect_stderr=true stdout_logfile_maxbytes=0 diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 3018f679..0dbd0b70 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -158,6 +158,8 @@ pub enum LlmProviderType { Mistral, #[serde(rename = "openai")] OpenAI, + #[serde(rename = "gemini")] + Gemini, } impl Display for LlmProviderType { @@ -167,6 +169,7 @@ impl Display for LlmProviderType { LlmProviderType::Claude => write!(f, "claude"), LlmProviderType::Deepseek => write!(f, "deepseek"), LlmProviderType::Groq => write!(f, "groq"), + LlmProviderType::Gemini => write!(f, "gemini"), LlmProviderType::Mistral => write!(f, "mistral"), LlmProviderType::OpenAI => write!(f, "openai"), } diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs index 6f4e38d7..170eec42 100644 --- a/crates/hermesllm/src/providers/openai/types.rs +++ b/crates/hermesllm/src/providers/openai/types.rs @@ -494,4 +494,38 @@ data: [DONE] "Hello! How can I assist you today? Whether you have a question, need information, or just want to chat about something, I'm here to help. What would you like to talk about?" ); } + + #[test] + fn stream_chunk_parse_gemini() { + const CHUNK_RESPONSE: &str = r#"data: {"choices":[{"delta":{"content":":**\n\n* **Chief Executive:** T"#; + + let iter = SseChatCompletionIter::try_from(CHUNK_RESPONSE.as_bytes()); + + assert!(iter.is_ok(), "Failed to create SSE iterator"); + let iter: SseChatCompletionIter> = iter.unwrap(); + + let all_text: Vec = iter + .map(|item| { + let response = item.expect("Failed to parse response"); + response + .choices + .into_iter() + .filter_map(|choice| choice.delta.content) + .map(|content| content.to_string()) + .collect::() + }) + .collect(); + + assert_eq!( + all_text.len(), + 1, + "Expected 8 chunks of text, but got {}", + all_text.len() + ); + + assert_eq!( + all_text.join(""), + "Hello! How can I assist you today? Whether you have a question, need information, or just want to chat about something, I'm here to help. What would you like to talk about?" + ); + } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index bf40f337..389f6252 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -16,7 +16,7 @@ use hermesllm::providers::openai::types::{ }; use hermesllm::Provider; use http::StatusCode; -use log::{debug, info, warn}; +use log::{debug, info, trace, warn}; use proxy_wasm::hostcalls::get_current_time; use proxy_wasm::traits::*; use proxy_wasm::types::*; @@ -31,8 +31,9 @@ pub struct StreamContext { metrics: Rc, ratelimit_selector: Option
, streaming_response: bool, + streaming_buffer: Option>, response_tokens: usize, - is_chat_completions_request: bool, + // is_chat_completions_request: bool, llm_providers: Rc, llm_provider: Option>, request_id: Option, @@ -61,7 +62,7 @@ impl StreamContext { ratelimit_selector: None, streaming_response: false, response_tokens: 0, - is_chat_completions_request: false, + // is_chat_completions_request: false, llm_providers, llm_provider: None, request_id: None, @@ -72,6 +73,7 @@ impl StreamContext { user_message: None, traces_queue, request_body_sent_time: None, + streaming_buffer: None, } } fn llm_provider(&self) -> &LlmProvider { @@ -90,15 +92,30 @@ impl StreamContext { provider_hint, )); - if self.llm_provider.as_ref().unwrap().provider_interface == LlmProviderType::Groq { - if let Some(path) = self.get_http_request_header(":path") { - if path.starts_with("/v1/") { - let new_path = format!("/openai{}", path); - self.set_http_request_header(":path", Some(new_path.as_str())); + match self.llm_provider.as_ref().unwrap().provider_interface { + LlmProviderType::Groq => { + if let Some(path) = self.get_http_request_header(":path") { + if path.starts_with("/v1/") { + let new_path = format!("/openai{}", path); + self.set_http_request_header(":path", Some(new_path.as_str())); + } } } + LlmProviderType::Gemini => { + if let Some(path) = self.get_http_request_header(":path") { + if path == "/v1/chat/completions" { + self.set_http_request_header( + ":path", + Some("/v1beta/openai/chat/completions"), + ); + } + } + } + _ => {} } + if self.llm_provider.as_ref().unwrap().provider_interface == LlmProviderType::Groq {} + debug!( "request received: llm provider hint: {}, selected llm: {}, model: {}", self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER) @@ -242,8 +259,8 @@ impl HttpContext for StreamContext { self.delete_content_length_header(); self.save_ratelimit_header(); - let request_path = self.get_http_request_header(":path").unwrap_or_default(); - self.is_chat_completions_request = CHAT_COMPLETIONS_PATH.contains(&request_path.as_str()); + // let request_path = self.get_http_request_header(":path").unwrap_or_default(); + // self.is_chat_completions_request = CHAT_COMPLETIONS_PATH.contains(&request_path.as_str()); self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER); @@ -392,10 +409,10 @@ impl HttpContext for StreamContext { Action::Continue } - fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + fn on_http_response_headers(&mut self, _num_headers: usize, end_of_stream: bool) -> Action { debug!( "on_http_response_headers [S={}] end_stream={}", - self.context_id, _end_of_stream + self.context_id, end_of_stream ); self.set_property( @@ -417,10 +434,10 @@ impl HttpContext for StreamContext { return Action::Continue; } - if !self.is_chat_completions_request { - info!("on_http_response_body: non-chatcompletion request"); - return Action::Continue; - } + // if !self.is_chat_completions_request { + // info!("on_http_response_body: non-chatcompletion request"); + // return Action::Continue; + // } let current_time = get_current_time().unwrap(); if end_of_stream && body_size == 0 { @@ -542,18 +559,66 @@ impl HttpContext for StreamContext { } }; + if log::log_enabled!(log::Level::Trace) { + trace!( + "response data (converted to utf8): {}", + String::from_utf8_lossy(&body) + ); + } + let llm_provider_str = self.llm_provider().provider_interface.to_string(); let hermes_llm_provider = Provider::from(llm_provider_str.as_str()); if self.streaming_response { - let chat_completions_chunk_response_events = - match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) { - Ok(events) => events, - Err(e) => { - warn!("could not parse response: {}", e); - return Action::Continue; - } - }; + // check if body ends with a valid SSE event + if !body.ends_with(b"\n\n") { + if end_of_stream { + warn!("streaming response body does not end with a valid SSE event, but end of stream is true"); + self.send_server_error( + ServerError::LogicError( + "streaming response body does not end with a valid SSE event" + .to_string(), + ), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Continue; + } + + // buffer the body until we have a complete SSE event + debug!("streaming response body does not end with a valid SSE event, buffering the body"); + self.streaming_buffer + .get_or_insert_with(Vec::new) + .extend_from_slice(&body); + // we need to wait for the next chunk to complete the SSE event + return Action::Pause; + } + + // if streaming_buffer is Some, it means we have buffered data from previous chunks + // otherwise we can process the body directly + + let sse_event_buffer = match self.streaming_buffer.take() { + Some(buffer) => { + debug!("streaming response body has buffered data, prepending it to the current chunk"); + let mut complete_body = buffer; + complete_body.extend_from_slice(&body); + complete_body + } + None => { + debug!("no buffered data, processing the current chunk directly"); + body + } + }; + + let chat_completions_chunk_response_events = match SseChatCompletionIter::try_from(( + sse_event_buffer.as_slice(), + &hermes_llm_provider, + )) { + Ok(events) => events, + Err(e) => { + warn!("could not parse response: {}", e); + return Action::Continue; + } + }; for event in chat_completions_chunk_response_events { match event { diff --git a/demos/use_cases/llm_routing/arch_config.yaml b/demos/use_cases/llm_routing/arch_config.yaml index 0d38335e..46acdc07 100644 --- a/demos/use_cases/llm_routing/arch_config.yaml +++ b/demos/use_cases/llm_routing/arch_config.yaml @@ -45,5 +45,10 @@ llm_providers: provider_interface: groq model: llama-3.1-8b-instant + - name: gemini + access_key: $GEMINI_API_KEY + provider_interface: gemini + model: gemini-1.5-pro-latest + tracing: random_sampling: 100