add more changes

This commit is contained in:
Adil Hafeez 2025-06-03 15:00:57 -07:00
parent 2d4d0b01ee
commit f10e0fcece
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
14 changed files with 486 additions and 139 deletions

224
crates/Cargo.lock generated
View file

@ -68,6 +68,21 @@ version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]]
name = "ansi_term"
version = "0.12.1"
@ -196,6 +211,7 @@ dependencies = [
"eventsource-stream",
"futures",
"futures-util",
"hermesllm",
"http-body 1.0.1",
"http-body-util",
"hyper 1.6.0",
@ -276,7 +292,11 @@ version = "0.4.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
dependencies = [
"android-tzdata",
"iana-time-zone",
"num-traits",
"serde",
"windows-link",
]
[[package]]
@ -288,7 +308,7 @@ dependencies = [
"ansi_term",
"atty",
"bitflags 1.3.2",
"strsim",
"strsim 0.8.0",
"textwrap",
"unicode-width",
"vec_map",
@ -521,6 +541,41 @@ dependencies = [
"typenum",
]
[[package]]
name = "darling"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.11.1",
"syn 2.0.87",
]
[[package]]
name = "darling_macro"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [
"darling_core",
"quote",
"syn 2.0.87",
]
[[package]]
name = "debugid"
version = "0.8.0"
@ -530,6 +585,16 @@ dependencies = [
"uuid",
]
[[package]]
name = "deranged"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e"
dependencies = [
"powerfmt",
"serde",
]
[[package]]
name = "derivative"
version = "2.2.0"
@ -1013,6 +1078,7 @@ dependencies = [
"common",
"serde",
"serde_json",
"serde_with",
"thiserror 2.0.12",
]
@ -1238,6 +1304,30 @@ dependencies = [
"tracing",
]
[[package]]
name = "iana-time-zone"
version = "0.1.63"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"log",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
name = "icu_collections"
version = "1.5.0"
@ -1362,6 +1452,12 @@ version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005"
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "1.0.3"
@ -1391,6 +1487,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown 0.12.3",
"serde",
]
[[package]]
@ -1677,6 +1774,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-traits"
version = "0.2.19"
@ -1935,6 +2038,12 @@ dependencies = [
"serde",
]
[[package]]
name = "powerfmt"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
version = "0.2.20"
@ -2542,6 +2651,36 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_with"
version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa"
dependencies = [
"base64 0.22.1",
"chrono",
"hex",
"indexmap 1.9.3",
"indexmap 2.6.0",
"serde",
"serde_derive",
"serde_json",
"serde_with_macros",
"time",
]
[[package]]
name = "serde_with_macros"
version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.87",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
@ -2676,6 +2815,12 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "structopt"
version = "0.3.26"
@ -2871,6 +3016,37 @@ dependencies = [
"rustc-hash",
]
[[package]]
name = "time"
version = "0.3.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40"
dependencies = [
"deranged",
"itoa",
"num-conv",
"powerfmt",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c"
[[package]]
name = "time-macros"
version = "0.2.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49"
dependencies = [
"num-conv",
"time-core",
]
[[package]]
name = "tinystr"
version = "0.7.6"
@ -3775,6 +3951,41 @@ dependencies = [
"wasmtime-environ",
]
[[package]]
name = "windows-core"
version = "0.61.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980"
dependencies = [
"windows-implement",
"windows-interface",
"windows-link",
"windows-result",
"windows-strings 0.4.2",
]
[[package]]
name = "windows-implement"
version = "0.60.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.87",
]
[[package]]
name = "windows-interface"
version = "0.59.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.87",
]
[[package]]
name = "windows-link"
version = "0.1.1"
@ -3788,7 +3999,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3"
dependencies = [
"windows-result",
"windows-strings",
"windows-strings 0.3.1",
"windows-targets 0.53.0",
]
@ -3810,6 +4021,15 @@ dependencies = [
"windows-link",
]
[[package]]
name = "windows-strings"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-sys"
version = "0.52.0"

View file

