use hermesllm in arch gateway for both stream and non stream messages

This commit is contained in:
Adil Hafeez 2025-06-04 16:19:45 -07:00
parent 670907145a
commit 0c7aa132ee
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
9 changed files with 357 additions and 101 deletions

View file

@ -0,0 +1,113 @@
use serde_json::Value;
use crate::providers::openai::types::{ChatCompletionsRequest, Message, StreamOptions};
#[derive(Debug, Clone)]
pub struct OpenAIRequestBuilder {
model: String,
messages: Vec<Message>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<u32>,
max_tokens: Option<u32>,
stream: Option<bool>,
stop: Option<Vec<String>>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
stream_options: Option<StreamOptions>,
tools: Option<Vec<Value>>,
}
impl OpenAIRequestBuilder {
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
model: model.into(),
messages,
temperature: None,
top_p: None,
n: None,
max_tokens: None,
stream: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
stream_options: None,
tools: None,
}
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn n(mut self, n: u32) -> Self {
self.n = Some(n);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.stream = Some(stream);
self
}
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.presence_penalty = Some(presence_penalty);
self
}
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.frequency_penalty = Some(frequency_penalty);
self
}
pub fn stream_options(mut self, include_usage: bool) -> Self {
self.stream = Some(true);
self.stream_options = Some(StreamOptions { include_usage });
self
}
pub fn tools(mut self, tools: Vec<Value>) -> Self {
self.tools = Some(tools);
self
}
pub fn build(self) -> Result<ChatCompletionsRequest, &'static str> {
let request = ChatCompletionsRequest {
model: self.model,
messages: self.messages,
temperature: self.temperature,
top_p: self.top_p,
n: self.n,
max_tokens: self.max_tokens,
stream: self.stream,
stop: self.stop,
presence_penalty: self.presence_penalty,
frequency_penalty: self.frequency_penalty,
stream_options: self.stream_options,
tools: self.tools,
};
Ok(request)
}
}
impl ChatCompletionsRequest {
pub fn builder(model: impl Into<String>, messages: Vec<Message>) -> OpenAIRequestBuilder {
OpenAIRequestBuilder::new(model, messages)
}
}

View file

@ -1,30 +1,4 @@
pub mod types;
pub mod builder;
use thiserror::Error;
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
pub type OpenAIRequestBuilder = builder::OpenAIRequestBuilder;
#[derive(Debug, Error)]
pub enum OpenAIError {
#[error("json error: {0}")]
JsonParseError(#[from] serde_json::Error),
}
type Result<T> = std::result::Result<T, OpenAIError>;
impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(OpenAIError::from)
}
}
impl TryFrom<&[u8]> for ChatCompletionsResponse {
type Error = OpenAIError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(OpenAIError::from)
}
}

View file

@ -1,7 +1,21 @@
use std::fmt::Display;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_with::skip_serializing_none;
use std::convert::TryFrom;
use std::str;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum OpenAIError {
#[error("json error: {0}")]
JsonParseError(#[from] serde_json::Error),
#[error("utf8 parsing error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
}
type Result<T> = std::result::Result<T, OpenAIError>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum MultiPartContentType {
@ -57,10 +71,9 @@ pub struct Message {
pub content: Option<ContentType>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
pub include_usage: bool,
}
#[skip_serializing_none]
@ -77,8 +90,15 @@ pub struct ChatCompletionsRequest {
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub stream_options: Option<StreamOptions>,
pub tools: Option<Vec<Value>>,
}
impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(OpenAIError::from)
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -90,6 +110,13 @@ pub struct ChatCompletionsResponse {
pub usage: Option<Usage>,
}
impl TryFrom<&[u8]> for ChatCompletionsResponse {
type Error = OpenAIError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(OpenAIError::from)
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Choice {
@ -120,6 +147,59 @@ pub struct ChatCompletionStreamResponse {
pub created: u64,
pub model: String,
pub choices: Vec<StreamChoice>,
pub usage: Option<Usage>,
}
pub struct SseChatCompletionIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
lines: I,
}
impl<I> SseChatCompletionIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
pub fn new(lines: I) -> Self {
Self { lines }
}
}
impl<I> Iterator for SseChatCompletionIter<I>
where
I: Iterator,
I::Item: AsRef<str>,
{
type Item = Result<ChatCompletionStreamResponse>;
fn next(&mut self) -> Option<Self::Item> {
for line in &mut self.lines {
let line = line.as_ref();
if let Some(data) = line.strip_prefix("data: ") {
let data = data.trim();
if data == "[DONE]" {
return None;
}
return Some(
serde_json::from_str::<ChatCompletionStreamResponse>(data)
.map_err(OpenAIError::from),
);
}
}
None
}
}
impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter<str::Lines<'a>> {
type Error = OpenAIError;
fn try_from(bytes: &'a [u8]) -> Result<Self> {
let s = std::str::from_utf8(bytes)?;
Ok(SseChatCompletionIter::new(s.lines()))
}
}
#[cfg(test)]
@ -191,4 +271,83 @@ mod tests {
panic!("Expected MultiPartContent");
}
}
#[test]
fn test_sse_streaming() {
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]}
data: [DONE]"#;
let iter = SseChatCompletionIter::new(json_data.lines());
println!("Testing SSE Streaming");
for item in iter {
match item {
Ok(response) => {
println!("Received response: {:?}", response);
if response.choices.is_empty() {
continue;
}
for choice in response.choices {
if let Some(content) = choice.delta.content {
println!("Content: {}", content);
}
}
}
Err(e) => {
println!("Error parsing JSON: {}", e);
return;
}
}
}
}
#[test]
fn test_sse_streaming_try_from_bytes() {
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]}
data: [DONE]"#;
let iter = SseChatCompletionIter::try_from(json_data.as_bytes())
.expect("Failed to create SSE iterator");
println!("Testing SSE Streaming");
for item in iter {
match item {
Ok(response) => {
println!("Received response: {:?}", response);
if response.choices.is_empty() {
continue;
}
for choice in response.choices {
if let Some(content) = choice.delta.content {
println!("Content: {}", content);
}
}
}
Err(e) => {
println!("Error parsing JSON: {}", e);
return;
}
}
}
}
#[test]
fn parse_chat_completions_request() {
const CHAT_COMPLETIONS_REQUEST: &str = r#"
{
"model": "None",
"messages": [
{
"role": "user",
"content": "how is the weather in seattle"
}
],
"stream": true
} "#;
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(CHAT_COMPLETIONS_REQUEST.as_bytes())
.expect("Failed to parse ChatCompletionsRequest");
}
}