mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
Add ability to stream a response (#50)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
118bff7c7e
commit
9f3c845610
5 changed files with 251 additions and 104 deletions
|
|
@ -7,4 +7,5 @@ pub const USER_ROLE: &str = "user";
|
||||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||||
pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b";
|
pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b";
|
||||||
pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||||
|
pub const OPENAI_CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||||
|
|
|
||||||
|
|
@ -232,14 +232,12 @@ impl RootContext for FilterContext {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
|
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||||
Some(Box::new(StreamContext {
|
Some(Box::new(StreamContext::new(
|
||||||
host_header: None,
|
context_id,
|
||||||
ratelimit_selector: None,
|
Rc::clone(&self.metrics),
|
||||||
callouts: HashMap::new(),
|
Rc::clone(&self.prompt_targets),
|
||||||
metrics: Rc::clone(&self.metrics),
|
)))
|
||||||
prompt_targets: Rc::clone(&self.prompt_targets),
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_type(&self) -> Option<ContextType> {
|
fn get_type(&self) -> Option<ContextType> {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::consts::{
|
use crate::consts::{
|
||||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
|
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
|
||||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, OPENAI_CHAT_COMPLETIONS_PATH,
|
||||||
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||||
};
|
};
|
||||||
use crate::filter_context::{embeddings_store, WasmMetrics};
|
use crate::filter_context::{embeddings_store, WasmMetrics};
|
||||||
|
|
@ -10,13 +10,16 @@ use crate::stats::IncrementingMetric;
|
||||||
use crate::tokenizer;
|
use crate::tokenizer;
|
||||||
use acap::cos;
|
use acap::cos;
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use log::{debug, error, info, warn};
|
use log::{debug, info, warn};
|
||||||
use open_message_format_embeddings::models::{
|
use open_message_format_embeddings::models::{
|
||||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||||
};
|
};
|
||||||
use proxy_wasm::traits::*;
|
use proxy_wasm::traits::*;
|
||||||
use proxy_wasm::types::*;
|
use proxy_wasm::types::*;
|
||||||
use public_types::common_types::open_ai::{ChatCompletions, Message};
|
use public_types::common_types::open_ai::{
|
||||||
|
ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message,
|
||||||
|
StreamOptions,
|
||||||
|
};
|
||||||
use public_types::common_types::{
|
use public_types::common_types::{
|
||||||
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
|
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
|
||||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||||
|
|
@ -39,19 +42,40 @@ pub struct CallContext {
|
||||||
response_handler_type: ResponseHandlerType,
|
response_handler_type: ResponseHandlerType,
|
||||||
user_message: Option<String>,
|
user_message: Option<String>,
|
||||||
prompt_target: Option<PromptTarget>,
|
prompt_target: Option<PromptTarget>,
|
||||||
request_body: ChatCompletions,
|
request_body: ChatCompletionsRequest,
|
||||||
similarity_scores: Option<Vec<(String, f64)>>,
|
similarity_scores: Option<Vec<(String, f64)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct StreamContext {
|
pub struct StreamContext {
|
||||||
pub host_header: Option<String>,
|
pub context_id: u32,
|
||||||
pub ratelimit_selector: Option<Header>,
|
|
||||||
pub callouts: HashMap<u32, CallContext>,
|
|
||||||
pub metrics: Rc<WasmMetrics>,
|
pub metrics: Rc<WasmMetrics>,
|
||||||
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||||
|
callouts: HashMap<u32, CallContext>,
|
||||||
|
host_header: Option<String>,
|
||||||
|
ratelimit_selector: Option<Header>,
|
||||||
|
streaming_response: bool,
|
||||||
|
response_tokens: usize,
|
||||||
|
chat_completions_request: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamContext {
|
impl StreamContext {
|
||||||
|
pub fn new(
|
||||||
|
context_id: u32,
|
||||||
|
metrics: Rc<WasmMetrics>,
|
||||||
|
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||||
|
) -> Self {
|
||||||
|
StreamContext {
|
||||||
|
context_id,
|
||||||
|
metrics,
|
||||||
|
prompt_targets,
|
||||||
|
callouts: HashMap::new(),
|
||||||
|
host_header: None,
|
||||||
|
ratelimit_selector: None,
|
||||||
|
streaming_response: false,
|
||||||
|
response_tokens: 0,
|
||||||
|
chat_completions_request: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
fn save_host_header(&mut self) {
|
fn save_host_header(&mut self) {
|
||||||
// Save the host header to be used by filter logic later on.
|
// Save the host header to be used by filter logic later on.
|
||||||
self.host_header = self.get_http_request_header(":host");
|
self.host_header = self.get_http_request_header(":host");
|
||||||
|
|
@ -70,7 +94,8 @@ impl StreamContext {
|
||||||
// The gateway can start gathering information necessary for routing. For now change the path to an
|
// The gateway can start gathering information necessary for routing. For now change the path to an
|
||||||
// OpenAI API path.
|
// OpenAI API path.
|
||||||
Some(path) if path == "/llmrouting" => {
|
Some(path) if path == "/llmrouting" => {
|
||||||
self.set_http_request_header(":path", Some("/v1/chat/completions"));
|
self.set_http_request_header(":path", Some(OPENAI_CHAT_COMPLETIONS_PATH));
|
||||||
|
self.chat_completions_request = true;
|
||||||
}
|
}
|
||||||
// Otherwise let the filter continue.
|
// Otherwise let the filter continue.
|
||||||
_ => (),
|
_ => (),
|
||||||
|
|
@ -86,21 +111,26 @@ impl StreamContext {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_server_error(&self, error: String) {
|
fn send_server_error(&self, error: String, override_status_code: Option<StatusCode>) {
|
||||||
debug!("server error occurred: {}", error);
|
debug!("server error occurred: {}", error);
|
||||||
self.send_http_response(
|
self.send_http_response(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
override_status_code
|
||||||
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
|
.as_u16()
|
||||||
|
.into(),
|
||||||
vec![],
|
vec![],
|
||||||
Some(error.as_bytes()),
|
Some(error.as_bytes()),
|
||||||
)
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
||||||
Ok(embedding_response) => embedding_response,
|
Ok(embedding_response) => embedding_response,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
self.send_server_error(format!("Error deserializing embedding response: {:?}", e));
|
return self.send_server_error(
|
||||||
return;
|
format!("Error deserializing embedding response: {:?}", e),
|
||||||
|
None,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -115,19 +145,15 @@ impl StreamContext {
|
||||||
let prompt_target_embeddings = match embeddings_store().read() {
|
let prompt_target_embeddings = match embeddings_store().read() {
|
||||||
Ok(embeddings) => embeddings,
|
Ok(embeddings) => embeddings,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let error_message = format!("Error reading embeddings store: {:?}", e);
|
return self
|
||||||
warn!("{}", error_message);
|
.send_server_error(format!("Error reading embeddings store: {:?}", e), None);
|
||||||
self.send_server_error(error_message);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let prompt_targets = match self.prompt_targets.read() {
|
let prompt_targets = match self.prompt_targets.read() {
|
||||||
Ok(prompt_targets) => prompt_targets,
|
Ok(prompt_targets) => prompt_targets,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let error_message = format!("Error reading prompt targets: {:?}", e);
|
self.send_server_error(format!("Error reading prompt targets: {:?}", e), None);
|
||||||
warn!("{}", error_message);
|
|
||||||
self.send_server_error(error_message);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -220,12 +246,13 @@ impl StreamContext {
|
||||||
match serde_json::from_slice(&body) {
|
match serde_json::from_slice(&body) {
|
||||||
Ok(zeroshot_response) => zeroshot_response,
|
Ok(zeroshot_response) => zeroshot_response,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!(
|
self.send_server_error(
|
||||||
"Error deserializing zeroshot intent detection response: {:?}",
|
format!(
|
||||||
e
|
"Error deserializing zeroshot intent detection response: {:?}",
|
||||||
|
e
|
||||||
|
),
|
||||||
|
None,
|
||||||
);
|
);
|
||||||
info!("body: {:?}", String::from_utf8(body).unwrap());
|
|
||||||
self.resume_http_request();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -319,10 +346,12 @@ impl StreamContext {
|
||||||
parameters: tools_parameters,
|
parameters: tools_parameters,
|
||||||
};
|
};
|
||||||
|
|
||||||
let chat_completions = ChatCompletions {
|
let chat_completions = ChatCompletionsRequest {
|
||||||
model: GPT_35_TURBO.to_string(),
|
model: GPT_35_TURBO.to_string(),
|
||||||
messages: callout_context.request_body.messages.clone(),
|
messages: callout_context.request_body.messages.clone(),
|
||||||
tools: Some(vec![tools_defintion]),
|
tools: Some(vec![tools_defintion]),
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||||
|
|
@ -331,11 +360,10 @@ impl StreamContext {
|
||||||
msg_body
|
msg_body
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
self.send_server_error(format!(
|
return self.send_server_error(
|
||||||
"Error serializing request_params: {:?}",
|
format!("Error serializing request_params: {:?}", e),
|
||||||
e
|
None,
|
||||||
));
|
);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -424,12 +452,10 @@ impl StreamContext {
|
||||||
.arguments
|
.arguments
|
||||||
.contains_key(¶m.name)
|
.contains_key(¶m.name)
|
||||||
{
|
{
|
||||||
warn!("boltfc did not extract required parameter: {}", param.name);
|
self.send_server_error(
|
||||||
return self.send_http_response(
|
format!("missing required parameter: {}", param.name),
|
||||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
Some(StatusCode::BAD_REQUEST),
|
||||||
vec![],
|
)
|
||||||
Some("missing required parameter".as_bytes()),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
@ -510,17 +536,19 @@ impl StreamContext {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let request_message: ChatCompletions = ChatCompletions {
|
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||||
model: GPT_35_TURBO.to_string(),
|
model: GPT_35_TURBO.to_string(),
|
||||||
messages,
|
messages,
|
||||||
tools: None,
|
tools: None,
|
||||||
|
stream: callout_context.request_body.stream,
|
||||||
|
stream_options: callout_context.request_body.stream_options,
|
||||||
};
|
};
|
||||||
|
|
||||||
let json_string = match serde_json::to_string(&request_message) {
|
let json_string = match serde_json::to_string(&chat_completions_request) {
|
||||||
Ok(json_string) => json_string,
|
Ok(json_string) => json_string,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
self.send_server_error(format!("Error serializing request_body: {:?}", e));
|
return self
|
||||||
return;
|
.send_server_error(format!("Error serializing request_body: {:?}", e), None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
debug!(
|
debug!(
|
||||||
|
|
@ -528,22 +556,21 @@ impl StreamContext {
|
||||||
json_string
|
json_string
|
||||||
);
|
);
|
||||||
|
|
||||||
let request_body = callout_context.request_body;
|
|
||||||
|
|
||||||
// Tokenize and Ratelimit.
|
// Tokenize and Ratelimit.
|
||||||
if let Some(selector) = self.ratelimit_selector.take() {
|
if let Some(selector) = self.ratelimit_selector.take() {
|
||||||
if let Ok(token_count) = tokenizer::token_count(&request_body.model, &json_string) {
|
if let Ok(token_count) =
|
||||||
|
tokenizer::token_count(&chat_completions_request.model, &json_string)
|
||||||
|
{
|
||||||
match ratelimit::ratelimits(None).read().unwrap().check_limit(
|
match ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||||
request_body.model,
|
chat_completions_request.model,
|
||||||
selector,
|
selector,
|
||||||
NonZero::new(token_count as u32).unwrap(),
|
NonZero::new(token_count as u32).unwrap(),
|
||||||
) {
|
) {
|
||||||
Ok(_) => (),
|
Ok(_) => (),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
self.send_http_response(
|
self.send_server_error(
|
||||||
StatusCode::TOO_MANY_REQUESTS.as_u16().into(),
|
format!("Exceeded Ratelimit: {}", err),
|
||||||
vec![],
|
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||||
Some(format!("Exceeded Ratelimit: {}", err).as_bytes()),
|
|
||||||
);
|
);
|
||||||
self.metrics.ratelimited_rq.increment(1);
|
self.metrics.ratelimited_rq.increment(1);
|
||||||
return;
|
return;
|
||||||
|
|
@ -583,31 +610,36 @@ impl HttpContext for StreamContext {
|
||||||
|
|
||||||
// Deserialize body into spec.
|
// Deserialize body into spec.
|
||||||
// Currently OpenAI API.
|
// Currently OpenAI API.
|
||||||
let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) {
|
let mut deserialized_body: ChatCompletionsRequest =
|
||||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
match self.get_http_request_body(0, body_size) {
|
||||||
Ok(deserialized) => deserialized,
|
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||||
Err(msg) => {
|
Ok(deserialized) => deserialized,
|
||||||
self.send_http_response(
|
Err(msg) => {
|
||||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
self.send_server_error(
|
||||||
vec![],
|
format!("Failed to deserialize: {}", msg),
|
||||||
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
|
Some(StatusCode::BAD_REQUEST),
|
||||||
|
);
|
||||||
|
return Action::Pause;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
self.send_server_error(
|
||||||
|
format!(
|
||||||
|
"Failed to obtain body bytes even though body_size is {}",
|
||||||
|
body_size
|
||||||
|
),
|
||||||
|
None,
|
||||||
);
|
);
|
||||||
return Action::Pause;
|
return Action::Pause;
|
||||||
}
|
}
|
||||||
},
|
};
|
||||||
None => {
|
|
||||||
self.send_http_response(
|
self.streaming_response = deserialized_body.stream;
|
||||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
|
||||||
vec![],
|
deserialized_body.stream_options = Some(StreamOptions {
|
||||||
None,
|
include_usage: true,
|
||||||
);
|
});
|
||||||
error!(
|
}
|
||||||
"Failed to obtain body bytes even though body_size is {}",
|
|
||||||
body_size
|
|
||||||
);
|
|
||||||
return Action::Pause;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let user_message = match deserialized_body
|
let user_message = match deserialized_body
|
||||||
.messages
|
.messages
|
||||||
|
|
@ -682,6 +714,92 @@ impl HttpContext for StreamContext {
|
||||||
|
|
||||||
Action::Pause
|
Action::Pause
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||||
|
if !self.chat_completions_request {
|
||||||
|
return Action::Continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"recv [S={}] bytes={} end_stream={}",
|
||||||
|
self.context_id, body_size, end_of_stream
|
||||||
|
);
|
||||||
|
|
||||||
|
if !end_of_stream && !self.streaming_response {
|
||||||
|
return Action::Pause;
|
||||||
|
}
|
||||||
|
|
||||||
|
let body = self
|
||||||
|
.get_http_response_body(0, body_size)
|
||||||
|
.expect("cant get response body");
|
||||||
|
|
||||||
|
let body_str = String::from_utf8(body).expect("body is not utf-8");
|
||||||
|
|
||||||
|
if self.streaming_response {
|
||||||
|
debug!("streaming response");
|
||||||
|
let chat_completions_data = match body_str.split_once("data: ") {
|
||||||
|
Some((_, chat_completions_data)) => chat_completions_data,
|
||||||
|
None => {
|
||||||
|
self.send_server_error(String::from("parsing error in streaming data"), None);
|
||||||
|
return Action::Pause;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let chat_completions_chunk_response: ChatCompletionChunkResponse =
|
||||||
|
match serde_json::from_str(chat_completions_data) {
|
||||||
|
Ok(de) => de,
|
||||||
|
Err(_) => {
|
||||||
|
if chat_completions_data != "[NONE]" {
|
||||||
|
self.send_server_error(
|
||||||
|
String::from("error in streaming response"),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
return Action::Continue;
|
||||||
|
}
|
||||||
|
return Action::Continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(content) = chat_completions_chunk_response
|
||||||
|
.choices
|
||||||
|
.first()
|
||||||
|
.unwrap()
|
||||||
|
.delta
|
||||||
|
.content
|
||||||
|
.as_ref()
|
||||||
|
{
|
||||||
|
let model = &chat_completions_chunk_response.model;
|
||||||
|
let token_count = tokenizer::token_count(model, content).unwrap_or(0);
|
||||||
|
self.response_tokens += token_count;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
debug!("non streaming response");
|
||||||
|
let chat_completions_response: ChatCompletionsResponse =
|
||||||
|
match serde_json::from_str(&body_str) {
|
||||||
|
Ok(de) => de,
|
||||||
|
Err(e) => {
|
||||||
|
self.send_server_error(
|
||||||
|
format!(
|
||||||
|
"error in non-streaming response: {}\n response was={}",
|
||||||
|
e, body_str
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
return Action::Pause;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.response_tokens += chat_completions_response.usage.completions_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"recv [S={}] total_tokens={} end_stream={}",
|
||||||
|
self.context_id, self.response_tokens, end_of_stream
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO:: ratelimit based on response tokens.
|
||||||
|
Action::Continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Context for StreamContext {
|
impl Context for StreamContext {
|
||||||
|
|
@ -711,9 +829,10 @@ impl Context for StreamContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let error_message = "No response body in inline HTTP request";
|
self.send_server_error(
|
||||||
warn!("{}", error_message);
|
String::from("No response body in inline HTTP request"),
|
||||||
self.send_server_error(error_message.to_owned());
|
None,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,14 +36,10 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||||
.call_proxy_on_request_headers(http_context, 0, false)
|
.call_proxy_on_request_headers(http_context, 0, false)
|
||||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
|
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
|
||||||
.returning(Some("api.openai.com"))
|
.returning(Some("api.openai.com"))
|
||||||
.expect_add_header_map_value(
|
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||||
Some(MapType::HttpRequestHeaders),
|
|
||||||
Some("content-length"),
|
|
||||||
Some(""),
|
|
||||||
)
|
|
||||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||||
.returning(Some("/llmrouting"))
|
.returning(Some("/llmrouting"))
|
||||||
.expect_add_header_map_value(
|
.expect_replace_header_map_value(
|
||||||
Some(MapType::HttpRequestHeaders),
|
Some(MapType::HttpRequestHeaders),
|
||||||
Some(":path"),
|
Some(":path"),
|
||||||
Some("/v1/chat/completions"),
|
Some("/v1/chat/completions"),
|
||||||
|
|
@ -196,7 +192,7 @@ prompt_targets:
|
||||||
- name: city
|
- name: city
|
||||||
|
|
||||||
ratelimits:
|
ratelimits:
|
||||||
- provider: gpt-4
|
- provider: gpt-3.5-turbo
|
||||||
selector:
|
selector:
|
||||||
key: selector-key
|
key: selector-key
|
||||||
value: selector-value
|
value: selector-value
|
||||||
|
|
@ -245,14 +241,10 @@ fn successful_request_to_open_ai_chat_completions() {
|
||||||
.call_proxy_on_request_headers(http_context, 0, false)
|
.call_proxy_on_request_headers(http_context, 0, false)
|
||||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
|
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
|
||||||
.returning(Some("api.openai.com"))
|
.returning(Some("api.openai.com"))
|
||||||
.expect_add_header_map_value(
|
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||||
Some(MapType::HttpRequestHeaders),
|
|
||||||
Some("content-length"),
|
|
||||||
Some(""),
|
|
||||||
)
|
|
||||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||||
.returning(Some("/llmrouting"))
|
.returning(Some("/llmrouting"))
|
||||||
.expect_add_header_map_value(
|
.expect_replace_header_map_value(
|
||||||
Some(MapType::HttpRequestHeaders),
|
Some(MapType::HttpRequestHeaders),
|
||||||
Some(":path"),
|
Some(":path"),
|
||||||
Some("/v1/chat/completions"),
|
Some("/v1/chat/completions"),
|
||||||
|
|
@ -289,9 +281,9 @@ fn successful_request_to_open_ai_chat_completions() {
|
||||||
)
|
)
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
.returning(Some(chat_completions_request_body))
|
.returning(Some(chat_completions_request_body))
|
||||||
// TODO: assert that the model field was added.
|
|
||||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
|
.expect_http_call(Some("model_server"), None, None, None, None)
|
||||||
|
.returning(Some(4))
|
||||||
.expect_metric_increment("active_http_calls", 1)
|
.expect_metric_increment("active_http_calls", 1)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -335,14 +327,10 @@ fn bad_request_to_open_ai_chat_completions() {
|
||||||
.call_proxy_on_request_headers(http_context, 0, false)
|
.call_proxy_on_request_headers(http_context, 0, false)
|
||||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
|
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
|
||||||
.returning(Some("api.openai.com"))
|
.returning(Some("api.openai.com"))
|
||||||
.expect_add_header_map_value(
|
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
||||||
Some(MapType::HttpRequestHeaders),
|
|
||||||
Some("content-length"),
|
|
||||||
Some(""),
|
|
||||||
)
|
|
||||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||||
.returning(Some("/llmrouting"))
|
.returning(Some("/llmrouting"))
|
||||||
.expect_add_header_map_value(
|
.expect_replace_header_map_value(
|
||||||
Some(MapType::HttpRequestHeaders),
|
Some(MapType::HttpRequestHeaders),
|
||||||
Some(":path"),
|
Some(":path"),
|
||||||
Some("/v1/chat/completions"),
|
Some("/v1/chat/completions"),
|
||||||
|
|
@ -377,6 +365,7 @@ fn bad_request_to_open_ai_chat_completions() {
|
||||||
)
|
)
|
||||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||||
.returning(Some(incomplete_chat_completions_request_body))
|
.returning(Some(incomplete_chat_completions_request_body))
|
||||||
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_send_local_response(
|
.expect_send_local_response(
|
||||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
||||||
None,
|
None,
|
||||||
|
|
@ -485,6 +474,10 @@ fn request_ratelimited() {
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.expect_metric_increment("ratelimited_rq", 1)
|
.expect_metric_increment("ratelimited_rq", 1)
|
||||||
|
.expect_log(
|
||||||
|
Some(LogLevel::Debug),
|
||||||
|
Some("server error occurred: Exceeded Ratelimit: Not allowed"),
|
||||||
|
)
|
||||||
.execute_and_expect(ReturnType::None)
|
.execute_and_expect(ReturnType::None)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -94,12 +94,21 @@ pub mod open_ai {
|
||||||
use super::ToolsDefinition;
|
use super::ToolsDefinition;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ChatCompletions {
|
pub struct ChatCompletionsRequest {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub messages: Vec<Message>,
|
pub messages: Vec<Message>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tools: Option<Vec<ToolsDefinition>>,
|
pub tools: Option<Vec<ToolsDefinition>>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream_options: Option<StreamOptions>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct StreamOptions {
|
||||||
|
pub include_usage: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -109,6 +118,33 @@ pub mod open_ai {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub model: Option<String>,
|
pub model: Option<String>,
|
||||||
}
|
}
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionsResponse {
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub completions_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatCompletionChunkResponse {
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<Choice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Choice {
|
||||||
|
pub delta: Delta,
|
||||||
|
// TODO: could this be an enum?
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Delta {
|
||||||
|
pub content: Option<String>,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue