mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 12:52:56 +02:00
Add workflow logic for weather forecast demo (#24)
This commit is contained in:
parent
7ef68eccfb
commit
33f9dd22e6
32 changed files with 1902 additions and 459 deletions
|
|
@ -23,9 +23,53 @@ pub struct StoreVectorEmbeddingsRequest {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum CallContext {
|
||||
EmbeddingRequest(EmbeddingRequest),
|
||||
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
|
||||
CreateVectorCollection(String),
|
||||
}
|
||||
|
||||
// https://api.qdrant.tech/master/api-reference/search/points
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointsRequest {
|
||||
pub vector: Vec<f64>,
|
||||
pub limit: i32,
|
||||
pub with_payload: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointResult {
|
||||
pub id: String,
|
||||
pub version: i32,
|
||||
pub score: f64,
|
||||
pub payload: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointsResponse {
|
||||
pub result: Vec<SearchPointResult>,
|
||||
pub status: String,
|
||||
pub time: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NERRequest {
|
||||
pub input: String,
|
||||
pub labels: Vec<String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Entity {
|
||||
pub text: String,
|
||||
pub label: String,
|
||||
pub score: f64,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NERResponse {
|
||||
pub data: Vec<Entity>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
pub mod open_ai {
|
||||
|
|
@ -41,6 +85,6 @@ pub mod open_ai {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ pub struct PromptConfig {
|
|||
pub timeout_ms: u64,
|
||||
pub embedding_provider: EmbeddingProviver,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub system_prompt: String,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_targets: Vec<PromptTarget>,
|
||||
}
|
||||
|
||||
|
|
@ -49,6 +49,29 @@ pub struct LlmProvider {
|
|||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct Endpoint {
|
||||
pub cluster: String,
|
||||
pub path: Option<String>,
|
||||
pub method: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct EntityDetail {
|
||||
pub name: String,
|
||||
pub required: Option<bool>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum EntityType {
|
||||
Vec(Vec<String>),
|
||||
Struct(Vec<EntityDetail>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct PromptTarget {
|
||||
|
|
@ -56,7 +79,9 @@ pub struct PromptTarget {
|
|||
pub prompt_type: String,
|
||||
pub name: String,
|
||||
pub few_shot_examples: Vec<String>,
|
||||
pub endpoint: String,
|
||||
pub entities: Option<EntityType>,
|
||||
pub endpoint: Option<Endpoint>,
|
||||
pub system_prompt: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -88,13 +113,23 @@ katanemo-prompt-config:
|
|||
name: weather-forecast
|
||||
few-shot-examples:
|
||||
- what is the weather in New York?
|
||||
endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT"
|
||||
cache-response: true
|
||||
cache-response-settings:
|
||||
- cache-ttl-secs: 3600 # cache expiry in seconds
|
||||
- cache-max-size: 1000 # in number of items
|
||||
- cache-eviction-strategy: LRU
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- name: location
|
||||
required: true
|
||||
description: "The location for which the weather is requested"
|
||||
|
||||
- type: context-resolver
|
||||
name: weather-forecast-2
|
||||
few-shot-examples:
|
||||
- what is the weather in New York?
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- city
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -1 +1,7 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
|
||||
pub const DEFAULT_NER_MODEL: &str = "urchade/gliner_large-v2.1";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
|
||||
pub const DEFAULT_NER_THRESHOLD: f64 = 0.6;
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
|
|
|
|||
289
envoyfilter/src/filter_context.rs
Normal file
289
envoyfilter/src/filter_context.rs
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
use common_types::{CallContext, EmbeddingRequest};
|
||||
use configuration::PromptTarget;
|
||||
use log::info;
|
||||
use md5::Digest;
|
||||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::to_string;
|
||||
use stats::RecordingMetric;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use consts::DEFAULT_EMBEDDING_MODEL;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
use crate::common_types;
|
||||
use crate::configuration;
|
||||
use crate::consts;
|
||||
use crate::stats;
|
||||
use crate::stream_context::StreamContext;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct WasmMetrics {
|
||||
active_http_calls: stats::Gauge,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FilterContext {
|
||||
metrics: WasmMetrics,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, common_types::CallContext>,
|
||||
config: Option<configuration::Configuration>,
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: WasmMetrics::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(
|
||||
few_shot_example.to_string(),
|
||||
)),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
let json_data: String = to_string(&embeddings_input).unwrap();
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
let embedding_request = EmbeddingRequest {
|
||||
create_embedding_request: embeddings_input,
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext::EmbeddingRequest(embedding_request)
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_request_handler(
|
||||
&mut self,
|
||||
body_size: usize,
|
||||
create_embedding_request: CreateEmbeddingRequest,
|
||||
prompt_target: PromptTarget,
|
||||
) {
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let embedding_response: CreateEmbeddingResponse =
|
||||
serde_json::from_slice(&body).unwrap();
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
to_string(&prompt_target).unwrap(),
|
||||
);
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(input) => {
|
||||
id = Some(md5::compute(&input));
|
||||
payload.insert("input".to_string(), input);
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
};
|
||||
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_vector_store_points_handler(&self, body_size: usize) {
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
info!(
|
||||
"response body: len {:?}",
|
||||
String::from_utf8(body).unwrap().len()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: run once per envoy instance, right now it runs once per worker
|
||||
fn init_vector_store(&mut self) {
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(b"{ \"vectors\": { \"size\": 1024, \"distance\": \"Cosine\"}}"),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for init-vector-store: {:?}", e);
|
||||
}
|
||||
};
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::CreateVectorCollection("prompt_vector_store".to_string()),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
// self.metrics
|
||||
// .active_http_calls
|
||||
// .record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
|
||||
match callout_data {
|
||||
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
|
||||
}
|
||||
common_types::CallContext::StoreVectorEmbeddings(_) => {
|
||||
self.create_vector_store_points_handler(body_size)
|
||||
}
|
||||
common_types::CallContext::CreateVectorCollection(_) => {
|
||||
let mut http_status_code = "Nil".to_string();
|
||||
self.get_http_call_response_headers()
|
||||
.iter()
|
||||
.for_each(|(k, v)| {
|
||||
if k == ":status" {
|
||||
http_status_code.clone_from(v);
|
||||
}
|
||||
});
|
||||
info!("CreateVectorCollection response: {}", http_status_code);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
Some(Box::new(StreamContext {
|
||||
host_header: None,
|
||||
callouts: HashMap::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
// initialize vector store
|
||||
self.init_vector_store();
|
||||
self.process_prompt_targets();
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,22 +1,13 @@
|
|||
use common_types::{CallContext, EmbeddingRequest};
|
||||
use configuration::PromptTarget;
|
||||
use http::StatusCode;
|
||||
use log::error;
|
||||
use log::info;
|
||||
use md5::Digest;
|
||||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use filter_context::FilterContext;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use stats::{Gauge, RecordingMetric};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
mod common_types;
|
||||
mod configuration;
|
||||
mod consts;
|
||||
mod filter_context;
|
||||
mod stats;
|
||||
mod stream_context;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
|
|
@ -24,397 +15,3 @@ proxy_wasm::main! {{
|
|||
Box::new(FilterContext::new())
|
||||
});
|
||||
}}
|
||||
|
||||
struct StreamContext {
|
||||
host_header: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
fn save_host_header(&mut self) {
|
||||
// Save the host header to be used by filter logic later on.
|
||||
self.host_header = self.get_http_request_header(":host");
|
||||
}
|
||||
|
||||
fn delete_content_length_header(&mut self) {
|
||||
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
|
||||
// Server's generally throw away requests whose body length do not match the Content-Length header.
|
||||
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
|
||||
// manipulate the body in benign ways e.g., compression.
|
||||
self.set_http_request_header("content-length", None);
|
||||
}
|
||||
|
||||
fn modify_path_header(&mut self) {
|
||||
match self.get_http_request_header(":path") {
|
||||
// The gateway can start gathering information necessary for routing. For now change the path to an
|
||||
// OpenAI API path.
|
||||
Some(path) if path == "/llmrouting" => {
|
||||
self.set_http_request_header(":path", Some("/v1/chat/completions"));
|
||||
}
|
||||
// Otherwise let the filter continue.
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
impl HttpContext for StreamContext {
|
||||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
self.save_host_header();
|
||||
self.delete_content_length_header();
|
||||
self.modify_path_header();
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
||||
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
// Let the client send the gateway all the data before sending to the LLM_provider.
|
||||
// TODO: consider a streaming API.
|
||||
if !end_of_stream {
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
if body_size == 0 {
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let mut deserialized_body: common_types::open_ai::ChatCompletions =
|
||||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
vec![],
|
||||
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
// Modify JSON payload
|
||||
deserialized_body.model = String::from("gpt-3.5-turbo");
|
||||
|
||||
match serde_json::to_string(&deserialized_body) {
|
||||
Ok(json_string) => {
|
||||
self.set_http_request_body(0, body_size, &json_string.into_bytes());
|
||||
Action::Continue
|
||||
}
|
||||
Err(error) => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!("Failed to serialize body: {}", error);
|
||||
Action::Pause
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for StreamContext {}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct WasmMetrics {
|
||||
active_http_calls: Gauge,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FilterContext {
|
||||
metrics: WasmMetrics,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, common_types::CallContext>,
|
||||
config: Option<configuration::Configuration>,
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: WasmMetrics::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
for prompt_target in &self
|
||||
.config
|
||||
.as_ref()
|
||||
.expect("Gateway configuration cannot be non-existent")
|
||||
.prompt_config
|
||||
.prompt_targets
|
||||
{
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
info!("few_shot_example: {:?}", few_shot_example);
|
||||
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(
|
||||
few_shot_example.to_string(),
|
||||
)),
|
||||
model: String::from(consts::DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching embedding server HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
let embedding_request = EmbeddingRequest {
|
||||
create_embedding_request: embeddings_input,
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext::EmbeddingRequest(embedding_request)
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_request_handler(
|
||||
&mut self,
|
||||
body_size: usize,
|
||||
create_embedding_request: CreateEmbeddingRequest,
|
||||
prompt_target: PromptTarget,
|
||||
) {
|
||||
info!("response received for CreateEmbeddingRequest");
|
||||
let body = match self.get_http_call_response_body(0, body_size) {
|
||||
Some(body) => body,
|
||||
None => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if body.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let (json_data, create_vector_store_points) =
|
||||
match build_qdrant_data(&body, create_embedding_request, &prompt_target) {
|
||||
Ok(tuple) => tuple,
|
||||
Err(error) => {
|
||||
panic!(
|
||||
"Error building qdrant data from embedding response {}",
|
||||
error
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching qdrant HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id={} in qdrant requests", token_id)
|
||||
}
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
|
||||
// TODO: @adilhafeez implement.
|
||||
fn create_vector_store_points_handler(&self, body_size: usize) {
|
||||
info!("response received for CreateVectorStorePoints");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
info!("response body: {:?}", String::from_utf8(body).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
info!("on_http_call_response: token_id = {}", token_id);
|
||||
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
match callout_data {
|
||||
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
|
||||
}
|
||||
common_types::CallContext::StoreVectorEmbeddings(_) => {
|
||||
self.create_vector_store_points_handler(body_size)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = match serde_yaml::from_slice(&config_bytes) {
|
||||
Ok(config) => config,
|
||||
Err(error) => {
|
||||
panic!("Failed to deserialize plugin configuration: {}", error);
|
||||
}
|
||||
};
|
||||
info!("on_configure: plugin configuration loaded");
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
Some(Box::new(StreamContext { host_header: None }))
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
info!("on_vm_start: setting up tick timeout");
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
info!("on_tick: starting to process prompt targets");
|
||||
self.process_prompt_targets();
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
|
||||
fn build_qdrant_data(
|
||||
embedding_data: &[u8],
|
||||
create_embedding_request: CreateEmbeddingRequest,
|
||||
prompt_target: &PromptTarget,
|
||||
) -> Result<(String, common_types::StoreVectorEmbeddingsRequest), serde_json::Error> {
|
||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(embedding_data) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(error) => {
|
||||
panic!("Failed to deserialize embedding response: {}", error);
|
||||
}
|
||||
};
|
||||
info!(
|
||||
"embedding_response model: {}, vector len: {}",
|
||||
embedding_response.model,
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
serde_json::to_string(&prompt_target)?,
|
||||
);
|
||||
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(ref input) => {
|
||||
id = Some(md5::compute(input));
|
||||
payload.insert("input".to_string(), input.clone());
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
};
|
||||
let json_data = serde_json::to_string(&create_vector_store_points)?;
|
||||
info!(
|
||||
"create_vector_store_points: points length: {}",
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
Ok((json_data, create_vector_store_points))
|
||||
}
|
||||
|
|
|
|||
520
envoyfilter/src/stream_context.rs
Normal file
520
envoyfilter/src/stream_context.rs
Normal file
|
|
@ -0,0 +1,520 @@
|
|||
use http::StatusCode;
|
||||
use log::error;
|
||||
use log::info;
|
||||
use log::warn;
|
||||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
use consts::{
|
||||
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
|
||||
use crate::common_types;
|
||||
use crate::common_types::open_ai::Message;
|
||||
use crate::common_types::SearchPointsResponse;
|
||||
use crate::configuration::EntityDetail;
|
||||
use crate::configuration::EntityType;
|
||||
use crate::configuration::PromptTarget;
|
||||
use crate::consts;
|
||||
|
||||
enum RequestType {
|
||||
GetEmbedding,
|
||||
SearchPoints,
|
||||
Ner,
|
||||
ContextResolver,
|
||||
}
|
||||
|
||||
pub struct CallContext {
|
||||
request_type: RequestType,
|
||||
user_message: String,
|
||||
prompt_target: Option<PromptTarget>,
|
||||
request_body: common_types::open_ai::ChatCompletions,
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
pub host_header: Option<String>,
|
||||
pub callouts: HashMap<u32, CallContext>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
fn save_host_header(&mut self) {
|
||||
// Save the host header to be used by filter logic later on.
|
||||
self.host_header = self.get_http_request_header(":host");
|
||||
}
|
||||
|
||||
fn delete_content_length_header(&mut self) {
|
||||
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
|
||||
// Server's generally throw away requests whose body length do not match the Content-Length header.
|
||||
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
|
||||
// manipulate the body in benign ways e.g., compression.
|
||||
self.set_http_request_header("content-length", None);
|
||||
}
|
||||
|
||||
fn modify_path_header(&mut self) {
|
||||
match self.get_http_request_header(":path") {
|
||||
// The gateway can start gathering information necessary for routing. For now change the path to an
|
||||
// OpenAI API path.
|
||||
Some(path) if path == "/llmrouting" => {
|
||||
self.set_http_request_header(":path", Some("/v1/chat/completions"));
|
||||
}
|
||||
// Otherwise let the filter continue.
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(e) => {
|
||||
warn!("Error deserializing embedding response: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let search_points_request = common_types::SearchPointsRequest {
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
limit: 10,
|
||||
with_payload: true,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&search_points_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(e) => {
|
||||
warn!("Error serializing search_points_request: {:?}", e);
|
||||
self.reset_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let path = format!("/collections/{}/points/search", DEFAULT_COLLECTION_NAME);
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", &path),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
|
||||
}
|
||||
};
|
||||
|
||||
callout_context.request_type = RequestType::SearchPoints;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
|
||||
fn search_points_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let search_points_response: SearchPointsResponse = match serde_json::from_slice(&body) {
|
||||
Ok(search_points_response) => search_points_response,
|
||||
Err(e) => {
|
||||
warn!("Error deserializing search_points_response: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let search_results = &search_points_response.result;
|
||||
|
||||
if search_results.is_empty() {
|
||||
info!("No prompt target matched");
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
|
||||
info!("similarity score: {}", search_results[0].score);
|
||||
|
||||
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD {
|
||||
info!(
|
||||
"prompt target below threshold: {}",
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
let prompt_target_str = search_results[0].payload.get("prompt-target").unwrap();
|
||||
let prompt_target: PromptTarget = match serde_json::from_slice(prompt_target_str.as_bytes())
|
||||
{
|
||||
Ok(prompt_target) => prompt_target,
|
||||
Err(e) => {
|
||||
warn!("Error deserializing prompt_target: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
info!("prompt_target name: {:?}", prompt_target.name);
|
||||
|
||||
// only extract entity names
|
||||
let entity_names = get_entity_details(&prompt_target)
|
||||
.iter()
|
||||
.map(|entity| entity.name.clone())
|
||||
.collect();
|
||||
let user_message = callout_context.user_message.clone();
|
||||
let ner_request = common_types::NERRequest {
|
||||
input: user_message,
|
||||
labels: entity_names,
|
||||
model: DEFAULT_NER_MODEL.to_string(),
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&ner_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(e) => {
|
||||
warn!("Error serializing ner_request: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"nerhost",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/ner"),
|
||||
(":authority", "nerhost"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
|
||||
}
|
||||
};
|
||||
callout_context.request_type = RequestType::Ner;
|
||||
callout_context.prompt_target = Some(prompt_target);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
|
||||
fn ner_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let ner_response: common_types::NERResponse = match serde_json::from_slice(&body) {
|
||||
Ok(ner_response) => ner_response,
|
||||
Err(e) => {
|
||||
warn!("Error deserializing ner_response: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
info!("ner_response: {:?}", ner_response);
|
||||
|
||||
let mut request_params: HashMap<String, String> = HashMap::new();
|
||||
for entity in ner_response.data.iter() {
|
||||
if entity.score < DEFAULT_NER_THRESHOLD {
|
||||
warn!(
|
||||
"score of entity was too low entity name: {}, score: {}",
|
||||
entity.label, entity.score
|
||||
);
|
||||
continue;
|
||||
}
|
||||
request_params.insert(entity.label.clone(), entity.text.clone());
|
||||
}
|
||||
|
||||
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
|
||||
let entity_details = get_entity_details(prompt_target);
|
||||
for entity in entity_details {
|
||||
if entity.required.unwrap_or(false) && !request_params.contains_key(&entity.name) {
|
||||
warn!(
|
||||
"required entity missing or score of entity was too low: {}",
|
||||
entity.name
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let req_param_str = match serde_json::to_string(&request_params) {
|
||||
Ok(req_param_str) => req_param_str,
|
||||
Err(e) => {
|
||||
warn!("Error serializing request_params: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let endpoint = callout_context
|
||||
.prompt_target
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.endpoint
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
|
||||
let http_path = match &endpoint.path {
|
||||
Some(path) => path.clone(),
|
||||
None => "/".to_string(),
|
||||
};
|
||||
|
||||
let http_method = match &endpoint.method {
|
||||
Some(method) => method.clone(),
|
||||
None => http::Method::POST.to_string(),
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
&endpoint.cluster.clone(),
|
||||
vec![
|
||||
(":method", http_method.as_str()),
|
||||
(":path", http_path.as_str()),
|
||||
(":authority", endpoint.cluster.as_str()),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(req_param_str.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for context-resolver: {:?}", e);
|
||||
}
|
||||
};
|
||||
callout_context.request_type = RequestType::ContextResolver;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
|
||||
fn context_resolver_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
|
||||
info!("response received for context-resolver");
|
||||
let body_string = String::from_utf8(body);
|
||||
let prompt_target = callout_context.prompt_target.unwrap();
|
||||
let mut request_body = callout_context.request_body;
|
||||
match prompt_target.system_prompt {
|
||||
None => {}
|
||||
Some(system_prompt) => {
|
||||
let system_prompt_message: Message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt),
|
||||
};
|
||||
request_body.messages.push(system_prompt_message);
|
||||
}
|
||||
}
|
||||
match body_string {
|
||||
Ok(body_string) => {
|
||||
info!("context-resolver response: {}", body_string);
|
||||
let context_resolver_response = Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(body_string),
|
||||
};
|
||||
request_body.messages.push(context_resolver_response);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error converting response to string: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let json_string = match serde_json::to_string(&request_body) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
warn!("Error serializing request_body: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
info!("sending request to openai: msg len: {}", json_string.len());
|
||||
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
}
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
impl HttpContext for StreamContext {
|
||||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
self.save_host_header();
|
||||
self.delete_content_length_header();
|
||||
self.modify_path_header();
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
||||
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
// Let the client send the gateway all the data before sending to the LLM_provider.
|
||||
// TODO: consider a streaming API.
|
||||
if !end_of_stream {
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
if body_size == 0 {
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let deserialized_body: common_types::open_ai::ChatCompletions =
|
||||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
vec![],
|
||||
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let user_message = match deserialized_body
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|last_message| last_message.content.as_ref())
|
||||
{
|
||||
Some(content) => content,
|
||||
None => {
|
||||
info!("No messages in the request body");
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
let get_embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
let call_context = CallContext {
|
||||
request_type: RequestType::GetEmbedding,
|
||||
user_message: user_message.clone(),
|
||||
prompt_target: None,
|
||||
request_body: deserialized_body,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
|
||||
Action::Pause
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for StreamContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_context = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
let resp = self.get_http_call_response_body(0, body_size);
|
||||
|
||||
if resp.is_none() {
|
||||
warn!("No response body");
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
|
||||
let body = match resp {
|
||||
Some(body) => body,
|
||||
None => {
|
||||
warn!("Empty response body");
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match callout_context.request_type {
|
||||
RequestType::GetEmbedding => {
|
||||
self.embeddings_handler(body, callout_context);
|
||||
}
|
||||
|
||||
RequestType::SearchPoints => {
|
||||
self.search_points_handler(body, callout_context);
|
||||
}
|
||||
RequestType::Ner => {
|
||||
self.ner_handler(body, callout_context);
|
||||
}
|
||||
RequestType::ContextResolver => {
|
||||
self.context_resolver_handler(body, callout_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_entity_details(prompt_target: &PromptTarget) -> Vec<EntityDetail> {
|
||||
match prompt_target.entities.as_ref() {
|
||||
Some(EntityType::Vec(entity_names)) => {
|
||||
let mut entity_details: Vec<EntityDetail> = Vec::new();
|
||||
for entity_name in entity_names {
|
||||
entity_details.push(EntityDetail {
|
||||
name: entity_name.clone(),
|
||||
required: Some(true),
|
||||
description: None,
|
||||
});
|
||||
}
|
||||
entity_details
|
||||
}
|
||||
Some(EntityType::Struct(entity_details)) => entity_details.clone(),
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue