Add workflow logic for weather forecast demo (#24)

This commit is contained in:
Adil Hafeez 2024-07-30 16:23:23 -07:00 committed by GitHub
parent 7ef68eccfb
commit 33f9dd22e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 1902 additions and 459 deletions

View file

@ -23,9 +23,53 @@ pub struct StoreVectorEmbeddingsRequest {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)]
pub enum CallContext {
EmbeddingRequest(EmbeddingRequest),
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
CreateVectorCollection(String),
}
// https://api.qdrant.tech/master/api-reference/search/points
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointsRequest {
pub vector: Vec<f64>,
pub limit: i32,
pub with_payload: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointResult {
pub id: String,
pub version: i32,
pub score: f64,
pub payload: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointsResponse {
pub result: Vec<SearchPointResult>,
pub status: String,
pub time: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NERRequest {
pub input: String,
pub labels: Vec<String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Entity {
pub text: String,
pub label: String,
pub score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NERResponse {
pub data: Vec<Entity>,
pub model: String,
}
pub mod open_ai {
@ -41,6 +85,6 @@ pub mod open_ai {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
pub content: Option<String>,
}
}

View file

@ -28,7 +28,7 @@ pub struct PromptConfig {
pub timeout_ms: u64,
pub embedding_provider: EmbeddingProviver,
pub llm_providers: Vec<LlmProvider>,
pub system_prompt: String,
pub system_prompt: Option<String>,
pub prompt_targets: Vec<PromptTarget>,
}
@ -49,6 +49,29 @@ pub struct LlmProvider {
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct Endpoint {
pub cluster: String,
pub path: Option<String>,
pub method: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct EntityDetail {
pub name: String,
pub required: Option<bool>,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EntityType {
Vec(Vec<String>),
Struct(Vec<EntityDetail>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct PromptTarget {
@ -56,7 +79,9 @@ pub struct PromptTarget {
pub prompt_type: String,
pub name: String,
pub few_shot_examples: Vec<String>,
pub endpoint: String,
pub entities: Option<EntityType>,
pub endpoint: Option<Endpoint>,
pub system_prompt: Option<String>,
}
#[cfg(test)]
@ -88,13 +113,23 @@ katanemo-prompt-config:
name: weather-forecast
few-shot-examples:
- what is the weather in New York?
endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT"
cache-response: true
cache-response-settings:
- cache-ttl-secs: 3600 # cache expiry in seconds
- cache-max-size: 1000 # in number of items
- cache-eviction-strategy: LRU
endpoint:
cluster: weatherhost
path: /weather
entities:
- name: location
required: true
description: "The location for which the weather is requested"
- type: context-resolver
name: weather-forecast-2
few-shot-examples:
- what is the weather in New York?
endpoint:
cluster: weatherhost
path: /weather
entities:
- city
"#;
#[test]

View file

@ -1 +1,7 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
pub const DEFAULT_NER_MODEL: &str = "urchade/gliner_large-v2.1";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
pub const DEFAULT_NER_THRESHOLD: f64 = 0.6;
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";

View file

@ -0,0 +1,289 @@
use common_types::{CallContext, EmbeddingRequest};
use configuration::PromptTarget;
use log::info;
use md5::Digest;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::to_string;
use stats::RecordingMetric;
use std::collections::HashMap;
use std::time::Duration;
use consts::DEFAULT_EMBEDDING_MODEL;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use crate::common_types;
use crate::configuration;
use crate::consts;
use crate::stats;
use crate::stream_context::StreamContext;
#[derive(Copy, Clone)]
struct WasmMetrics {
active_http_calls: stats::Gauge,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
}
}
}
pub struct FilterContext {
metrics: WasmMetrics,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, common_types::CallContext>,
config: Option<configuration::Configuration>,
}
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
metrics: WasmMetrics::new(),
}
}
fn process_prompt_targets(&mut self) {
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
for few_shot_example in &prompt_target.few_shot_examples {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(
few_shot_example.to_string(),
)),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
// TODO: Handle potential errors
let json_data: String = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
let embedding_request = EmbeddingRequest {
create_embedding_request: embeddings_input,
prompt_target: prompt_target.clone(),
};
if self
.callouts
.insert(token_id, {
CallContext::EmbeddingRequest(embedding_request)
})
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
}
fn embedding_request_handler(
&mut self,
body_size: usize,
create_embedding_request: CreateEmbeddingRequest,
prompt_target: PromptTarget,
) {
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let embedding_response: CreateEmbeddingResponse =
serde_json::from_slice(&body).unwrap();
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
to_string(&prompt_target).unwrap(),
);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(input) => {
id = Some(md5::compute(&input));
payload.insert("input".to_string(), input);
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],
};
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
}
fn create_vector_store_points_handler(&self, body_size: usize) {
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
info!(
"response body: len {:?}",
String::from_utf8(body).unwrap().len()
);
}
}
}
//TODO: run once per envoy instance, right now it runs once per worker
fn init_vector_store(&mut self) {
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(b"{ \"vectors\": { \"size\": 1024, \"distance\": \"Cosine\"}}"),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for init-vector-store: {:?}", e);
}
};
if self
.callouts
.insert(
token_id,
CallContext::CreateVectorCollection("prompt_vector_store".to_string()),
)
.is_some()
{
panic!("duplicate token_id")
}
// self.metrics
// .active_http_calls
// .record(self.callouts.len().try_into().unwrap());
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
}
common_types::CallContext::StoreVectorEmbeddings(_) => {
self.create_vector_store_points_handler(body_size)
}
common_types::CallContext::CreateVectorCollection(_) => {
let mut http_status_code = "Nil".to_string();
self.get_http_call_response_headers()
.iter()
.for_each(|(k, v)| {
if k == ":status" {
http_status_code.clone_from(v);
}
});
info!("CreateVectorCollection response: {}", http_status_code);
}
}
}
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
}
true
}
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
Some(Box::new(StreamContext {
host_header: None,
callouts: HashMap::new(),
}))
}
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _: usize) -> bool {
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
// initialize vector store
self.init_vector_store();
self.process_prompt_targets();
self.set_tick_period(Duration::from_secs(0));
}
}

