mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
omni llm
This commit is contained in:
parent
9f59943041
commit
21a5bdd561
8 changed files with 239 additions and 47 deletions
|
|
@ -56,7 +56,7 @@ def docker_start_archgw_detached(
|
|||
volume_mappings = [
|
||||
f"{logs_path_abs}:/var/log:rw",
|
||||
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
||||
# "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
|
||||
"/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro",
|
||||
]
|
||||
volume_mappings_args = [
|
||||
item for volume in volume_mappings for item in ("-v", volume)
|
||||
|
|
|
|||
56
crates/Cargo.lock
generated
56
crates/Cargo.lock
generated
|
|
@ -64,9 +64,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.90"
|
||||
version = "1.0.97"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95"
|
||||
checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f"
|
||||
|
||||
[[package]]
|
||||
name = "arbitrary"
|
||||
|
|
@ -82,7 +82,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -487,7 +487,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -919,7 +919,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1061,6 +1061,7 @@ dependencies = [
|
|||
"http",
|
||||
"log",
|
||||
"md5",
|
||||
"omnillm",
|
||||
"proxy-wasm",
|
||||
"proxy-wasm-test-framework",
|
||||
"rand",
|
||||
|
|
@ -1160,6 +1161,15 @@ dependencies = [
|
|||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "omnillm"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.20.2"
|
||||
|
|
@ -1529,29 +1539,29 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.210"
|
||||
version = "1.0.219"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a"
|
||||
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.210"
|
||||
version = "1.0.219"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
|
||||
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.130"
|
||||
version = "1.0.140"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "610f75ff4a8e3cb29b85da56eabdd1bff5b06739059a4b8e2967fef32e5d9944"
|
||||
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
|
|
@ -1603,7 +1613,7 @@ checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1711,9 +1721,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.79"
|
||||
version = "2.0.87"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
|
||||
checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
|
@ -1728,7 +1738,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1772,7 +1782,7 @@ checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2071,7 +2081,7 @@ dependencies = [
|
|||
"anyhow",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
"wasmtime-component-util",
|
||||
"wasmtime-wit-bindgen",
|
||||
"wit-parser",
|
||||
|
|
@ -2201,7 +2211,7 @@ checksum = "a2bde986038b819bc43a21fef0610aeb47aabfe3ea09ca3533a7b81023b84ec6"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2450,7 +2460,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
|
|
@ -2472,7 +2482,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2492,7 +2502,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
|
|
@ -2515,7 +2525,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6"
|
|||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
"syn 2.0.87",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["llm_gateway", "prompt_gateway", "common"]
|
||||
members = ["llm_gateway", "prompt_gateway", "common", "omnillm"]
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ impl Serialize for FunctionParameters {
|
|||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
// select all requried parameters
|
||||
// select all required parameters
|
||||
let required: Vec<&String> = self
|
||||
.properties
|
||||
.iter()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ rand = "0.8.5"
|
|||
thiserror = "1.0.64"
|
||||
derivative = "2.2.0"
|
||||
sha2 = "0.10.8"
|
||||
omnillm = { version = "0.1.0", path = "../omnillm" }
|
||||
|
||||
[dev-dependencies]
|
||||
proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" }
|
||||
|
|
|
|||
|
|
@ -16,12 +16,14 @@ use common::tracing::{Event, Span, TraceData, Traceparent};
|
|||
use common::{ratelimit, routing, tokenizer};
|
||||
use http::StatusCode;
|
||||
use log::{debug, trace, warn};
|
||||
use omnillm::{ChatRequest, LlmProvider as LlmProviderTrait, LlmProviders as OmniLlmProviders};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
|
|
@ -40,9 +42,10 @@ pub struct StreamContext {
|
|||
ttft_time: Option<u128>,
|
||||
traceparent: Option<String>,
|
||||
request_body_sent_time: Option<u128>,
|
||||
user_message: Option<Message>,
|
||||
user_message: Option<omnillm::Message>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
omni_llm_providers: Rc<OmniLlmProviders>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -71,6 +74,7 @@ impl StreamContext {
|
|||
user_message: None,
|
||||
traces_queue,
|
||||
request_body_sent_time: None,
|
||||
omni_llm_providers: Rc::new(OmniLlmProviders::new()),
|
||||
}
|
||||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
|
|
@ -271,18 +275,17 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let mut deserialized_body: ChatCompletionsRequest =
|
||||
match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!("body str: {}", String::from_utf8_lossy(&body_bytes));
|
||||
self.send_server_error(
|
||||
ServerError::Deserialization(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
let mut deserialized_body: ChatRequest = match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!("body str: {}", String::from_utf8_lossy(&body_bytes));
|
||||
self.send_server_error(
|
||||
ServerError::Deserialization(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
// remove metadata from the request body
|
||||
//TODO: move this to prompt gateway
|
||||
|
|
@ -295,7 +298,7 @@ impl HttpContext for StreamContext {
|
|||
self.user_message = deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "user")
|
||||
.filter(|m| m.role == omnillm::Role::User)
|
||||
.last()
|
||||
.cloned();
|
||||
|
||||
|
|
@ -336,16 +339,32 @@ impl HttpContext for StreamContext {
|
|||
model_name,
|
||||
);
|
||||
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
let omn_provider_type =
|
||||
omnillm::Provider::from_str(&self.llm_provider().provider_interface.to_string())
|
||||
.unwrap();
|
||||
|
||||
trace!("request body: {}", chat_completion_request_str);
|
||||
let chat_request_bytes = self
|
||||
.omni_llm_providers
|
||||
.as_ref()
|
||||
.providers
|
||||
.get(&omn_provider_type)
|
||||
.unwrap()
|
||||
.translate_request(&deserialized_body)
|
||||
.unwrap();
|
||||
|
||||
if deserialized_body.stream {
|
||||
trace!(
|
||||
"request body str: {}",
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
if deserialized_body.stream.unwrap_or_default() {
|
||||
self.streaming_response = true;
|
||||
}
|
||||
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
|
||||
deserialized_body.stream_options = Some(StreamOptions {
|
||||
include_usage: true,
|
||||
if deserialized_body.stream.unwrap_or_default()
|
||||
&& deserialized_body.stream_options.is_none()
|
||||
{
|
||||
deserialized_body.stream_options = Some(omnillm::StreamOptions {
|
||||
include_usage: Some(true),
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -367,7 +386,7 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
|
||||
self.set_http_request_body(0, body_size, &chat_request_bytes);
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
|
|
|||
9
crates/omnillm/Cargo.toml
Normal file
9
crates/omnillm/Cargo.toml
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
[package]
|
||||
name = "omnillm"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.97"
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
153
crates/omnillm/src/lib.rs
Normal file
153
crates/omnillm/src/lib.rs
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
use std::{collections::HashMap, str::FromStr};
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Hash, PartialEq, Eq)]
|
||||
pub enum Provider {
|
||||
OpenAI,
|
||||
Mistral,
|
||||
}
|
||||
|
||||
impl FromStr for Provider {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"openai" => Ok(Provider::OpenAI),
|
||||
"mistral" => Ok(Provider::Mistral),
|
||||
_ => Err(anyhow::anyhow!(format!("Invalid provider: {}", s))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct StreamOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub include_usage: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, PartialEq, Clone)]
|
||||
pub enum Role {
|
||||
#[serde(rename = "system")]
|
||||
System,
|
||||
#[serde(rename = "user")]
|
||||
User,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: i64,
|
||||
pub completion_tokens: i64,
|
||||
pub total: i64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ChatResponse {
|
||||
pub messages: Vec<Message>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
pub trait LlmProvider {
|
||||
fn translate_request(&self, request: &ChatRequest) -> Result<Vec<u8>>;
|
||||
fn translate_response(&self, response: &Vec<u8>) -> Result<ChatResponse>;
|
||||
}
|
||||
|
||||
pub struct LlmProviders {
|
||||
pub providers: HashMap<Provider, Box<dyn LlmProvider>>,
|
||||
}
|
||||
|
||||
impl LlmProviders {
|
||||
pub fn new() -> LlmProviders {
|
||||
LlmProviders {
|
||||
providers: HashMap::from([
|
||||
(Provider::OpenAI, Box::new(OpenAI) as Box<dyn LlmProvider>),
|
||||
(Provider::Mistral, Box::new(Mistral) as Box<dyn LlmProvider>),
|
||||
]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenAI;
|
||||
impl LlmProvider for OpenAI {
|
||||
fn translate_request(&self, request: &ChatRequest) -> Result<Vec<u8>> {
|
||||
serde_json::to_string(request)
|
||||
.map(|s| s.into_bytes())
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn translate_response(&self, response: &Vec<u8>) -> Result<ChatResponse> {
|
||||
serde_json::from_slice(response).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mistral;
|
||||
impl LlmProvider for Mistral {
|
||||
fn translate_request(&self, request: &ChatRequest) -> Result<Vec<u8>> {
|
||||
serde_json::to_string(request)
|
||||
.map(|s| s.into_bytes())
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn translate_response(&self, response: &Vec<u8>) -> Result<ChatResponse> {
|
||||
serde_json::from_slice(response).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
#[test]
|
||||
fn test_translate_request_openai() {
|
||||
use super::{ChatRequest, LlmProvider, Message, OpenAI, Role};
|
||||
let request = ChatRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: Some("You are a helpful assistant.".to_string()),
|
||||
model: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: Some("I need help with my computer.".to_string()),
|
||||
model: None,
|
||||
},
|
||||
],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
stream: None,
|
||||
stream_options: None,
|
||||
};
|
||||
let openai = OpenAI;
|
||||
let result = openai.translate_request(&request).unwrap();
|
||||
let expected = r#"{"model":"gpt-3.5-turbo","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"I need help with my computer."}]}"#;
|
||||
assert_eq!(String::from_utf8(result).unwrap(), expected);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue