diff --git a/api_llm_gateway.rest b/api_llm_gateway.rest new file mode 100644 index 00000000..b40c229b --- /dev/null +++ b/api_llm_gateway.rest @@ -0,0 +1,76 @@ +@llm_endpoint = http://localhost:12000 +@openai_endpoint = https://api.openai.com +@access_key = {{$dotenv OPENAI_API_KEY}} + +### openai request +POST {{openai_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json +Authorization: Bearer {{access_key}} + +{ + "messages": [ + { + "role": "user", + "content": "hello" + } + ], + "model": "gpt-4o-mini" +} + +### openai request (streaming) +POST {{openai_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json +Authorization: Bearer {{access_key}} + +{ + "messages": [ + { + "role": "user", + "content": "hello" + } + ], + "model": "gpt-4o-mini", + "stream": true +} + + +### llm gateway request +POST {{llm_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json + +{ + "messages": [ + { + "role": "user", + "content": "hello" + } + ] +} + +### llm gateway request (streaming) +POST {{llm_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json + +{ + "messages": [ + { + "role": "user", + "content": "hello" + } + ], + "stream": true +} + +### llm gateway request (provider hint) +POST {{llm_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json +x-arch-llm-provider-hint: gpt-3.5-turbo-0125 + +{ + "messages": [ + { + "role": "user", + "content": "hello" + } + ] +} diff --git a/api_model_server.rest b/api_model_server.rest new file mode 100644 index 00000000..9102786a --- /dev/null +++ b/api_model_server.rest @@ -0,0 +1,44 @@ +@model_server_endpoint = http://localhost:51000 +@archfc_endpoint = https://api.fc.archgw.com + +### talk to model_server for completion +POST {{model_server_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json + +{ + "messages": [ + { + "role": "user", + "content": "how is the weather in seattle for next 10 days" + } + ], + "tools": [ + { + "id": "weather-112", + "tool_type": "function", + "function": { + "name": "weather_forecast", + "arguments": {"city": "str", "days": "int"} + } + } + ] +} + + +### talk to arch_fc directly for completion +POST {{archfc_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json + +{ + "model": "Arch-Function", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"id\": \"weather-112\", \"tool_type\": \"function\", \"function\": {\"name\": \"weather_forecast\", \"arguments\": {\"city\": \"str\", \"days\": \"int\"}}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n" + }, + { "role": "user", "content": "how is the weather in seattle?" }, + { "role": "assistant", "content": "Of course! " } + ], + "continue_final_message": true, + "add_generation_prompt": false +} diff --git a/api_prompt_gateway.rest b/api_prompt_gateway.rest new file mode 100644 index 00000000..b79b4230 --- /dev/null +++ b/api_prompt_gateway.rest @@ -0,0 +1,87 @@ +@prompt_endpoint = http://localhost:10000 + +### prompt gateway request +POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json + +{ + "messages": [ + { + "role": "user", + "content": "how is the weather in seattle for next 10 days" + } + ] +} + +### prompt gateway request (streaming) +POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json + +{ + "messages": [ + { + "role": "user", + "content": "how is the weather in seattle for next 10 days" + } + ], + "stream": true +} + + +### prompt gateway request param gathering +POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json + +{ + "messages": [ + { + "role": "user", + "content": "how is the weather in seattle" + } + ] +} + +### prompt gateway request param gathering and function calling +POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json + +{ + "messages": [ + { + "role": "user", + "content": "how is the weather in seattle" + }, + { + "role": "assistant", + "content": "It seems I'm missing some information. Could you provide the following details days ?", + "model": "Arch-Function-1.5b" + }, + { + "role": "user", + "content": "for next 10 days" + } + ] +} + +### prompt gateway request param gathering and function calling (streaming) +POST {{prompt_endpoint}}/v1/chat/completions HTTP/1.1 +Content-Type: application/json + +{ + "messages": [ + { + "role": "user", + "content": "how is the weather in seattle" + }, + { + "role": "assistant", + "content": "It seems I'm missing some information. Could you provide the following details days ?", + "model": "Arch-Function-1.5b" + }, + { + "role": "user", + "content": "for next 10 days" + } + ], + "stream": true +} diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs index de120369..0b44fac9 100644 --- a/crates/prompt_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -11,7 +11,8 @@ use common::http::CallArgs; use common::http::Client; use common::stats::Gauge; use common::stats::IncrementingMetric; -use log::debug; +use http::StatusCode; +use log::{debug, info, trace, warn}; use proxy_wasm::traits::*; use proxy_wasm::types::*; use std::cell::RefCell; @@ -53,6 +54,7 @@ pub struct FilterContext { prompt_guards: Rc, embeddings_store: Option>, temp_embeddings_store: EmbeddingsStore, + active_embedding_calls_count: u32, } impl FilterContext { @@ -66,22 +68,26 @@ impl FilterContext { prompt_guards: Rc::new(PromptGuards::default()), embeddings_store: Some(Rc::new(HashMap::new())), temp_embeddings_store: HashMap::new(), + active_embedding_calls_count: 0, } } - fn process_prompt_targets(&self) { - for values in self.prompt_targets.iter() { - let prompt_target = values.1; - self.schedule_embeddings_call( - &prompt_target.name, - &prompt_target.description, - EmbeddingType::Description, - ); - } + fn process_prompt_targets(&mut self) { + let prompt_target_description: Vec<(String, String)> = self + .prompt_targets + .iter() + .map(|(k, v)| (k.clone(), v.description.clone())) + .collect(); + + prompt_target_description + .iter() + .for_each(|(name, description)| { + self.schedule_embeddings_call(name, description, EmbeddingType::Description); + }); } fn schedule_embeddings_call( - &self, + &mut self, prompt_target_name: &str, input: &str, embedding_type: EmbeddingType, @@ -116,6 +122,7 @@ impl FilterContext { embedding_type, }; + self.active_embedding_calls_count += 1; if let Err(error) = self.http_call(call_args, call_context) { panic!("{error}") } @@ -123,9 +130,9 @@ impl FilterContext { fn embedding_response_handler( &mut self, - body_size: usize, embedding_type: EmbeddingType, prompt_target_name: String, + body: Vec, ) { let prompt_target = self .prompt_targets @@ -137,9 +144,6 @@ impl FilterContext { ) }); - let body = self - .get_http_call_response_body(0, body_size) - .expect("No body in response"); if !body.is_empty() { let mut embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { @@ -208,7 +212,7 @@ impl Context for FilterContext { body_size: usize, _num_trailers: usize, ) { - debug!( + trace!( "filter_context: on_http_call_response called with token_id: {:?}", token_id ); @@ -218,13 +222,26 @@ impl Context for FilterContext { .remove(&token_id) .expect("invalid token_id"); + self.active_embedding_calls_count -= 1; self.metrics.active_http_calls.increment(-1); + let body_bytes = self.get_http_call_response_body(0, body_size).unwrap(); - self.embedding_response_handler( - body_size, - callout_data.embedding_type, - callout_data.prompt_target_name, - ) + if let Some(status_code) = self.get_http_call_response_header(":status") { + if status_code == StatusCode::OK.as_str() { + self.embedding_response_handler( + callout_data.embedding_type, + callout_data.prompt_target_name, + body_bytes, + ); + } else { + warn!( + "Received non-200 status code: {} for callout with token_id: {}: body_str: {}", + status_code, + token_id, + String::from_utf8(body_bytes).unwrap() + ); + } + } } } @@ -262,10 +279,7 @@ impl RootContext for FilterContext { context_id ); - let embedding_store = match self.embeddings_store.as_ref() { - None => return None, - Some(store) => Some(Rc::clone(store)), - }; + let embedding_store = self.embeddings_store.as_ref().map(Rc::clone); Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), @@ -287,8 +301,20 @@ impl RootContext for FilterContext { } fn on_tick(&mut self) { - debug!("starting up arch filter in mode: prompt gateway mode"); - self.process_prompt_targets(); - self.set_tick_period(Duration::from_secs(0)); + if self.embeddings_store.is_some() + && self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() + { + info!("embeddings store initialized"); + self.set_tick_period(Duration::from_secs(0)); + } else { + if self.active_embedding_calls_count == 0 { + info!("retrieving embeddings from embedding server"); + self.process_prompt_targets(); + } else { + info!("waiting for embeddings store to be initialized"); + } + + self.set_tick_period(Duration::from_secs(5)); + } } } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 0ea27cfb..ed5a23c8 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -37,10 +37,10 @@ impl HttpContext for StreamContext { let request_path = self.get_http_request_header(":path").unwrap_or_default(); if request_path == HEALTHZ_PATH { - if self.embeddings_store.is_none() { - self.send_http_response(503, vec![], None); - } else { + if self.is_embedding_store_initialized() { self.send_http_response(200, vec![], None); + } else { + self.send_http_response(503, vec![], None); } return Action::Continue; } diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index c068072b..c7865a75 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -61,7 +61,7 @@ pub struct StreamCallContext { pub struct StreamContext { system_prompt: Rc>, - prompt_targets: Rc>, + pub prompt_targets: Rc>, pub embeddings_store: Option>, overrides: Rc>, pub metrics: Rc, @@ -111,10 +111,21 @@ impl StreamContext { traceparent: None, } } + fn embeddings_store(&self) -> &EmbeddingsStore { - self.embeddings_store - .as_ref() - .expect("embeddings store is not set") + self.embeddings_store.as_ref().unwrap() + } + + pub fn is_embedding_store_initialized(&self) -> bool { + if self.embeddings_store.as_ref().is_none() { + return false; + } + + if self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() { + return true; + } + + false } pub fn send_server_error(&self, error: ServerError, override_status_code: Option) { @@ -232,7 +243,7 @@ impl StreamContext { "embeddings not found for prompt target name: {}", prompt_name ); - return (prompt_name.clone(), f64::NAN); + return (prompt_name.clone(), 0.0); } }; @@ -243,7 +254,7 @@ impl StreamContext { "description embeddings not found for prompt target name: {}", prompt_name ); - return (prompt_name.clone(), f64::NAN); + return (prompt_name.clone(), 0.0); } }; let similarity_score_description = @@ -698,7 +709,7 @@ impl StreamContext { if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() { // This means that Arch FC did not have enough information to resolve the function call // Arch FC probably responded with a message asking for more information. - // Let's send the response back to the user to initalize lightweight dialog for parameter collection + // Let's send the response back to the user to initialize lightweight dialog for parameter collection //TODO: add resolver name to the response so the client can send the response back to the correct resolver diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 1bf581c5..46f2dfd8 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -161,6 +161,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&embeddings_response_buffer)) .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None) .expect_http_call( @@ -244,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { module .call_proxy_on_tick(filter_context) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), @@ -262,7 +263,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { ) .returning(Some(101)) .expect_metric_increment("active_http_calls", 1) - .expect_set_tick_period_millis(Some(0)) + .expect_set_tick_period_millis(Some(5000)) .execute_and_expect(ReturnType::None) .unwrap(); @@ -289,7 +290,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { 0, ) .expect_log( - Some(LogLevel::Debug), + Some(LogLevel::Trace), Some( format!( "filter_context: on_http_call_response called with token_id: {:?}", @@ -332,7 +333,7 @@ llm_providers: overrides: # confidence threshold for prompt target intent matching - prompt_target_intent_matching_threshold: 0.6 + prompt_target_intent_matching_threshold: 0.0 system_prompt: | You are a helpful assistant.