mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
use hermesllm in arch gateway for both stream and non stream messages
This commit is contained in:
parent
670907145a
commit
0c7aa132ee
9 changed files with 357 additions and 101 deletions
113
crates/hermesllm/src/providers/openai/builder.rs
Normal file
113
crates/hermesllm/src/providers/openai/builder.rs
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
use serde_json::Value;
|
||||
|
||||
use crate::providers::openai::types::{ChatCompletionsRequest, Message, StreamOptions};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OpenAIRequestBuilder {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
temperature: Option<f32>,
|
||||
top_p: Option<f32>,
|
||||
n: Option<u32>,
|
||||
max_tokens: Option<u32>,
|
||||
stream: Option<bool>,
|
||||
stop: Option<Vec<String>>,
|
||||
presence_penalty: Option<f32>,
|
||||
frequency_penalty: Option<f32>,
|
||||
stream_options: Option<StreamOptions>,
|
||||
tools: Option<Vec<Value>>,
|
||||
}
|
||||
|
||||
impl OpenAIRequestBuilder {
|
||||
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
messages,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
n: None,
|
||||
max_tokens: None,
|
||||
stream: None,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
stream_options: None,
|
||||
tools: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_p(mut self, top_p: f32) -> Self {
|
||||
self.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn n(mut self, n: u32) -> Self {
|
||||
self.n = Some(n);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
|
||||
self.max_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stream(mut self, stream: bool) -> Self {
|
||||
self.stream = Some(stream);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stop(mut self, stop: Vec<String>) -> Self {
|
||||
self.stop = Some(stop);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
|
||||
self.presence_penalty = Some(presence_penalty);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
|
||||
self.frequency_penalty = Some(frequency_penalty);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stream_options(mut self, include_usage: bool) -> Self {
|
||||
self.stream = Some(true);
|
||||
self.stream_options = Some(StreamOptions { include_usage });
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tools(mut self, tools: Vec<Value>) -> Self {
|
||||
self.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<ChatCompletionsRequest, &'static str> {
|
||||
let request = ChatCompletionsRequest {
|
||||
model: self.model,
|
||||
messages: self.messages,
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
n: self.n,
|
||||
max_tokens: self.max_tokens,
|
||||
stream: self.stream,
|
||||
stop: self.stop,
|
||||
presence_penalty: self.presence_penalty,
|
||||
frequency_penalty: self.frequency_penalty,
|
||||
stream_options: self.stream_options,
|
||||
tools: self.tools,
|
||||
};
|
||||
Ok(request)
|
||||
}
|
||||
}
|
||||
|
||||
impl ChatCompletionsRequest {
|
||||
pub fn builder(model: impl Into<String>, messages: Vec<Message>) -> OpenAIRequestBuilder {
|
||||
OpenAIRequestBuilder::new(model, messages)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,30 +1,4 @@
|
|||
pub mod types;
|
||||
pub mod builder;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
|
||||
pub type OpenAIRequestBuilder = builder::OpenAIRequestBuilder;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OpenAIError {
|
||||
#[error("json error: {0}")]
|
||||
JsonParseError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
||||
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
||||
type Error = OpenAIError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,21 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::convert::TryFrom;
|
||||
use std::str;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OpenAIError {
|
||||
#[error("json error: {0}")]
|
||||
JsonParseError(#[from] serde_json::Error),
|
||||
#[error("utf8 parsing error: {0}")]
|
||||
Utf8Error(#[from] std::str::Utf8Error),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum MultiPartContentType {
|
||||
|
|
@ -57,10 +71,9 @@ pub struct Message {
|
|||
pub content: Option<ContentType>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamOptions {
|
||||
pub include_usage: bool,
|
||||
pub include_usage: bool,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
|
|
@ -77,8 +90,15 @@ pub struct ChatCompletionsRequest {
|
|||
pub presence_penalty: Option<f32>,
|
||||
pub frequency_penalty: Option<f32>,
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
pub tools: Option<Vec<Value>>,
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
|
|
@ -90,6 +110,13 @@ pub struct ChatCompletionsResponse {
|
|||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
||||
type Error = OpenAIError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Choice {
|
||||
|
|
@ -120,6 +147,59 @@ pub struct ChatCompletionStreamResponse {
|
|||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<StreamChoice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
pub struct SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
lines: I,
|
||||
}
|
||||
|
||||
impl<I> SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
pub fn new(lines: I) -> Self {
|
||||
Self { lines }
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Iterator for SseChatCompletionIter<I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<str>,
|
||||
{
|
||||
type Item = Result<ChatCompletionStreamResponse>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
for line in &mut self.lines {
|
||||
let line = line.as_ref();
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
let data = data.trim();
|
||||
if data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
return Some(
|
||||
serde_json::from_str::<ChatCompletionStreamResponse>(data)
|
||||
.map_err(OpenAIError::from),
|
||||
);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter<str::Lines<'a>> {
|
||||
type Error = OpenAIError;
|
||||
|
||||
fn try_from(bytes: &'a [u8]) -> Result<Self> {
|
||||
let s = std::str::from_utf8(bytes)?;
|
||||
Ok(SseChatCompletionIter::new(s.lines()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -191,4 +271,83 @@ mod tests {
|
|||
panic!("Expected MultiPartContent");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_streaming() {
|
||||
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]}
|
||||
data: [DONE]"#;
|
||||
|
||||
let iter = SseChatCompletionIter::new(json_data.lines());
|
||||
|
||||
println!("Testing SSE Streaming");
|
||||
for item in iter {
|
||||
match item {
|
||||
Ok(response) => {
|
||||
println!("Received response: {:?}", response);
|
||||
if response.choices.is_empty() {
|
||||
continue;
|
||||
}
|
||||
for choice in response.choices {
|
||||
if let Some(content) = choice.delta.content {
|
||||
println!("Content: {}", content);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Error parsing JSON: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_streaming_try_from_bytes() {
|
||||
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]}
|
||||
data: [DONE]"#;
|
||||
|
||||
let iter = SseChatCompletionIter::try_from(json_data.as_bytes())
|
||||
.expect("Failed to create SSE iterator");
|
||||
|
||||
println!("Testing SSE Streaming");
|
||||
for item in iter {
|
||||
match item {
|
||||
Ok(response) => {
|
||||
println!("Received response: {:?}", response);
|
||||
if response.choices.is_empty() {
|
||||
continue;
|
||||
}
|
||||
for choice in response.choices {
|
||||
if let Some(content) = choice.delta.content {
|
||||
println!("Content: {}", content);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Error parsing JSON: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_chat_completions_request() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"model": "None",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in seattle"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
} "#;
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(CHAT_COMPLETIONS_REQUEST.as_bytes())
|
||||
.expect("Failed to parse ChatCompletionsRequest");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue