mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add more changes
This commit is contained in:
parent
2d4d0b01ee
commit
f10e0fcece
14 changed files with 486 additions and 139 deletions
224
crates/Cargo.lock
generated
224
crates/Cargo.lock
generated
|
|
@ -68,6 +68,21 @@ version = "0.2.18"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ansi_term"
|
||||
version = "0.12.1"
|
||||
|
|
@ -196,6 +211,7 @@ dependencies = [
|
|||
"eventsource-stream",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"hermesllm",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
|
|
@ -276,7 +292,11 @@ version = "0.4.41"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
|
||||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -288,7 +308,7 @@ dependencies = [
|
|||
"ansi_term",
|
||||
"atty",
|
||||
"bitflags 1.3.2",
|
||||
"strsim",
|
||||
"strsim 0.8.0",
|
||||
"textwrap",
|
||||
"unicode-width",
|
||||
"vec_map",
|
||||
|
|
@ -521,6 +541,41 @@ dependencies = [
|
|||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim 0.11.1",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugid"
|
||||
version = "0.8.0"
|
||||
|
|
@ -530,6 +585,16 @@ dependencies = [
|
|||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deranged"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e"
|
||||
dependencies = [
|
||||
"powerfmt",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derivative"
|
||||
version = "2.2.0"
|
||||
|
|
@ -1013,6 +1078,7 @@ dependencies = [
|
|||
"common",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
||||
|
|
@ -1238,6 +1304,30 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.63"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8"
|
||||
dependencies = [
|
||||
"android_system_properties",
|
||||
"core-foundation-sys",
|
||||
"iana-time-zone-haiku",
|
||||
"js-sys",
|
||||
"log",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone-haiku"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_collections"
|
||||
version = "1.5.0"
|
||||
|
|
@ -1362,6 +1452,12 @@ version = "2.2.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005"
|
||||
|
||||
[[package]]
|
||||
name = "ident_case"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "1.0.3"
|
||||
|
|
@ -1391,6 +1487,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
|||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown 0.12.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1677,6 +1774,12 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-conv"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
|
|
@ -1935,6 +2038,12 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "powerfmt"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.20"
|
||||
|
|
@ -2542,6 +2651,36 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_with"
|
||||
version = "3.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"hex",
|
||||
"indexmap 1.9.3",
|
||||
"indexmap 2.6.0",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"serde_with_macros",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_with_macros"
|
||||
version = "3.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_yaml"
|
||||
version = "0.9.34+deprecated"
|
||||
|
|
@ -2676,6 +2815,12 @@ version = "0.8.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "structopt"
|
||||
version = "0.3.26"
|
||||
|
|
@ -2871,6 +3016,37 @@ dependencies = [
|
|||
"rustc-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40"
|
||||
dependencies = [
|
||||
"deranged",
|
||||
"itoa",
|
||||
"num-conv",
|
||||
"powerfmt",
|
||||
"serde",
|
||||
"time-core",
|
||||
"time-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time-core"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c"
|
||||
|
||||
[[package]]
|
||||
name = "time-macros"
|
||||
version = "0.2.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49"
|
||||
dependencies = [
|
||||
"num-conv",
|
||||
"time-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinystr"
|
||||
version = "0.7.6"
|
||||
|
|
@ -3775,6 +3951,41 @@ dependencies = [
|
|||
"wasmtime-environ",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.61.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980"
|
||||
dependencies = [
|
||||
"windows-implement",
|
||||
"windows-interface",
|
||||
"windows-link",
|
||||
"windows-result",
|
||||
"windows-strings 0.4.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-implement"
|
||||
version = "0.60.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-interface"
|
||||
version = "0.59.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.1.1"
|
||||
|
|
@ -3788,7 +3999,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3"
|
||||
dependencies = [
|
||||
"windows-result",
|
||||
"windows-strings",
|
||||
"windows-strings 0.3.1",
|
||||
"windows-targets 0.53.0",
|
||||
]
|
||||
|
||||
|
|
@ -3810,6 +4021,15 @@ dependencies = [
|
|||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-strings"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.52.0"
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ eventsource-client = "0.15.0"
|
|||
eventsource-stream = "0.2.3"
|
||||
futures = "0.3.31"
|
||||
futures-util = "0.3.31"
|
||||
hermesllm = { version = "0.1.0", path = "../hermesllm" }
|
||||
http-body = "1.0.1"
|
||||
http-body-util = "0.1.3"
|
||||
hyper = { version = "1.6.0", features = ["full"] }
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::api::open_ai::ChatCompletionsRequest;
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use hermesllm::providers::openai::types::ChatCompletionsRequest;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
|
|
@ -32,13 +31,12 @@ pub async fn chat_completions(
|
|||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
let chat_completion_request: ChatCompletionsRequest =
|
||||
match serde_json::from_slice(&chat_request_bytes) {
|
||||
match ChatCompletionsRequest::try_from(chat_request_bytes.as_ref()) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
let v: Value = serde_json::from_slice(&chat_request_bytes).unwrap();
|
||||
warn!("arch-router request body string: {}", String::from_utf8_lossy(&chat_request_bytes));
|
||||
let err_msg = format!("Failed to parse request body: {}", err);
|
||||
warn!("{}", err_msg);
|
||||
warn!("arch-router request body: {}", v.to_string());
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use common::{
|
||||
api::open_ai::{ChatCompletionsResponse, ContentType, Message},
|
||||
configuration::{LlmProvider, LlmRoute},
|
||||
consts::ARCH_PROVIDER_HINT_HEADER,
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
|
@ -136,6 +136,11 @@ impl RouterService {
|
|||
}
|
||||
};
|
||||
|
||||
if chat_completion_response.choices.is_empty() {
|
||||
warn!("No choices in router response: {}", body);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(ContentType::Text(content)) =
|
||||
&chat_completion_response.choices[0].message.content
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use common::api::open_ai::{ChatCompletionsRequest, Message};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
use common::{
|
||||
api::open_ai::{ChatCompletionsRequest, ContentType, Message},
|
||||
configuration::LlmRoute,
|
||||
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
|
|
@ -121,11 +121,13 @@ impl RouterModel for RouterModelV1 {
|
|||
.iter()
|
||||
.rev()
|
||||
.map(|message| {
|
||||
Message::new(
|
||||
message.role.clone(),
|
||||
Message {
|
||||
role: message.role.clone(),
|
||||
// we can unwrap here because we have already filtered out messages without content
|
||||
message.content.as_ref().unwrap().to_string(),
|
||||
)
|
||||
content: Some(ContentType::Text(
|
||||
message.content.as_ref().unwrap().to_string(),
|
||||
)),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
||||
|
|
@ -141,14 +143,9 @@ impl RouterModel for RouterModelV1 {
|
|||
messages: vec![Message {
|
||||
content: Some(ContentType::Text(messages_content)),
|
||||
role: USER_ROLE.to_string(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
||||
}],
|
||||
tools: None,
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
metadata: None,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,4 +7,5 @@ edition = "2021"
|
|||
common = { version = "0.1.0", path = "../common" }
|
||||
serde = "1.0.219"
|
||||
serde_json = "1.0.140"
|
||||
serde_with = "3.12.0"
|
||||
thiserror = "2.0.12"
|
||||
|
|
|
|||
|
|
@ -5,12 +5,13 @@ pub mod providers;
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::providers::openai::types::OpenAIRequest;
|
||||
use crate::providers::openai::types::ChatCompletionsRequest;
|
||||
|
||||
|
||||
|
||||
#[test]
|
||||
fn openai_builder() {
|
||||
let request = OpenAIRequest::builder()
|
||||
.model("gpt-3.5-turbo")
|
||||
let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![])
|
||||
.temperature(0.7)
|
||||
.top_p(0.9)
|
||||
.n(1)
|
||||
|
|
@ -22,14 +23,14 @@ mod tests {
|
|||
.build()
|
||||
.expect("Failed to build OpenAIRequest");
|
||||
|
||||
assert_eq!(request.base.model, "gpt-3.5-turbo");
|
||||
assert_eq!(request.base.temperature, Some(0.7));
|
||||
assert_eq!(request.base.top_p, Some(0.9));
|
||||
assert_eq!(request.base.n, Some(1));
|
||||
assert_eq!(request.base.max_tokens, Some(100));
|
||||
assert_eq!(request.base.stream, Some(false));
|
||||
assert_eq!(request.base.stop, Some(vec!["\n".to_string()]));
|
||||
assert_eq!(request.base.presence_penalty, Some(0.0));
|
||||
assert_eq!(request.base.frequency_penalty, Some(0.0));
|
||||
assert_eq!(request.model, "gpt-3.5-turbo");
|
||||
assert_eq!(request.temperature, Some(0.7));
|
||||
assert_eq!(request.top_p, Some(0.9));
|
||||
assert_eq!(request.n, Some(1));
|
||||
assert_eq!(request.max_tokens, Some(100));
|
||||
assert_eq!(request.stream, Some(false));
|
||||
assert_eq!(request.stop, Some(vec!["\n".to_string()]));
|
||||
assert_eq!(request.presence_penalty, Some(0.0));
|
||||
assert_eq!(request.frequency_penalty, Some(0.0));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,44 +0,0 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatRequestBase {
|
||||
pub model: String,
|
||||
pub messages: Option<Vec<Message>>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub n: Option<u32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub stream: Option<bool>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub presence_penalty: Option<f32>,
|
||||
pub frequency_penalty: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatResponseBase {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: Message,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
|
@ -1,17 +1,19 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use crate::providers::common_types::{ChatRequestBase, ChatResponseBase};
|
||||
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
pub use crate::providers::openai::types::{Choice, Message, Usage};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeepSeekRequest {
|
||||
#[serde(flatten)]
|
||||
pub base: ChatRequestBase,
|
||||
pub base: ChatCompletionsRequest,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeepSeekResponse {
|
||||
#[serde(flatten)]
|
||||
pub base: ChatResponseBase,
|
||||
pub base: ChatCompletionsResponse,
|
||||
}
|
||||
|
||||
// Re-export for convenience
|
||||
pub use crate::providers::common_types::{Message, Choice, Usage};
|
||||
|
|
|
|||
|
|
@ -1,16 +1,19 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use crate::providers::common_types::{ChatRequestBase, ChatResponseBase};
|
||||
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
pub use crate::providers::openai::types::{Choice, Message, Usage};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GroqRequest {
|
||||
#[serde(flatten)]
|
||||
pub base: ChatRequestBase,
|
||||
pub base: ChatCompletionsRequest,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GroqResponse {
|
||||
#[serde(flatten)]
|
||||
pub base: ChatResponseBase,
|
||||
pub base: ChatCompletionsResponse,
|
||||
}
|
||||
|
||||
pub use crate::providers::common_types::{Message, Choice, Usage};
|
||||
|
|
|
|||
|
|
@ -1,12 +1,3 @@
|
|||
pub mod openai;
|
||||
pub mod groq;
|
||||
pub mod deepseek;
|
||||
pub mod common_types;
|
||||
|
||||
/// Supported LLM providers.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Provider {
|
||||
Groq,
|
||||
OpenAI,
|
||||
DeepSeek,
|
||||
}
|
||||
pub mod groq;
|
||||
pub mod openai;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ pub mod types;
|
|||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::providers::openai::types::{OpenAIRequest, OpenAIResponse};
|
||||
use crate::providers::openai::types::{ChatCompletionsRequest, ChatCompletionsResponse};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OpenAIError {
|
||||
|
|
@ -12,14 +12,14 @@ pub enum OpenAIError {
|
|||
|
||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
||||
|
||||
impl TryFrom<&[u8]> for OpenAIRequest {
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for OpenAIResponse {
|
||||
impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
||||
type Error = OpenAIError;
|
||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,124 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use crate::providers::common_types::{ChatRequestBase, ChatResponseBase};
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIRequest {
|
||||
#[serde(flatten)]
|
||||
pub base: ChatRequestBase,
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum MultiPartContentType {
|
||||
#[serde(rename = "text")]
|
||||
Text,
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl,
|
||||
}
|
||||
#[derive(Debug, Default, Clone)]
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct MultiPartContent {
|
||||
pub text: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
pub content_type: MultiPartContentType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum ContentType {
|
||||
Text(String),
|
||||
MultiPart(Vec<MultiPartContent>),
|
||||
}
|
||||
|
||||
impl Display for ContentType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ContentType::Text(text) => write!(f, "{}", text),
|
||||
ContentType::MultiPart(multi_part) => {
|
||||
let text_parts: Vec<String> = multi_part
|
||||
.iter()
|
||||
.filter_map(|part| {
|
||||
if part.content_type == MultiPartContentType::Text {
|
||||
part.text.clone()
|
||||
} else if part.content_type == MultiPartContentType::ImageUrl {
|
||||
// skip image URLs or their data in text representation
|
||||
None
|
||||
} else {
|
||||
panic!("Unsupported content type: {:?}", part.content_type);
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let combined_text = text_parts.join("\n");
|
||||
write!(f, "{}", combined_text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: Option<ContentType>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub n: Option<u32>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub stream: Option<bool>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub presence_penalty: Option<f32>,
|
||||
pub frequency_penalty: Option<f32>,
|
||||
}
|
||||
|
||||
impl Default for ChatCompletionsRequest {
|
||||
fn default() -> Self {
|
||||
ChatCompletionsRequest {
|
||||
model: String::new(),
|
||||
messages: Vec::new(),
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
n: None,
|
||||
max_tokens: None,
|
||||
stream: None,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: Message,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OpenAIRequestBuilder {
|
||||
model: Option<String>,
|
||||
messages: Option<Vec<crate::providers::common_types::Message>>,
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
temperature: Option<f32>,
|
||||
top_p: Option<f32>,
|
||||
n: Option<u32>,
|
||||
|
|
@ -21,18 +130,19 @@ pub struct OpenAIRequestBuilder {
|
|||
}
|
||||
|
||||
impl OpenAIRequestBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = Some(model.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn messages(mut self, messages: Vec<crate::providers::common_types::Message>) -> Self {
|
||||
self.messages = Some(messages);
|
||||
self
|
||||
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
messages,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
n: None,
|
||||
max_tokens: None,
|
||||
stream: None,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||
|
|
@ -75,10 +185,9 @@ impl OpenAIRequestBuilder {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<OpenAIRequest, &'static str> {
|
||||
let model = self.model.ok_or("model is required")?;
|
||||
let base = crate::providers::common_types::ChatRequestBase {
|
||||
model,
|
||||
pub fn build(self) -> Result<ChatCompletionsRequest, &'static str> {
|
||||
let request = ChatCompletionsRequest {
|
||||
model: self.model,
|
||||
messages: self.messages,
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
|
|
@ -89,20 +198,83 @@ impl OpenAIRequestBuilder {
|
|||
presence_penalty: self.presence_penalty,
|
||||
frequency_penalty: self.frequency_penalty,
|
||||
};
|
||||
Ok(OpenAIRequest { base })
|
||||
Ok(request)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIRequest {
|
||||
pub fn builder() -> OpenAIRequestBuilder {
|
||||
OpenAIRequestBuilder::new()
|
||||
impl ChatCompletionsRequest {
|
||||
pub fn builder(model: impl Into<String>, messages: Vec<Message>) -> OpenAIRequestBuilder {
|
||||
OpenAIRequestBuilder::new(model, messages)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIResponse {
|
||||
#[serde(flatten)]
|
||||
pub base: ChatResponseBase,
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
pub use crate::providers::common_types::{Message, Choice, Usage};
|
||||
#[test]
|
||||
fn test_content_type_display() {
|
||||
let text_content = ContentType::Text("Hello, world!".to_string());
|
||||
assert_eq!(text_content.to_string(), "Hello, world!");
|
||||
|
||||
let multi_part_content = ContentType::MultiPart(vec![
|
||||
MultiPartContent {
|
||||
text: Some("This is a text part.".to_string()),
|
||||
content_type: MultiPartContentType::Text,
|
||||
},
|
||||
MultiPartContent {
|
||||
text: Some("https://example.com/image.png".to_string()),
|
||||
content_type: MultiPartContentType::ImageUrl,
|
||||
},
|
||||
]);
|
||||
assert_eq!(multi_part_content.to_string(), "This is a text part.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_request_text_type_array() {
|
||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What city do you want to know the weather for?"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hello world"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
||||
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
|
||||
if let Some(ContentType::MultiPart(multi_part_content)) =
|
||||
chat_completions_request.messages[0].content.as_ref()
|
||||
{
|
||||
assert_eq!(multi_part_content.len(), 2);
|
||||
assert_eq!(
|
||||
multi_part_content[0].content_type,
|
||||
MultiPartContentType::Text
|
||||
);
|
||||
assert_eq!(
|
||||
multi_part_content[0].text,
|
||||
Some("What city do you want to know the weather for?".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
multi_part_content[1].content_type,
|
||||
MultiPartContentType::Text
|
||||
);
|
||||
assert_eq!(multi_part_content[1].text, Some("hello world".to_string()));
|
||||
} else {
|
||||
panic!("Expected MultiPartContent");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue