Add ability to stream a response (#50)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-09-17 16:12:41 -07:00 committed by GitHub
parent 118bff7c7e
commit 9f3c845610
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 251 additions and 104 deletions

View file

@ -7,4 +7,5 @@ pub const USER_ROLE: &str = "user";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b"; pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b";
pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const OPENAI_CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const MODEL_SERVER_NAME: &str = "model_server"; pub const MODEL_SERVER_NAME: &str = "model_server";

View file

@ -232,14 +232,12 @@ impl RootContext for FilterContext {
true true
} }
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> { fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
Some(Box::new(StreamContext { Some(Box::new(StreamContext::new(
host_header: None, context_id,
ratelimit_selector: None, Rc::clone(&self.metrics),
callouts: HashMap::new(), Rc::clone(&self.prompt_targets),
metrics: Rc::clone(&self.metrics), )))
prompt_targets: Rc::clone(&self.prompt_targets),
}))
} }
fn get_type(&self) -> Option<ContextType> { fn get_type(&self) -> Option<ContextType> {

View file

@ -1,6 +1,6 @@
use crate::consts::{ use crate::consts::{
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, OPENAI_CHAT_COMPLETIONS_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
}; };
use crate::filter_context::{embeddings_store, WasmMetrics}; use crate::filter_context::{embeddings_store, WasmMetrics};
@ -10,13 +10,16 @@ use crate::stats::IncrementingMetric;
use crate::tokenizer; use crate::tokenizer;
use acap::cos; use acap::cos;
use http::StatusCode; use http::StatusCode;
use log::{debug, error, info, warn}; use log::{debug, info, warn};
use open_message_format_embeddings::models::{ use open_message_format_embeddings::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
}; };
use proxy_wasm::traits::*; use proxy_wasm::traits::*;
use proxy_wasm::types::*; use proxy_wasm::types::*;
use public_types::common_types::open_ai::{ChatCompletions, Message}; use public_types::common_types::open_ai::{
ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message,
StreamOptions,
};
use public_types::common_types::{ use public_types::common_types::{
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition, BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
ZeroShotClassificationRequest, ZeroShotClassificationResponse, ZeroShotClassificationRequest, ZeroShotClassificationResponse,
@ -39,19 +42,40 @@ pub struct CallContext {
response_handler_type: ResponseHandlerType, response_handler_type: ResponseHandlerType,
user_message: Option<String>, user_message: Option<String>,
prompt_target: Option<PromptTarget>, prompt_target: Option<PromptTarget>,
request_body: ChatCompletions, request_body: ChatCompletionsRequest,
similarity_scores: Option<Vec<(String, f64)>>, similarity_scores: Option<Vec<(String, f64)>>,
} }
pub struct StreamContext { pub struct StreamContext {
pub host_header: Option<String>, pub context_id: u32,
pub ratelimit_selector: Option<Header>,
pub callouts: HashMap<u32, CallContext>,
pub metrics: Rc<WasmMetrics>, pub metrics: Rc<WasmMetrics>,
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>, pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
callouts: HashMap<u32, CallContext>,
host_header: Option<String>,
ratelimit_selector: Option<Header>,
streaming_response: bool,
response_tokens: usize,
chat_completions_request: bool,
} }
impl StreamContext { impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
) -> Self {
StreamContext {
context_id,
metrics,
prompt_targets,
callouts: HashMap::new(),
host_header: None,
ratelimit_selector: None,
streaming_response: false,
response_tokens: 0,
chat_completions_request: false,
}
}
fn save_host_header(&mut self) { fn save_host_header(&mut self) {
// Save the host header to be used by filter logic later on. // Save the host header to be used by filter logic later on.
self.host_header = self.get_http_request_header(":host"); self.host_header = self.get_http_request_header(":host");
@ -70,7 +94,8 @@ impl StreamContext {
// The gateway can start gathering information necessary for routing. For now change the path to an // The gateway can start gathering information necessary for routing. For now change the path to an
// OpenAI API path. // OpenAI API path.
Some(path) if path == "/llmrouting" => { Some(path) if path == "/llmrouting" => {
self.set_http_request_header(":path", Some("/v1/chat/completions")); self.set_http_request_header(":path", Some(OPENAI_CHAT_COMPLETIONS_PATH));
self.chat_completions_request = true;
} }
// Otherwise let the filter continue. // Otherwise let the filter continue.
_ => (), _ => (),
@ -86,21 +111,26 @@ impl StreamContext {
}); });
} }
fn send_server_error(&self, error: String) { fn send_server_error(&self, error: String, override_status_code: Option<StatusCode>) {
debug!("server error occurred: {}", error); debug!("server error occurred: {}", error);
self.send_http_response( self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), override_status_code
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
.as_u16()
.into(),
vec![], vec![],
Some(error.as_bytes()), Some(error.as_bytes()),
) );
} }
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) { fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response, Ok(embedding_response) => embedding_response,
Err(e) => { Err(e) => {
self.send_server_error(format!("Error deserializing embedding response: {:?}", e)); return self.send_server_error(
return; format!("Error deserializing embedding response: {:?}", e),
None,
);
} }
}; };
@ -115,19 +145,15 @@ impl StreamContext {
let prompt_target_embeddings = match embeddings_store().read() { let prompt_target_embeddings = match embeddings_store().read() {
Ok(embeddings) => embeddings, Ok(embeddings) => embeddings,
Err(e) => { Err(e) => {
let error_message = format!("Error reading embeddings store: {:?}", e); return self
warn!("{}", error_message); .send_server_error(format!("Error reading embeddings store: {:?}", e), None);
self.send_server_error(error_message);
return;
} }
}; };
let prompt_targets = match self.prompt_targets.read() { let prompt_targets = match self.prompt_targets.read() {
Ok(prompt_targets) => prompt_targets, Ok(prompt_targets) => prompt_targets,
Err(e) => { Err(e) => {
let error_message = format!("Error reading prompt targets: {:?}", e); self.send_server_error(format!("Error reading prompt targets: {:?}", e), None);
warn!("{}", error_message);
self.send_server_error(error_message);
return; return;
} }
}; };
@ -220,12 +246,13 @@ impl StreamContext {
match serde_json::from_slice(&body) { match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response, Ok(zeroshot_response) => zeroshot_response,
Err(e) => { Err(e) => {
warn!( self.send_server_error(
"Error deserializing zeroshot intent detection response: {:?}", format!(
e "Error deserializing zeroshot intent detection response: {:?}",
e
),
None,
); );
info!("body: {:?}", String::from_utf8(body).unwrap());
self.resume_http_request();
return; return;
} }
}; };
@ -319,10 +346,12 @@ impl StreamContext {
parameters: tools_parameters, parameters: tools_parameters,
}; };
let chat_completions = ChatCompletions { let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(), model: GPT_35_TURBO.to_string(),
messages: callout_context.request_body.messages.clone(), messages: callout_context.request_body.messages.clone(),
tools: Some(vec![tools_defintion]), tools: Some(vec![tools_defintion]),
stream: false,
stream_options: None,
}; };
let msg_body = match serde_json::to_string(&chat_completions) { let msg_body = match serde_json::to_string(&chat_completions) {
@ -331,11 +360,10 @@ impl StreamContext {
msg_body msg_body
} }
Err(e) => { Err(e) => {
self.send_server_error(format!( return self.send_server_error(
"Error serializing request_params: {:?}", format!("Error serializing request_params: {:?}", e),
e None,
)); );
return;
} }
}; };
@ -424,12 +452,10 @@ impl StreamContext {
.arguments .arguments
.contains_key(&param.name) .contains_key(&param.name)
{ {
warn!("boltfc did not extract required parameter: {}", param.name); self.send_server_error(
return self.send_http_response( format!("missing required parameter: {}", param.name),
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), Some(StatusCode::BAD_REQUEST),
vec![], )
Some("missing required parameter".as_bytes()),
);
} }
} }
}); });
@ -510,17 +536,19 @@ impl StreamContext {
} }
}); });
let request_message: ChatCompletions = ChatCompletions { let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(), model: GPT_35_TURBO.to_string(),
messages, messages,
tools: None, tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
}; };
let json_string = match serde_json::to_string(&request_message) { let json_string = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string, Ok(json_string) => json_string,
Err(e) => { Err(e) => {
self.send_server_error(format!("Error serializing request_body: {:?}", e)); return self
return; .send_server_error(format!("Error serializing request_body: {:?}", e), None);
} }
}; };
debug!( debug!(
@ -528,22 +556,21 @@ impl StreamContext {
json_string json_string
); );
let request_body = callout_context.request_body;
// Tokenize and Ratelimit. // Tokenize and Ratelimit.
if let Some(selector) = self.ratelimit_selector.take() { if let Some(selector) = self.ratelimit_selector.take() {
if let Ok(token_count) = tokenizer::token_count(&request_body.model, &json_string) { if let Ok(token_count) =
tokenizer::token_count(&chat_completions_request.model, &json_string)
{
match ratelimit::ratelimits(None).read().unwrap().check_limit( match ratelimit::ratelimits(None).read().unwrap().check_limit(
request_body.model, chat_completions_request.model,
selector, selector,
NonZero::new(token_count as u32).unwrap(), NonZero::new(token_count as u32).unwrap(),
) { ) {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
self.send_http_response( self.send_server_error(
StatusCode::TOO_MANY_REQUESTS.as_u16().into(), format!("Exceeded Ratelimit: {}", err),
vec![], Some(StatusCode::TOO_MANY_REQUESTS),
Some(format!("Exceeded Ratelimit: {}", err).as_bytes()),
); );
self.metrics.ratelimited_rq.increment(1); self.metrics.ratelimited_rq.increment(1);
return; return;
@ -583,31 +610,36 @@ impl HttpContext for StreamContext {
// Deserialize body into spec. // Deserialize body into spec.
// Currently OpenAI API. // Currently OpenAI API.
let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) { let mut deserialized_body: ChatCompletionsRequest =
Some(body_bytes) => match serde_json::from_slice(&body_bytes) { match self.get_http_request_body(0, body_size) {
Ok(deserialized) => deserialized, Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Err(msg) => { Ok(deserialized) => deserialized,
self.send_http_response( Err(msg) => {
StatusCode::BAD_REQUEST.as_u16().into(), self.send_server_error(
vec![], format!("Failed to deserialize: {}", msg),
Some(format!("Failed to deserialize: {}", msg).as_bytes()), Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
},
None => {
self.send_server_error(
format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
),
None,
); );
return Action::Pause; return Action::Pause;
} }
}, };
None => {
self.send_http_response( self.streaming_response = deserialized_body.stream;
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), if deserialized_body.stream && deserialized_body.stream_options.is_none() {
vec![], deserialized_body.stream_options = Some(StreamOptions {
None, include_usage: true,
); });
error!( }
"Failed to obtain body bytes even though body_size is {}",
body_size
);
return Action::Pause;
}
};
let user_message = match deserialized_body let user_message = match deserialized_body
.messages .messages
@ -682,6 +714,92 @@ impl HttpContext for StreamContext {
Action::Pause Action::Pause
} }
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
if !self.chat_completions_request {
return Action::Continue;
}
debug!(
"recv [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
);
if !end_of_stream && !self.streaming_response {
return Action::Pause;
}
let body = self
.get_http_response_body(0, body_size)
.expect("cant get response body");
let body_str = String::from_utf8(body).expect("body is not utf-8");
if self.streaming_response {
debug!("streaming response");
let chat_completions_data = match body_str.split_once("data: ") {
Some((_, chat_completions_data)) => chat_completions_data,
None => {
self.send_server_error(String::from("parsing error in streaming data"), None);
return Action::Pause;
}
};
let chat_completions_chunk_response: ChatCompletionChunkResponse =
match serde_json::from_str(chat_completions_data) {
Ok(de) => de,
Err(_) => {
if chat_completions_data != "[NONE]" {
self.send_server_error(
String::from("error in streaming response"),
None,
);
return Action::Continue;
}
return Action::Continue;
}
};
if let Some(content) = chat_completions_chunk_response
.choices
.first()
.unwrap()
.delta
.content
.as_ref()
{
let model = &chat_completions_chunk_response.model;
let token_count = tokenizer::token_count(model, content).unwrap_or(0);
self.response_tokens += token_count;
}
} else {
debug!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_str(&body_str) {
Ok(de) => de,
Err(e) => {
self.send_server_error(
format!(
"error in non-streaming response: {}\n response was={}",
e, body_str
),
None,
);
return Action::Pause;
}
};
self.response_tokens += chat_completions_response.usage.completions_tokens;
}
debug!(
"recv [S={}] total_tokens={} end_stream={}",
self.context_id, self.response_tokens, end_of_stream
);
// TODO:: ratelimit based on response tokens.
Action::Continue
}
} }
impl Context for StreamContext { impl Context for StreamContext {
@ -711,9 +829,10 @@ impl Context for StreamContext {
} }
} }
} else { } else {
let error_message = "No response body in inline HTTP request"; self.send_server_error(
warn!("{}", error_message); String::from("No response body in inline HTTP request"),
self.send_server_error(error_message.to_owned()); None,
);
} }
} }
} }

View file

@ -36,14 +36,10 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.call_proxy_on_request_headers(http_context, 0, false) .call_proxy_on_request_headers(http_context, 0, false)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
.returning(Some("api.openai.com")) .returning(Some("api.openai.com"))
.expect_add_header_map_value( .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
Some(MapType::HttpRequestHeaders),
Some("content-length"),
Some(""),
)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/llmrouting")) .returning(Some("/llmrouting"))
.expect_add_header_map_value( .expect_replace_header_map_value(
Some(MapType::HttpRequestHeaders), Some(MapType::HttpRequestHeaders),
Some(":path"), Some(":path"),
Some("/v1/chat/completions"), Some("/v1/chat/completions"),
@ -196,7 +192,7 @@ prompt_targets:
- name: city - name: city
ratelimits: ratelimits:
- provider: gpt-4 - provider: gpt-3.5-turbo
selector: selector:
key: selector-key key: selector-key
value: selector-value value: selector-value
@ -245,14 +241,10 @@ fn successful_request_to_open_ai_chat_completions() {
.call_proxy_on_request_headers(http_context, 0, false) .call_proxy_on_request_headers(http_context, 0, false)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
.returning(Some("api.openai.com")) .returning(Some("api.openai.com"))
.expect_add_header_map_value( .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
Some(MapType::HttpRequestHeaders),
Some("content-length"),
Some(""),
)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/llmrouting")) .returning(Some("/llmrouting"))
.expect_add_header_map_value( .expect_replace_header_map_value(
Some(MapType::HttpRequestHeaders), Some(MapType::HttpRequestHeaders),
Some(":path"), Some(":path"),
Some("/v1/chat/completions"), Some("/v1/chat/completions"),
@ -289,9 +281,9 @@ fn successful_request_to_open_ai_chat_completions() {
) )
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body)) .returning(Some(chat_completions_request_body))
// TODO: assert that the model field was added.
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("model_server"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1) .expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause)) .execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap(); .unwrap();
@ -335,14 +327,10 @@ fn bad_request_to_open_ai_chat_completions() {
.call_proxy_on_request_headers(http_context, 0, false) .call_proxy_on_request_headers(http_context, 0, false)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host"))
.returning(Some("api.openai.com")) .returning(Some("api.openai.com"))
.expect_add_header_map_value( .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
Some(MapType::HttpRequestHeaders),
Some("content-length"),
Some(""),
)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/llmrouting")) .returning(Some("/llmrouting"))
.expect_add_header_map_value( .expect_replace_header_map_value(
Some(MapType::HttpRequestHeaders), Some(MapType::HttpRequestHeaders),
Some(":path"), Some(":path"),
Some("/v1/chat/completions"), Some("/v1/chat/completions"),
@ -377,6 +365,7 @@ fn bad_request_to_open_ai_chat_completions() {
) )
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body)) .returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response( .expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()), Some(StatusCode::BAD_REQUEST.as_u16().into()),
None, None,
@ -485,6 +474,10 @@ fn request_ratelimited() {
None, None,
) )
.expect_metric_increment("ratelimited_rq", 1) .expect_metric_increment("ratelimited_rq", 1)
.expect_log(
Some(LogLevel::Debug),
Some("server error occurred: Exceeded Ratelimit: Not allowed"),
)
.execute_and_expect(ReturnType::None) .execute_and_expect(ReturnType::None)
.unwrap(); .unwrap();
} }

View file

@ -94,12 +94,21 @@ pub mod open_ai {
use super::ToolsDefinition; use super::ToolsDefinition;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletions { pub struct ChatCompletionsRequest {
#[serde(default)] #[serde(default)]
pub model: String, pub model: String,
pub messages: Vec<Message>, pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolsDefinition>>, pub tools: Option<Vec<ToolsDefinition>>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -109,6 +118,33 @@ pub mod open_ai {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>, pub model: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub completions_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunkResponse {
pub model: String,
pub choices: Vec<Choice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub delta: Delta,
// TODO: could this be an enum?
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Delta {
pub content: Option<String>,
}
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]