View file

@ -1,22 +1,13 @@
use common_types::{CallContext, EmbeddingRequest};
use configuration::PromptTarget;
use http::StatusCode;
use log::error;
use log::info;
use md5::Digest;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use stats::{Gauge, RecordingMetric};
use std::collections::HashMap;
use std::time::Duration;
mod common_types;
mod configuration;
mod consts;
mod filter_context;
mod stats;
mod stream_context;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
@ -24,397 +15,3 @@ proxy_wasm::main! {{
Box::new(FilterContext::new())
});
}}
struct StreamContext {
host_header: Option<String>,
}
impl StreamContext {
fn save_host_header(&mut self) {
// Save the host header to be used by filter logic later on.
self.host_header = self.get_http_request_header(":host");
}
fn delete_content_length_header(&mut self) {
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
// Server's generally throw away requests whose body length do not match the Content-Length header.
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
// manipulate the body in benign ways e.g., compression.
self.set_http_request_header("content-length", None);
}
fn modify_path_header(&mut self) {
match self.get_http_request_header(":path") {
// The gateway can start gathering information necessary for routing. For now change the path to an
// OpenAI API path.
Some(path) if path == "/llmrouting" => {
self.set_http_request_header(":path", Some("/v1/chat/completions"));
}
// Otherwise let the filter continue.
_ => (),
}
}
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
self.save_host_header();
self.delete_content_length_header();
self.modify_path_header();
Action::Continue
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
// Let the client send the gateway all the data before sending to the LLM_provider.
// TODO: consider a streaming API.
if !end_of_stream {
return Action::Pause;
}
if body_size == 0 {
return Action::Continue;
}
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: common_types::open_ai::ChatCompletions =
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
vec![],
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
);
return Action::Pause;
}
},
None => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!(
"Failed to obtain body bytes even though body_size is {}",
body_size
);
return Action::Pause;
}
};
// Modify JSON payload
deserialized_body.model = String::from("gpt-3.5-turbo");
match serde_json::to_string(&deserialized_body) {
Ok(json_string) => {
self.set_http_request_body(0, body_size, &json_string.into_bytes());
Action::Continue
}
Err(error) => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!("Failed to serialize body: {}", error);
Action::Pause
}
}
}
}
impl Context for StreamContext {}
#[derive(Copy, Clone)]
struct WasmMetrics {
active_http_calls: Gauge,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
}
}
}
struct FilterContext {
metrics: WasmMetrics,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, common_types::CallContext>,
config: Option<configuration::Configuration>,
}
impl FilterContext {
fn new() -> FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
metrics: WasmMetrics::new(),
}
}
fn process_prompt_targets(&mut self) {
for prompt_target in &self
.config
.as_ref()
.expect("Gateway configuration cannot be non-existent")
.prompt_config
.prompt_targets
{
for few_shot_example in &prompt_target.few_shot_examples {
info!("few_shot_example: {:?}", few_shot_example);
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(
few_shot_example.to_string(),
)),
model: String::from(consts::DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data: String = match serde_json::to_string(&embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
}
};
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching embedding server HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
let embedding_request = EmbeddingRequest {
create_embedding_request: embeddings_input,
prompt_target: prompt_target.clone(),
};
if self
.callouts
.insert(token_id, {
CallContext::EmbeddingRequest(embedding_request)
})
.is_some()
{
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
}
fn embedding_request_handler(
&mut self,
body_size: usize,
create_embedding_request: CreateEmbeddingRequest,
prompt_target: PromptTarget,
) {
info!("response received for CreateEmbeddingRequest");
let body = match self.get_http_call_response_body(0, body_size) {
Some(body) => body,
None => {
return;
}
};
if body.is_empty() {
return;
}
let (json_data, create_vector_store_points) =
match build_qdrant_data(&body, create_embedding_request, &prompt_target) {
Ok(tuple) => tuple,
Err(error) => {
panic!(
"Error building qdrant data from embedding response {}",
error
);
}
};
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching qdrant HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id={} in qdrant requests", token_id)
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
// TODO: @adilhafeez implement.
fn create_vector_store_points_handler(&self, body_size: usize) {
info!("response received for CreateVectorStorePoints");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
info!("response body: {:?}", String::from_utf8(body).unwrap());
}
}
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
info!("on_http_call_response: token_id = {}", token_id);
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
}
common_types::CallContext::StoreVectorEmbeddings(_) => {
self.create_vector_store_points_handler(body_size)
}
}
}
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = match serde_yaml::from_slice(&config_bytes) {
Ok(config) => config,
Err(error) => {
panic!("Failed to deserialize plugin configuration: {}", error);
}
};
info!("on_configure: plugin configuration loaded");
}
true
}
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
Some(Box::new(StreamContext { host_header: None }))
}
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _: usize) -> bool {
info!("on_vm_start: setting up tick timeout");
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
info!("on_tick: starting to process prompt targets");
self.process_prompt_targets();
self.set_tick_period(Duration::from_secs(0));
}
}
fn build_qdrant_data(
embedding_data: &[u8],
create_embedding_request: CreateEmbeddingRequest,
prompt_target: &PromptTarget,
) -> Result<(String, common_types::StoreVectorEmbeddingsRequest), serde_json::Error> {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(embedding_data) {
Ok(embedding_response) => embedding_response,
Err(error) => {
panic!("Failed to deserialize embedding response: {}", error);
}
};
info!(
"embedding_response model: {}, vector len: {}",
embedding_response.model,
embedding_response.data[0].embedding.len()
);
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
serde_json::to_string(&prompt_target)?,
);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(ref input) => {
id = Some(md5::compute(input));
payload.insert("input".to_string(), input.clone());
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],
};
let json_data = serde_json::to_string(&create_vector_store_points)?;
info!(
"create_vector_store_points: points length: {}",
embedding_response.data[0].embedding.len()
);
Ok((json_data, create_vector_store_points))
}

View file

@ -0,0 +1,520 @@
use http::StatusCode;
use log::error;
use log::info;
use log::warn;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use std::collections::HashMap;
use std::time::Duration;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use consts::{
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
};
use crate::common_types;
use crate::common_types::open_ai::Message;
use crate::common_types::SearchPointsResponse;
use crate::configuration::EntityDetail;
use crate::configuration::EntityType;
use crate::configuration::PromptTarget;
use crate::consts;
enum RequestType {
GetEmbedding,
SearchPoints,
Ner,
ContextResolver,
}
pub struct CallContext {
request_type: RequestType,
user_message: String,
prompt_target: Option<PromptTarget>,
request_body: common_types::open_ai::ChatCompletions,
}
pub struct StreamContext {
pub host_header: Option<String>,
pub callouts: HashMap<u32, CallContext>,
}
impl StreamContext {
fn save_host_header(&mut self) {
// Save the host header to be used by filter logic later on.
self.host_header = self.get_http_request_header(":host");
}
fn delete_content_length_header(&mut self) {
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
// Server's generally throw away requests whose body length do not match the Content-Length header.
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
// manipulate the body in benign ways e.g., compression.
self.set_http_request_header("content-length", None);
}
fn modify_path_header(&mut self) {
match self.get_http_request_header(":path") {
// The gateway can start gathering information necessary for routing. For now change the path to an
// OpenAI API path.
Some(path) if path == "/llmrouting" => {
self.set_http_request_header(":path", Some("/v1/chat/completions"));
}
// Otherwise let the filter continue.
_ => (),
}
}
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response,
Err(e) => {
warn!("Error deserializing embedding response: {:?}", e);
self.resume_http_request();
return;
}
};
let search_points_request = common_types::SearchPointsRequest {
vector: embedding_response.data[0].embedding.clone(),
limit: 10,
with_payload: true,
};
let json_data: String = match serde_json::to_string(&search_points_request) {
Ok(json_data) => json_data,
Err(e) => {
warn!("Error serializing search_points_request: {:?}", e);
self.reset_http_request();
return;
}
};
let path = format!("/collections/{}/points/search", DEFAULT_COLLECTION_NAME);
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "POST"),
(":path", &path),
(":authority", "qdrant"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
}
};
callout_context.request_type = RequestType::SearchPoints;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
}
fn search_points_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let search_points_response: SearchPointsResponse = match serde_json::from_slice(&body) {
Ok(search_points_response) => search_points_response,
Err(e) => {
warn!("Error deserializing search_points_response: {:?}", e);
self.resume_http_request();
return;
}
};
let search_results = &search_points_response.result;
if search_results.is_empty() {
info!("No prompt target matched");
self.resume_http_request();
return;
}
info!("similarity score: {}", search_results[0].score);
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD {
info!(
"prompt target below threshold: {}",
DEFAULT_PROMPT_TARGET_THRESHOLD
);
self.resume_http_request();
return;
}
let prompt_target_str = search_results[0].payload.get("prompt-target").unwrap();
let prompt_target: PromptTarget = match serde_json::from_slice(prompt_target_str.as_bytes())
{
Ok(prompt_target) => prompt_target,
Err(e) => {
warn!("Error deserializing prompt_target: {:?}", e);
self.resume_http_request();
return;
}
};
info!("prompt_target name: {:?}", prompt_target.name);
// only extract entity names
let entity_names = get_entity_details(&prompt_target)
.iter()
.map(|entity| entity.name.clone())
.collect();
let user_message = callout_context.user_message.clone();
let ner_request = common_types::NERRequest {
input: user_message,
labels: entity_names,
model: DEFAULT_NER_MODEL.to_string(),
};
let json_data: String = match serde_json::to_string(&ner_request) {
Ok(json_data) => json_data,
Err(e) => {
warn!("Error serializing ner_request: {:?}", e);
self.resume_http_request();
return;
}
};
let token_id = match self.dispatch_http_call(
"nerhost",
vec![
(":method", "POST"),
(":path", "/ner"),
(":authority", "nerhost"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
}
};
callout_context.request_type = RequestType::Ner;
callout_context.prompt_target = Some(prompt_target);
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
}
fn ner_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let ner_response: common_types::NERResponse = match serde_json::from_slice(&body) {
Ok(ner_response) => ner_response,
Err(e) => {
warn!("Error deserializing ner_response: {:?}", e);
self.resume_http_request();
return;
}
};
info!("ner_response: {:?}", ner_response);
let mut request_params: HashMap<String, String> = HashMap::new();
for entity in ner_response.data.iter() {
if entity.score < DEFAULT_NER_THRESHOLD {
warn!(
"score of entity was too low entity name: {}, score: {}",
entity.label, entity.score
);
continue;
}
request_params.insert(entity.label.clone(), entity.text.clone());
}
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
let entity_details = get_entity_details(prompt_target);
for entity in entity_details {
if entity.required.unwrap_or(false) && !request_params.contains_key(&entity.name) {
warn!(
"required entity missing or score of entity was too low: {}",
entity.name
);
self.resume_http_request();
return;
}
}
let req_param_str = match serde_json::to_string(&request_params) {
Ok(req_param_str) => req_param_str,
Err(e) => {
warn!("Error serializing request_params: {:?}", e);
self.resume_http_request();
return;
}
};
let endpoint = callout_context
.prompt_target
.as_ref()
.unwrap()
.endpoint
.as_ref()
.unwrap();
let http_path = match &endpoint.path {
Some(path) => path.clone(),
None => "/".to_string(),
};
let http_method = match &endpoint.method {
Some(method) => method.clone(),
None => http::Method::POST.to_string(),
};
let token_id = match self.dispatch_http_call(
&endpoint.cluster.clone(),
vec![
(":method", http_method.as_str()),
(":path", http_path.as_str()),
(":authority", endpoint.cluster.as_str()),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(req_param_str.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for context-resolver: {:?}", e);
}
};
callout_context.request_type = RequestType::ContextResolver;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
}
fn context_resolver_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
info!("response received for context-resolver");
let body_string = String::from_utf8(body);
let prompt_target = callout_context.prompt_target.unwrap();
let mut request_body = callout_context.request_body;
match prompt_target.system_prompt {
None => {}
Some(system_prompt) => {
let system_prompt_message: Message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt),
};
request_body.messages.push(system_prompt_message);
}
}
match body_string {
Ok(body_string) => {
info!("context-resolver response: {}", body_string);
let context_resolver_response = Message {
role: USER_ROLE.to_string(),
content: Some(body_string),
};
request_body.messages.push(context_resolver_response);
}
Err(e) => {
warn!("Error converting response to string: {:?}", e);
self.resume_http_request();
return;
}
}
let json_string = match serde_json::to_string(&request_body) {
Ok(json_string) => json_string,
Err(e) => {
warn!("Error serializing request_body: {:?}", e);
self.resume_http_request();
return;
}
};
info!("sending request to openai: msg len: {}", json_string.len());
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
self.resume_http_request();
}
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
self.save_host_header();
self.delete_content_length_header();
self.modify_path_header();
Action::Continue
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
// Let the client send the gateway all the data before sending to the LLM_provider.
// TODO: consider a streaming API.
if !end_of_stream {
return Action::Pause;
}
if body_size == 0 {
return Action::Continue;
}
// Deserialize body into spec.
// Currently OpenAI API.
let deserialized_body: common_types::open_ai::ChatCompletions =
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
vec![],
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
);
return Action::Pause;
}
},
None => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!(
"Failed to obtain body bytes even though body_size is {}",
body_size
);
return Action::Pause;
}
};
let user_message = match deserialized_body
.messages
.last()
.and_then(|last_message| last_message.content.as_ref())
{
Some(content) => content,
None => {
info!("No messages in the request body");
return Action::Continue;
}
};
let get_embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
}
};
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!(
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
e
);
}
};
let call_context = CallContext {
request_type: RequestType::GetEmbedding,
user_message: user_message.clone(),
prompt_target: None,
request_body: deserialized_body,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
Action::Pause
}
}
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
let callout_context = self.callouts.remove(&token_id).expect("invalid token_id");
let resp = self.get_http_call_response_body(0, body_size);
if resp.is_none() {
warn!("No response body");
self.resume_http_request();
return;
}
let body = match resp {
Some(body) => body,
None => {
warn!("Empty response body");
self.resume_http_request();
return;
}
};
match callout_context.request_type {
RequestType::GetEmbedding => {
self.embeddings_handler(body, callout_context);
}
RequestType::SearchPoints => {
self.search_points_handler(body, callout_context);
}
RequestType::Ner => {
self.ner_handler(body, callout_context);
}
RequestType::ContextResolver => {
self.context_resolver_handler(body, callout_context);
}
}
}
}
fn get_entity_details(prompt_target: &PromptTarget) -> Vec<EntityDetail> {
match prompt_target.entities.as_ref() {
Some(EntityType::Vec(entity_names)) => {
let mut entity_details: Vec<EntityDetail> = Vec::new();
for entity_name in entity_names {
entity_details.push(EntityDetail {
name: entity_name.clone(),
required: Some(true),
description: None,
});
}
entity_details
}
Some(EntityType::Struct(entity_details)) => entity_details.clone(),
None => Vec::new(),
}
}