diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 44ceb6c9..3b14a246 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -327,6 +327,7 @@ dependencies = [ "derivative", "duration-string", "governor", + "hermesllm", "hex", "log", "pretty_assertions", @@ -1075,7 +1076,6 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" name = "hermesllm" version = "0.1.0" dependencies = [ - "common", "serde", "serde_json", "serde_with", diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index d8c35140..4696b43b 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -18,6 +18,7 @@ serde_json = "1.0" hex = "0.4.3" urlencoding = "2.1.3" url = "2.5.4" +hermesllm = { version = "0.1.0", path = "../hermesllm" } [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs index 17f19ebb..78eb4097 100644 --- a/crates/common/src/errors.rs +++ b/crates/common/src/errors.rs @@ -1,6 +1,7 @@ use proxy_wasm::types::Status; use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit}; +use hermesllm::providers::openai::types::{OpenAIError}; #[derive(thiserror::Error, Debug)] pub enum ClientError { @@ -39,4 +40,6 @@ pub enum ServerError { BadRequest { why: String }, #[error("error in streaming response")] Streaming(#[from] ChatCompletionChunkResponseError), + #[error("error parsing openai message: {0}")] + OpenAIPError(#[from] OpenAIError), } diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index 5393a9ad..c7917f9a 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -4,7 +4,6 @@ version = "0.1.0" edition = "2021" [dependencies] -common = { version = "0.1.0", path = "../common" } serde = "1.0.219" serde_json = "1.0.140" serde_with = "3.12.0" diff --git a/crates/hermesllm/src/providers/openai/builder.rs b/crates/hermesllm/src/providers/openai/builder.rs new file mode 100644 index 00000000..43c4176f --- /dev/null +++ b/crates/hermesllm/src/providers/openai/builder.rs @@ -0,0 +1,113 @@ +use serde_json::Value; + +use crate::providers::openai::types::{ChatCompletionsRequest, Message, StreamOptions}; + +#[derive(Debug, Clone)] +pub struct OpenAIRequestBuilder { + model: String, + messages: Vec, + temperature: Option, + top_p: Option, + n: Option, + max_tokens: Option, + stream: Option, + stop: Option>, + presence_penalty: Option, + frequency_penalty: Option, + stream_options: Option, + tools: Option>, +} + +impl OpenAIRequestBuilder { + pub fn new(model: impl Into, messages: Vec) -> Self { + Self { + model: model.into(), + messages, + temperature: None, + top_p: None, + n: None, + max_tokens: None, + stream: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + stream_options: None, + tools: None, + } + } + + pub fn temperature(mut self, temperature: f32) -> Self { + self.temperature = Some(temperature); + self + } + + pub fn top_p(mut self, top_p: f32) -> Self { + self.top_p = Some(top_p); + self + } + + pub fn n(mut self, n: u32) -> Self { + self.n = Some(n); + self + } + + pub fn max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = Some(max_tokens); + self + } + + pub fn stream(mut self, stream: bool) -> Self { + self.stream = Some(stream); + self + } + + pub fn stop(mut self, stop: Vec) -> Self { + self.stop = Some(stop); + self + } + + pub fn presence_penalty(mut self, presence_penalty: f32) -> Self { + self.presence_penalty = Some(presence_penalty); + self + } + + pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self { + self.frequency_penalty = Some(frequency_penalty); + self + } + + pub fn stream_options(mut self, include_usage: bool) -> Self { + self.stream = Some(true); + self.stream_options = Some(StreamOptions { include_usage }); + self + } + + pub fn tools(mut self, tools: Vec) -> Self { + self.tools = Some(tools); + self + } + + pub fn build(self) -> Result { + let request = ChatCompletionsRequest { + model: self.model, + messages: self.messages, + temperature: self.temperature, + top_p: self.top_p, + n: self.n, + max_tokens: self.max_tokens, + stream: self.stream, + stop: self.stop, + presence_penalty: self.presence_penalty, + frequency_penalty: self.frequency_penalty, + stream_options: self.stream_options, + tools: self.tools, + }; + Ok(request) + } +} + +impl ChatCompletionsRequest { + pub fn builder(model: impl Into, messages: Vec) -> OpenAIRequestBuilder { + OpenAIRequestBuilder::new(model, messages) + } +} diff --git a/crates/hermesllm/src/providers/openai/mod.rs b/crates/hermesllm/src/providers/openai/mod.rs index b9c0b1a1..eb6a9726 100644 --- a/crates/hermesllm/src/providers/openai/mod.rs +++ b/crates/hermesllm/src/providers/openai/mod.rs @@ -1,30 +1,4 @@ pub mod types; pub mod builder; -use thiserror::Error; - -use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse}; - pub type OpenAIRequestBuilder = builder::OpenAIRequestBuilder; - -#[derive(Debug, Error)] -pub enum OpenAIError { - #[error("json error: {0}")] - JsonParseError(#[from] serde_json::Error), -} - -type Result = std::result::Result; - -impl TryFrom<&[u8]> for ChatCompletionsRequest { - type Error = OpenAIError; - fn try_from(bytes: &[u8]) -> Result { - serde_json::from_slice(bytes).map_err(OpenAIError::from) - } -} - -impl TryFrom<&[u8]> for ChatCompletionsResponse { - type Error = OpenAIError; - fn try_from(bytes: &[u8]) -> Result { - serde_json::from_slice(bytes).map_err(OpenAIError::from) - } -} diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs index a8c5b3fe..66f0dc29 100644 --- a/crates/hermesllm/src/providers/openai/types.rs +++ b/crates/hermesllm/src/providers/openai/types.rs @@ -1,7 +1,21 @@ use std::fmt::Display; use serde::{Deserialize, Serialize}; +use serde_json::Value; use serde_with::skip_serializing_none; +use std::convert::TryFrom; +use std::str; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum OpenAIError { + #[error("json error: {0}")] + JsonParseError(#[from] serde_json::Error), + #[error("utf8 parsing error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), +} + +type Result = std::result::Result; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum MultiPartContentType { @@ -57,10 +71,9 @@ pub struct Message { pub content: Option, } - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct StreamOptions { - pub include_usage: bool, + pub include_usage: bool, } #[skip_serializing_none] @@ -77,8 +90,15 @@ pub struct ChatCompletionsRequest { pub presence_penalty: Option, pub frequency_penalty: Option, pub stream_options: Option, + pub tools: Option>, } +impl TryFrom<&[u8]> for ChatCompletionsRequest { + type Error = OpenAIError; + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(OpenAIError::from) + } +} #[skip_serializing_none] #[derive(Debug, Clone, Deserialize, Serialize)] @@ -90,6 +110,13 @@ pub struct ChatCompletionsResponse { pub usage: Option, } +impl TryFrom<&[u8]> for ChatCompletionsResponse { + type Error = OpenAIError; + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(OpenAIError::from) + } +} + #[skip_serializing_none] #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Choice { @@ -120,6 +147,59 @@ pub struct ChatCompletionStreamResponse { pub created: u64, pub model: String, pub choices: Vec, + pub usage: Option, +} + +pub struct SseChatCompletionIter +where + I: Iterator, + I::Item: AsRef, +{ + lines: I, +} + +impl SseChatCompletionIter +where + I: Iterator, + I::Item: AsRef, +{ + pub fn new(lines: I) -> Self { + Self { lines } + } +} + +impl Iterator for SseChatCompletionIter +where + I: Iterator, + I::Item: AsRef, +{ + type Item = Result; + + fn next(&mut self) -> Option { + for line in &mut self.lines { + let line = line.as_ref(); + if let Some(data) = line.strip_prefix("data: ") { + let data = data.trim(); + if data == "[DONE]" { + return None; + } + return Some( + serde_json::from_str::(data) + .map_err(OpenAIError::from), + ); + } + } + None + } +} + +impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter> { + type Error = OpenAIError; + + fn try_from(bytes: &'a [u8]) -> Result { + let s = std::str::from_utf8(bytes)?; + Ok(SseChatCompletionIter::new(s.lines())) + } } #[cfg(test)] @@ -191,4 +271,83 @@ mod tests { panic!("Expected MultiPartContent"); } } + + #[test] + fn test_sse_streaming() { + let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]} +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]} +data: [DONE]"#; + + let iter = SseChatCompletionIter::new(json_data.lines()); + + println!("Testing SSE Streaming"); + for item in iter { + match item { + Ok(response) => { + println!("Received response: {:?}", response); + if response.choices.is_empty() { + continue; + } + for choice in response.choices { + if let Some(content) = choice.delta.content { + println!("Content: {}", content); + } + } + } + Err(e) => { + println!("Error parsing JSON: {}", e); + return; + } + } + } + } + + #[test] + fn test_sse_streaming_try_from_bytes() { + let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]} +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]} +data: [DONE]"#; + + let iter = SseChatCompletionIter::try_from(json_data.as_bytes()) + .expect("Failed to create SSE iterator"); + + println!("Testing SSE Streaming"); + for item in iter { + match item { + Ok(response) => { + println!("Received response: {:?}", response); + if response.choices.is_empty() { + continue; + } + for choice in response.choices { + if let Some(content) = choice.delta.content { + println!("Content: {}", content); + } + } + } + Err(e) => { + println!("Error parsing JSON: {}", e); + return; + } + } + } + } + + #[test] + fn parse_chat_completions_request() { + const CHAT_COMPLETIONS_REQUEST: &str = r#" +{ + "model": "None", + "messages": [ + { + "role": "user", + "content": "how is the weather in seattle" + } + ], + "stream": true +} "#; + + let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(CHAT_COMPLETIONS_REQUEST.as_bytes()) + .expect("Failed to parse ChatCompletionsRequest"); + } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 7f4620ba..d7d4ad23 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,5 +1,4 @@ use crate::metrics::Metrics; -use common::api::open_ai::ChatCompletionStreamResponseServerEvents; use common::configuration::{LlmProvider, LlmProviderType, Overrides}; use common::consts::{ ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, @@ -11,7 +10,7 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; -use hermesllm::providers::openai::types::ChatCompletionsRequest; +use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatCompletionIter}; use hermesllm::providers::openai::types::{ ChatCompletionsResponse, ContentType, Message, StreamOptions, }; @@ -285,23 +284,17 @@ impl HttpContext for StreamContext { } }; - // Deserialize body into spec. - // Currently OpenAI API. - let mut deserialized_body: ChatCompletionsRequest = - match serde_json::from_slice(&body_bytes) { - Ok(deserialized) => deserialized, - Err(e) => { - debug!( - "on_http_request_body: request body: {}", - String::from_utf8_lossy(&body_bytes) - ); - self.send_server_error( - ServerError::Deserialization(e), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }; + let mut deserialized_body = match ChatCompletionsRequest::try_from(body_bytes.as_slice()) { + Ok(deserialized) => deserialized, + Err(e) => { + debug!( + "on_http_request_body: request body: {}", + String::from_utf8_lossy(&body_bytes) + ); + self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST)); + return Action::Pause; + } + }; self.user_message = deserialized_body .messages @@ -541,58 +534,30 @@ impl HttpContext for StreamContext { } }; - let body_utf8 = match String::from_utf8(body) { - Ok(body_utf8) => body_utf8, - Err(e) => { - warn!("could not convert to utf8: {}", e); - return Action::Continue; - } - }; - if self.streaming_response { - if body_utf8 == "data: [DONE]\n" { - return Action::Continue; - } - let chat_completions_chunk_response_events = - match ChatCompletionStreamResponseServerEvents::try_from(body_utf8.as_str()) { - Ok(response) => response, + match SseChatCompletionIter::try_from(body.as_slice()) { + Ok(events) => events, Err(e) => { - warn!( - "invalid streaming response: body str: {}, {:?}", - body_utf8, e - ); + warn!("could not parse response: {}", e); return Action::Continue; } }; - if chat_completions_chunk_response_events.events.is_empty() { - warn!( - "couldn't parse any streaming events: body str: {}", - body_utf8 - ); - return Action::Continue; + for event in chat_completions_chunk_response_events { + match event { + Ok(event) => { + if let Some(usage) = event.usage.as_ref() { + self.response_tokens += usage.completion_tokens; + } + } + Err(e) => { + warn!("error in response event: {}", e); + continue; + } + } } - let model = chat_completions_chunk_response_events - .events - .first() - .unwrap() - .model - .clone(); - let tokens_str = chat_completions_chunk_response_events.to_string(); - - let token_count = - match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str()) - { - Ok(token_count) => token_count, - Err(e) => { - warn!("could not get token count: {:?}", e); - return Action::Continue; - } - }; - self.response_tokens += token_count; - // Compute TTFT if not already recorded if self.ttft_duration.is_none() { // if let Some(start_time) = self.start_time { @@ -616,23 +581,20 @@ impl HttpContext for StreamContext { } else { debug!("non streaming response"); let chat_completions_response: ChatCompletionsResponse = - match serde_json::from_str(body_utf8.as_str()) { + match serde_json::from_slice(body.as_slice()) { Ok(de) => de, Err(err) => { info!( - "non chat-completion compliant response received err: {}, body: {}", - err, body_utf8 + "non chat-completion compliant response received err: {}, body: {:?}", + err, + String::from_utf8(body) ); return Action::Continue; } }; - if chat_completions_response.usage.is_some() { - self.response_tokens += chat_completions_response - .usage - .as_ref() - .unwrap() - .completion_tokens; + if let Some(usage) = chat_completions_response.usage { + self.response_tokens += usage.completion_tokens; } } diff --git a/tests/rest/api_llm_gateway.rest b/tests/rest/api_llm_gateway.rest index 41fcffca..752d600b 100644 --- a/tests/rest/api_llm_gateway.rest +++ b/tests/rest/api_llm_gateway.rest @@ -75,3 +75,48 @@ x-arch-llm-provider-hint: gpt-3.5-turbo-0125 } ] } + +### llm gateway request with function calling (default target) +POST {{llm_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json + +{ + "stream": true, + "model": "None", + "messages": [ + { + "role": "user", + "content": "how is the weather in seattle" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get current weather at a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for", + "format": "City, State" + }, + "unit": { + "type": "string", + "description": "The unit to return the weather in.", + "enum": ["celsius", "fahrenheit"], + "default": "celsius" + }, + "days": { + "type": "string", + "description": "The number of days for the request." + } + }, + "required": ["location", "days"] + } + } + } + ] +}