@ -10,6 +10,7 @@ eventsource-client = "0.15.0"
eventsource-stream = "0.2.3"
futures = "0.3.31"
futures-util = "0.3.31"
hermesllm = { version = "0.1.0", path = "../hermesllm" }
http-body = "1.0.1"
http-body-util = "0.1.3"
hyper = { version = "1.6.0", features = ["full"] }

View file

@ -1,14 +1,13 @@
use std::sync::Arc;
use bytes::Bytes;
use common::api::open_ai::ChatCompletionsRequest;
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use hermesllm::providers::openai::types::ChatCompletionsRequest;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use serde_json::Value;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
@ -32,13 +31,12 @@ pub async fn chat_completions(
let chat_request_bytes = request.collect().await?.to_bytes();
let chat_completion_request: ChatCompletionsRequest =
match serde_json::from_slice(&chat_request_bytes) {
match ChatCompletionsRequest::try_from(chat_request_bytes.as_ref()) {
Ok(request) => request,
Err(err) => {
let v: Value = serde_json::from_slice(&chat_request_bytes).unwrap();
warn!("arch-router request body string: {}", String::from_utf8_lossy(&chat_request_bytes));
let err_msg = format!("Failed to parse request body: {}", err);
warn!("{}", err_msg);
warn!("arch-router request body: {}", v.to_string());
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);

View file

@ -1,10 +1,10 @@
use std::sync::Arc;
use common::{
api::open_ai::{ChatCompletionsResponse, ContentType, Message},
configuration::{LlmProvider, LlmRoute},
consts::ARCH_PROVIDER_HINT_HEADER,
};
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
use hyper::header;
use thiserror::Error;
use tracing::{debug, info, warn};
@ -136,6 +136,11 @@ impl RouterService {
}
};
if chat_completion_response.choices.is_empty() {
warn!("No choices in router response: {}", body);
return Ok(None);
}
if let Some(ContentType::Text(content)) =
&chat_completion_response.choices[0].message.content
{

View file

@ -1,4 +1,4 @@
use common::api::open_ai::{ChatCompletionsRequest, Message};
use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message};
use thiserror::Error;
#[derive(Debug, Error)]

View file

@ -1,8 +1,8 @@
use common::{
api::open_ai::{ChatCompletionsRequest, ContentType, Message},
configuration::LlmRoute,
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
};
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
@ -121,11 +121,13 @@ impl RouterModel for RouterModelV1 {
.iter()
.rev()
.map(|message| {
Message::new(
message.role.clone(),
Message {
role: message.role.clone(),
// we can unwrap here because we have already filtered out messages without content
message.content.as_ref().unwrap().to_string(),
)
content: Some(ContentType::Text(
message.content.as_ref().unwrap().to_string(),
)),
}
})
.collect::<Vec<Message>>();
@ -141,14 +143,9 @@ impl RouterModel for RouterModelV1 {
messages: vec![Message {
content: Some(ContentType::Text(messages_content)),
role: USER_ROLE.to_string(),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: None,
stream: false,
stream_options: None,
metadata: None,
..Default::default()
}
}

View file

@ -7,4 +7,5 @@ edition = "2021"
common = { version = "0.1.0", path = "../common" }
serde = "1.0.219"
serde_json = "1.0.140"
serde_with = "3.12.0"
thiserror = "2.0.12"

View file

@ -5,12 +5,13 @@ pub mod providers;
#[cfg(test)]
mod tests {
use crate::providers::openai::types::OpenAIRequest;
use crate::providers::openai::types::ChatCompletionsRequest;
#[test]
fn openai_builder() {
let request = OpenAIRequest::builder()
.model("gpt-3.5-turbo")
let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![])
.temperature(0.7)
.top_p(0.9)
.n(1)
@ -22,14 +23,14 @@ mod tests {
.build()
.expect("Failed to build OpenAIRequest");
assert_eq!(request.base.model, "gpt-3.5-turbo");
assert_eq!(request.base.temperature, Some(0.7));
assert_eq!(request.base.top_p, Some(0.9));
assert_eq!(request.base.n, Some(1));
assert_eq!(request.base.max_tokens, Some(100));
assert_eq!(request.base.stream, Some(false));
assert_eq!(request.base.stop, Some(vec!["\n".to_string()]));
assert_eq!(request.base.presence_penalty, Some(0.0));
assert_eq!(request.base.frequency_penalty, Some(0.0));
assert_eq!(request.model, "gpt-3.5-turbo");
assert_eq!(request.temperature, Some(0.7));
assert_eq!(request.top_p, Some(0.9));
assert_eq!(request.n, Some(1));
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.stream, Some(false));
assert_eq!(request.stop, Some(vec!["\n".to_string()]));
assert_eq!(request.presence_penalty, Some(0.0));
assert_eq!(request.frequency_penalty, Some(0.0));
}
}

View file

@ -1,44 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatRequestBase {
pub model: String,
pub messages: Option<Vec<Message>>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub n: Option<u32>,
pub max_tokens: Option<u32>,
pub stream: Option<bool>,
pub stop: Option<Vec<String>>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatResponseBase {
pub id: String,
pub object: String,
pub created: u64,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Choice {
pub index: u32,
pub message: Message,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}

View file

@ -1,17 +1,19 @@
use serde::{Deserialize, Serialize};
use crate::providers::common_types::{ChatRequestBase, ChatResponseBase};
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
pub use crate::providers::openai::types::{Choice, Message, Usage};
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeepSeekRequest {
#[serde(flatten)]
pub base: ChatRequestBase,
pub base: ChatCompletionsRequest,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeepSeekResponse {
#[serde(flatten)]
pub base: ChatResponseBase,
pub base: ChatCompletionsResponse,
}
// Re-export for convenience
pub use crate::providers::common_types::{Message, Choice, Usage};

View file

@ -1,16 +1,19 @@
use serde::{Deserialize, Serialize};
use crate::providers::common_types::{ChatRequestBase, ChatResponseBase};
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
pub use crate::providers::openai::types::{Choice, Message, Usage};
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroqRequest {
#[serde(flatten)]
pub base: ChatRequestBase,
pub base: ChatCompletionsRequest,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GroqResponse {
#[serde(flatten)]
pub base: ChatResponseBase,
pub base: ChatCompletionsResponse,
}
pub use crate::providers::common_types::{Message, Choice, Usage};

View file

@ -1,12 +1,3 @@
pub mod openai;
pub mod groq;
pub mod deepseek;
pub mod common_types;
/// Supported LLM providers.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Provider {
Groq,
OpenAI,
DeepSeek,
}
pub mod groq;
pub mod openai;

View file

@ -2,7 +2,7 @@ pub mod types;
use thiserror::Error;
use crate::providers::openai::types::{OpenAIRequest, OpenAIResponse};
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
#[derive(Debug, Error)]
pub enum OpenAIError {
@ -12,14 +12,14 @@ pub enum OpenAIError {
type Result<T> = std::result::Result<T, OpenAIError>;
impl TryFrom<&[u8]> for OpenAIRequest {
impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(OpenAIError::from)
}
}
impl TryFrom<&[u8]> for OpenAIResponse {
impl TryFrom<&[u8]> for ChatCompletionsResponse {
type Error = OpenAIError;
fn try_from(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(OpenAIError::from)

View file

@ -1,15 +1,124 @@
use serde::{Deserialize, Serialize};
use crate::providers::common_types::{ChatRequestBase, ChatResponseBase};
use std::fmt::Display;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIRequest {
#[serde(flatten)]
pub base: ChatRequestBase,
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum MultiPartContentType {
#[serde(rename = "text")]
Text,
#[serde(rename = "image_url")]
ImageUrl,
}
#[derive(Debug, Default, Clone)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MultiPartContent {
pub text: Option<String>,
#[serde(rename = "type")]
pub content_type: MultiPartContentType,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ContentType {
Text(String),
MultiPart(Vec<MultiPartContent>),
}
impl Display for ContentType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ContentType::Text(text) => write!(f, "{}", text),
ContentType::MultiPart(multi_part) => {
let text_parts: Vec<String> = multi_part
.iter()
.filter_map(|part| {
if part.content_type == MultiPartContentType::Text {
part.text.clone()
} else if part.content_type == MultiPartContentType::ImageUrl {
// skip image URLs or their data in text representation
None
} else {
panic!("Unsupported content type: {:?}", part.content_type);
}
})
.collect();
let combined_text = text_parts.join("\n");
write!(f, "{}", combined_text)
}
}
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: Option<ContentType>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsRequest {
pub model: String,
pub messages: Vec<Message>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub n: Option<u32>,
pub max_tokens: Option<u32>,
pub stream: Option<bool>,
pub stop: Option<Vec<String>>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
}
impl Default for ChatCompletionsRequest {
fn default() -> Self {
ChatCompletionsRequest {
model: String::new(),
messages: Vec::new(),
temperature: None,
top_p: None,
n: None,
max_tokens: None,
stream: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
}
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionsResponse {
pub id: String,
pub object: String,
pub created: u64,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Choice {
pub index: u32,
pub message: Message,
pub finish_reason: Option<String>,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct OpenAIRequestBuilder {
model: Option<String>,
messages: Option<Vec<crate::providers::common_types::Message>>,
model: String,
messages: Vec<Message>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<u32>,
@ -21,18 +130,19 @@ pub struct OpenAIRequestBuilder {
}
impl OpenAIRequestBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn messages(mut self, messages: Vec<crate::providers::common_types::Message>) -> Self {
self.messages = Some(messages);
self
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
model: model.into(),
messages,
temperature: None,
top_p: None,
n: None,
max_tokens: None,
stream: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
}
}
pub fn temperature(mut self, temperature: f32) -> Self {
@ -75,10 +185,9 @@ impl OpenAIRequestBuilder {
self
}
pub fn build(self) -> Result<OpenAIRequest, &'static str> {
let model = self.model.ok_or("model is required")?;
let base = crate::providers::common_types::ChatRequestBase {
model,
pub fn build(self) -> Result<ChatCompletionsRequest, &'static str> {
let request = ChatCompletionsRequest {
model: self.model,
messages: self.messages,
temperature: self.temperature,
top_p: self.top_p,
@ -89,20 +198,83 @@ impl OpenAIRequestBuilder {
presence_penalty: self.presence_penalty,
frequency_penalty: self.frequency_penalty,
};
Ok(OpenAIRequest { base })
Ok(request)
}
}
impl OpenAIRequest {
pub fn builder() -> OpenAIRequestBuilder {
OpenAIRequestBuilder::new()
impl ChatCompletionsRequest {
pub fn builder(model: impl Into<String>, messages: Vec<Message>) -> OpenAIRequestBuilder {
OpenAIRequestBuilder::new(model, messages)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIResponse {
#[serde(flatten)]
pub base: ChatResponseBase,
}
#[cfg(test)]
mod tests {
use super::*;
pub use crate::providers::common_types::{Message, Choice, Usage};
#[test]
fn test_content_type_display() {
let text_content = ContentType::Text("Hello, world!".to_string());
assert_eq!(text_content.to_string(), "Hello, world!");
let multi_part_content = ContentType::MultiPart(vec![
MultiPartContent {
text: Some("This is a text part.".to_string()),
content_type: MultiPartContentType::Text,
},
MultiPartContent {
text: Some("https://example.com/image.png".to_string()),
content_type: MultiPartContentType::ImageUrl,
},
]);
assert_eq!(multi_part_content.to_string(), "This is a text part.");
}
#[test]
fn test_chat_completions_request_text_type_array() {
const CHAT_COMPLETIONS_REQUEST: &str = r#"
{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What city do you want to know the weather for?"
},
{
"type": "text",
"text": "hello world"
}
]
}
]
}
"#;
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
if let Some(ContentType::MultiPart(multi_part_content)) =
chat_completions_request.messages[0].content.as_ref()
{
assert_eq!(multi_part_content.len(), 2);
assert_eq!(
multi_part_content[0].content_type,
MultiPartContentType::Text
);
assert_eq!(
multi_part_content[0].text,
Some("What city do you want to know the weather for?".to_string())
);
assert_eq!(
multi_part_content[1].content_type,
MultiPartContentType::Text
);
assert_eq!(multi_part_content[1].text, Some("hello world".to_string()));
} else {
panic!("Expected MultiPartContent");
}
}
}