debug streaming

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-10-13 15:36:53 -06:00
parent 639839fbb1
commit bbd6058ca0
5 changed files with 122 additions and 230 deletions

View file

@ -12,7 +12,7 @@ services:
- ./envoy.template.yaml:/config/envoy.template.yaml
- ./target/wasm32-wasi/release/intelligent_prompt_gateway.wasm:/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm
- ./arch_config_schema.yaml:/config/arch_config_schema.yaml
- ./tools/config_generator.py:/config/config_generator.py
- ./tools/cli/config_generator.py:/config/config_generator.py
- ./arch_logs:/var/log/
env_file:
- stage.env

View file

@ -174,87 +174,6 @@ static_resources:
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
- name: arch_listener_llm
address:
socket_address:
address: 0.0.0.0
port_value: 12000
filter_chains:
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
generate_request_id: true
tracing:
provider:
name: envoy.tracers.opentelemetry
typed_config:
"@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
grpc_service:
envoy_grpc:
cluster_name: opentelemetry_collector
timeout: 0.250s
service_name: arch
random_sampling:
value: {{ arch_tracing.random_sampling }}
{% endif %}
stat_prefix: arch_listener_http
codec_type: AUTO
scheme_header_transformation:
scheme_to_overwrite: https
access_log:
- name: envoy.access_loggers.file
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
path: "/var/log/access_llm.log"
route_config:
name: local_routes
virtual_hosts:
- name: local_service
domains:
- "*"
routes:
{% for provider in arch_llm_providers %}
- match:
prefix: "/"
headers:
- name: "x-arch-llm-provider"
string_match:
exact: {{ provider.name }}
route:
auto_host_rewrite: true
cluster: {{ provider.provider }}
timeout: 60s
{% endfor %}
- match:
prefix: "/"
direct_response:
status: 400
body:
inline_string: "x-arch-llm-provider header not set, cannot perform routing\n"
http_filters:
- name: envoy.filters.http.wasm
typed_config:
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
value:
config:
name: "http_config"
root_id: llm_gateway
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
{{ arch_llm_config | indent(32) }}
vm_config:
runtime: "envoy.wasm.runtime.v8"
code:
local:
filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm"
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
clusters:
- name: openai
connect_timeout: 5s

View file

@ -112,6 +112,7 @@ pub struct StreamContext {
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
mode: GatewayMode,
read_response_bytes: usize,
}
impl StreamContext {
@ -150,6 +151,7 @@ impl StreamContext {
overrides,
request_id: None,
mode,
read_response_bytes: 0,
}
}
fn llm_provider(&self) -> &LlmProvider {
@ -1101,6 +1103,87 @@ impl StreamContext {
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
self.resume_http_request();
}
fn chat_completions_streaming_response_handler(&mut self, data: ChatCompletionChunkResponse) {
if let Some(content) = data.choices.first().unwrap().delta.content.as_ref() {
let model = &data.model;
let token_count = tokenizer::token_count(model, content).unwrap_or(0);
self.response_tokens += token_count;
}
}
fn chat_completions_unary_response_handler(
&mut self,
data: ChatCompletionsResponse,
body: &[u8],
body_size: usize,
) {
if data.usage.is_some() {
self.response_tokens += data.usage.as_ref().unwrap().completion_tokens;
}
if let Some(tool_calls) = self.tool_calls.as_ref() {
if !tool_calls.is_empty() {
if self.arch_state.is_none() {
self.arch_state = Some(Vec::new());
}
// compute sha hash from message history
let mut hasher = Sha256::new();
let prompts: Vec<String> = self
.chat_completions_request
.as_ref()
.unwrap()
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.map(|msg| msg.content.clone().unwrap())
.collect();
let prompts_merged = prompts.join("#.#");
hasher.update(prompts_merged.clone());
let hash_key = hasher.finalize();
// conver hash to hex string
let hash_key_str = format!("{:x}", hash_key);
debug!(
"hash key: {}, prompts: {} {:?}",
hash_key_str, prompts_merged, self.mode
);
// create new tool call state
let tool_call_state = ToolCallState {
key: hash_key_str,
message: self.user_prompt.clone(),
tool_call: tool_calls[0].function.clone(),
tool_response: self.tool_call_response.clone().unwrap(),
};
// push tool call state to arch state
self.arch_state
.as_mut()
.unwrap()
.push(ArchState::ToolCall(vec![tool_call_state]));
let mut data: Value = serde_json::from_slice(&body).unwrap();
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
debug!("arch_state: {} {:?}", arch_state_str, self.mode);
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
metadata.as_object_mut().unwrap().insert(
ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
let data_serialized = serde_json::to_string(&data).unwrap();
debug!("arch => user: {} {:?}", data_serialized, self.mode);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
};
}
}
}
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
@ -1328,155 +1411,47 @@ impl HttpContext for StreamContext {
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!(
"recv [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
);
if !self.is_chat_completions_request {
if let Some(body_str) = self
.get_http_response_body(0, body_size)
.and_then(|bytes| String::from_utf8(bytes).ok())
{
debug!("recv [S={}] body_str={}", self.context_id, body_str);
}
if body_size == 0 {
return Action::Continue;
}
if !end_of_stream {
return Action::Pause;
}
let body = self
.get_http_response_body(0, body_size)
.get_http_response_body(self.read_response_bytes, body_size)
.expect("cant get response body");
if self.streaming_response {
let body_str = String::from_utf8(body).expect("body is not utf-8");
debug!("streaming response");
let chat_completions_data = match body_str.split_once("data: ") {
Some((_, chat_completions_data)) => chat_completions_data,
None => {
self.send_server_error(
ServerError::LogicError(String::from("parsing error in streaming data")),
None,
);
return Action::Pause;
}
};
let chat_completions_chunk_response: ChatCompletionChunkResponse =
match serde_json::from_str(chat_completions_data) {
Ok(de) => de,
Err(_) => {
if chat_completions_data != "[NONE]" {
self.send_server_error(
ServerError::LogicError(String::from(
"error in streaming response",
)),
None,
);
return Action::Continue;
}
return Action::Continue;
}
};
if let Some(content) = chat_completions_chunk_response
.choices
.first()
.unwrap()
.delta
.content
.as_ref()
{
let model = &chat_completions_chunk_response.model;
let token_count = tokenizer::token_count(model, content).unwrap_or(0);
self.response_tokens += token_count;
}
} else {
debug!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
debug!("invalid response: {}", String::from_utf8_lossy(&body));
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
}
};
if chat_completions_response.usage.is_some() {
self.response_tokens += chat_completions_response
.usage
.as_ref()
.unwrap()
.completion_tokens;
}
if let Some(tool_calls) = self.tool_calls.as_ref() {
if !tool_calls.is_empty() {
if self.arch_state.is_none() {
self.arch_state = Some(Vec::new());
}
// compute sha hash from message history
let mut hasher = Sha256::new();
let prompts: Vec<String> = self
.chat_completions_request
.as_ref()
.unwrap()
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.map(|msg| msg.content.clone().unwrap())
.collect();
let prompts_merged = prompts.join("#.#");
hasher.update(prompts_merged.clone());
let hash_key = hasher.finalize();
// conver hash to hex string
let hash_key_str = format!("{:x}", hash_key);
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
// create new tool call state
let tool_call_state = ToolCallState {
key: hash_key_str,
message: self.user_prompt.clone(),
tool_call: tool_calls[0].function.clone(),
tool_response: self.tool_call_response.clone().unwrap(),
};
// push tool call state to arch state
self.arch_state
.as_mut()
.unwrap()
.push(ArchState::ToolCall(vec![tool_call_state]));
let mut data: Value = serde_json::from_slice(&body).unwrap();
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
debug!("arch_state: {}", arch_state_str);
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
metadata.as_object_mut().unwrap().insert(
ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
let data_serialized = serde_json::to_string(&data).unwrap();
debug!("arch => user: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
};
}
}
}
self.read_response_bytes += body_size;
let body_str = String::from_utf8(body).expect("body is not utf-8");
debug!(
"recv [S={}] total_tokens={} end_stream={}",
self.context_id, self.response_tokens, end_of_stream
"recv [S={}] bytes={}({}) end_stream={}",
self.context_id,
body_size - self.read_response_bytes,
body_str,
end_of_stream,
);
match serde_json::from_str(&body_str) {
Ok(de) => {
self.chat_completions_unary_response_handler(de, body_str.as_bytes(), body_size);
}
Err(_) => {
debug!(
"Couldn't deserialize as ChatCompletionsResponse {:?}",
self.mode
)
}
};
match body_str.split_once("data: ") {
Some((_, chat_completions_data)) => match serde_json::from_str(chat_completions_data) {
Ok(de) => self.chat_completions_streaming_response_handler(de),
Err(_) => debug!("couldn't deserialize streaming data {:?}", self.mode),
},
None => debug!("couldn't split {:?}", self.mode),
};
debug!(
"recv [S={}] total_tokens={} end_stream={} {:?}",
self.context_id, self.response_tokens, end_of_stream, self.mode
);
// TODO:: ratelimit based on response tokens.

View file

@ -4,13 +4,11 @@ import os
from openai import OpenAI
import gradio as gr
api_key = os.getenv("OPENAI_API_KEY")
CHAT_COMPLETION_ENDPOINT = os.getenv(
"CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1"
)
client = OpenAI(api_key=api_key, base_url=CHAT_COMPLETION_ENDPOINT)
client = OpenAI(api_key="--", base_url=CHAT_COMPLETION_ENDPOINT)
def predict(message, history):
history_openai_format = []
@ -20,7 +18,7 @@ def predict(message, history):
history_openai_format.append({"role": "user", "content": message})
response = client.chat.completions.create(
model="gpt-3.5-turbo",
model="arch",
messages=history_openai_format,
temperature=1.0,
stream=True,
@ -33,4 +31,4 @@ def predict(message, history):
yield partial_message
gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8081)
gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8080)

View file

@ -36,7 +36,7 @@ def start_server():
sys.exit(1)
print(
"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)"
"Starting Archgw Model Server - Loading some awesomeness, this may take a little time."
)
process = subprocess.Popen(
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],