diff --git a/arch/docker-compose.dev.yaml b/arch/docker-compose.dev.yaml index 36c364bb..7457bfc5 100644 --- a/arch/docker-compose.dev.yaml +++ b/arch/docker-compose.dev.yaml @@ -12,7 +12,7 @@ services: - ./envoy.template.yaml:/config/envoy.template.yaml - ./target/wasm32-wasi/release/intelligent_prompt_gateway.wasm:/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm - ./arch_config_schema.yaml:/config/arch_config_schema.yaml - - ./tools/config_generator.py:/config/config_generator.py + - ./tools/cli/config_generator.py:/config/config_generator.py - ./arch_logs:/var/log/ env_file: - stage.env diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 070aa919..86db4906 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -174,87 +174,6 @@ static_resources: typed_config: "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router - - name: arch_listener_llm - address: - socket_address: - address: 0.0.0.0 - port_value: 12000 - filter_chains: - - filters: - - name: envoy.filters.network.http_connection_manager - typed_config: - "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager - {% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %} - generate_request_id: true - tracing: - provider: - name: envoy.tracers.opentelemetry - typed_config: - "@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig - grpc_service: - envoy_grpc: - cluster_name: opentelemetry_collector - timeout: 0.250s - service_name: arch - random_sampling: - value: {{ arch_tracing.random_sampling }} - {% endif %} - stat_prefix: arch_listener_http - codec_type: AUTO - scheme_header_transformation: - scheme_to_overwrite: https - access_log: - - name: envoy.access_loggers.file - typed_config: - "@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog - path: "/var/log/access_llm.log" - route_config: - name: local_routes - virtual_hosts: - - name: local_service - domains: - - "*" - routes: - {% for provider in arch_llm_providers %} - - match: - prefix: "/" - headers: - - name: "x-arch-llm-provider" - string_match: - exact: {{ provider.name }} - route: - auto_host_rewrite: true - cluster: {{ provider.provider }} - timeout: 60s - {% endfor %} - - match: - prefix: "/" - direct_response: - status: 400 - body: - inline_string: "x-arch-llm-provider header not set, cannot perform routing\n" - http_filters: - - name: envoy.filters.http.wasm - typed_config: - "@type": type.googleapis.com/udpa.type.v1.TypedStruct - type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm - value: - config: - name: "http_config" - root_id: llm_gateway - configuration: - "@type": "type.googleapis.com/google.protobuf.StringValue" - value: | - {{ arch_llm_config | indent(32) }} - vm_config: - runtime: "envoy.wasm.runtime.v8" - code: - local: - filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm" - - name: envoy.filters.http.router - typed_config: - "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router - clusters: - name: openai connect_timeout: 5s diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index bc9e62fa..f36c0de6 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -112,6 +112,7 @@ pub struct StreamContext { llm_provider: Option>, request_id: Option, mode: GatewayMode, + read_response_bytes: usize, } impl StreamContext { @@ -150,6 +151,7 @@ impl StreamContext { overrides, request_id: None, mode, + read_response_bytes: 0, } } fn llm_provider(&self) -> &LlmProvider { @@ -1101,6 +1103,87 @@ impl StreamContext { self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes()); self.resume_http_request(); } + + fn chat_completions_streaming_response_handler(&mut self, data: ChatCompletionChunkResponse) { + if let Some(content) = data.choices.first().unwrap().delta.content.as_ref() { + let model = &data.model; + let token_count = tokenizer::token_count(model, content).unwrap_or(0); + self.response_tokens += token_count; + } + } + + fn chat_completions_unary_response_handler( + &mut self, + data: ChatCompletionsResponse, + body: &[u8], + body_size: usize, + ) { + if data.usage.is_some() { + self.response_tokens += data.usage.as_ref().unwrap().completion_tokens; + } + + if let Some(tool_calls) = self.tool_calls.as_ref() { + if !tool_calls.is_empty() { + if self.arch_state.is_none() { + self.arch_state = Some(Vec::new()); + } + + // compute sha hash from message history + let mut hasher = Sha256::new(); + let prompts: Vec = self + .chat_completions_request + .as_ref() + .unwrap() + .messages + .iter() + .filter(|msg| msg.role == USER_ROLE) + .map(|msg| msg.content.clone().unwrap()) + .collect(); + let prompts_merged = prompts.join("#.#"); + hasher.update(prompts_merged.clone()); + let hash_key = hasher.finalize(); + // conver hash to hex string + let hash_key_str = format!("{:x}", hash_key); + debug!( + "hash key: {}, prompts: {} {:?}", + hash_key_str, prompts_merged, self.mode + ); + + // create new tool call state + let tool_call_state = ToolCallState { + key: hash_key_str, + message: self.user_prompt.clone(), + tool_call: tool_calls[0].function.clone(), + tool_response: self.tool_call_response.clone().unwrap(), + }; + + // push tool call state to arch state + self.arch_state + .as_mut() + .unwrap() + .push(ArchState::ToolCall(vec![tool_call_state])); + + let mut data: Value = serde_json::from_slice(&body).unwrap(); + // use serde::Value to manipulate the json object and ensure that we don't lose any data + if let Value::Object(ref mut map) = data { + // serialize arch state and add to metadata + let arch_state_str = serde_json::to_string(&self.arch_state).unwrap(); + debug!("arch_state: {} {:?}", arch_state_str, self.mode); + let metadata = map + .entry("metadata") + .or_insert(Value::Object(serde_json::Map::new())); + metadata.as_object_mut().unwrap().insert( + ARCH_STATE_HEADER.to_string(), + serde_json::Value::String(arch_state_str), + ); + + let data_serialized = serde_json::to_string(&data).unwrap(); + debug!("arch => user: {} {:?}", data_serialized, self.mode); + self.set_http_response_body(0, body_size, data_serialized.as_bytes()); + }; + } + } + } } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. @@ -1328,155 +1411,47 @@ impl HttpContext for StreamContext { } fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - debug!( - "recv [S={}] bytes={} end_stream={}", - self.context_id, body_size, end_of_stream - ); - - if !self.is_chat_completions_request { - if let Some(body_str) = self - .get_http_response_body(0, body_size) - .and_then(|bytes| String::from_utf8(bytes).ok()) - { - debug!("recv [S={}] body_str={}", self.context_id, body_str); - } + if body_size == 0 { return Action::Continue; } - if !end_of_stream { - return Action::Pause; - } - let body = self - .get_http_response_body(0, body_size) + .get_http_response_body(self.read_response_bytes, body_size) .expect("cant get response body"); - - if self.streaming_response { - let body_str = String::from_utf8(body).expect("body is not utf-8"); - debug!("streaming response"); - let chat_completions_data = match body_str.split_once("data: ") { - Some((_, chat_completions_data)) => chat_completions_data, - None => { - self.send_server_error( - ServerError::LogicError(String::from("parsing error in streaming data")), - None, - ); - return Action::Pause; - } - }; - - let chat_completions_chunk_response: ChatCompletionChunkResponse = - match serde_json::from_str(chat_completions_data) { - Ok(de) => de, - Err(_) => { - if chat_completions_data != "[NONE]" { - self.send_server_error( - ServerError::LogicError(String::from( - "error in streaming response", - )), - None, - ); - return Action::Continue; - } - return Action::Continue; - } - }; - - if let Some(content) = chat_completions_chunk_response - .choices - .first() - .unwrap() - .delta - .content - .as_ref() - { - let model = &chat_completions_chunk_response.model; - let token_count = tokenizer::token_count(model, content).unwrap_or(0); - self.response_tokens += token_count; - } - } else { - debug!("non streaming response"); - let chat_completions_response: ChatCompletionsResponse = - match serde_json::from_slice(&body) { - Ok(de) => de, - Err(e) => { - debug!("invalid response: {}", String::from_utf8_lossy(&body)); - self.send_server_error(ServerError::Deserialization(e), None); - return Action::Pause; - } - }; - - if chat_completions_response.usage.is_some() { - self.response_tokens += chat_completions_response - .usage - .as_ref() - .unwrap() - .completion_tokens; - } - - if let Some(tool_calls) = self.tool_calls.as_ref() { - if !tool_calls.is_empty() { - if self.arch_state.is_none() { - self.arch_state = Some(Vec::new()); - } - - // compute sha hash from message history - let mut hasher = Sha256::new(); - let prompts: Vec = self - .chat_completions_request - .as_ref() - .unwrap() - .messages - .iter() - .filter(|msg| msg.role == USER_ROLE) - .map(|msg| msg.content.clone().unwrap()) - .collect(); - let prompts_merged = prompts.join("#.#"); - hasher.update(prompts_merged.clone()); - let hash_key = hasher.finalize(); - // conver hash to hex string - let hash_key_str = format!("{:x}", hash_key); - debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged); - - // create new tool call state - let tool_call_state = ToolCallState { - key: hash_key_str, - message: self.user_prompt.clone(), - tool_call: tool_calls[0].function.clone(), - tool_response: self.tool_call_response.clone().unwrap(), - }; - - // push tool call state to arch state - self.arch_state - .as_mut() - .unwrap() - .push(ArchState::ToolCall(vec![tool_call_state])); - - let mut data: Value = serde_json::from_slice(&body).unwrap(); - // use serde::Value to manipulate the json object and ensure that we don't lose any data - if let Value::Object(ref mut map) = data { - // serialize arch state and add to metadata - let arch_state_str = serde_json::to_string(&self.arch_state).unwrap(); - debug!("arch_state: {}", arch_state_str); - let metadata = map - .entry("metadata") - .or_insert(Value::Object(serde_json::Map::new())); - metadata.as_object_mut().unwrap().insert( - ARCH_STATE_HEADER.to_string(), - serde_json::Value::String(arch_state_str), - ); - - let data_serialized = serde_json::to_string(&data).unwrap(); - debug!("arch => user: {}", data_serialized); - self.set_http_response_body(0, body_size, data_serialized.as_bytes()); - }; - } - } - } + self.read_response_bytes += body_size; + let body_str = String::from_utf8(body).expect("body is not utf-8"); debug!( - "recv [S={}] total_tokens={} end_stream={}", - self.context_id, self.response_tokens, end_of_stream + "recv [S={}] bytes={}({}) end_stream={}", + self.context_id, + body_size - self.read_response_bytes, + body_str, + end_of_stream, + ); + + match serde_json::from_str(&body_str) { + Ok(de) => { + self.chat_completions_unary_response_handler(de, body_str.as_bytes(), body_size); + } + Err(_) => { + debug!( + "Couldn't deserialize as ChatCompletionsResponse {:?}", + self.mode + ) + } + }; + + match body_str.split_once("data: ") { + Some((_, chat_completions_data)) => match serde_json::from_str(chat_completions_data) { + Ok(de) => self.chat_completions_streaming_response_handler(de), + Err(_) => debug!("couldn't deserialize streaming data {:?}", self.mode), + }, + None => debug!("couldn't split {:?}", self.mode), + }; + + debug!( + "recv [S={}] total_tokens={} end_stream={} {:?}", + self.context_id, self.response_tokens, end_of_stream, self.mode ); // TODO:: ratelimit based on response tokens. diff --git a/chatbot_ui/app/run_stream.py b/chatbot_ui/app/run_stream.py index 8be5a16b..89f6fe7d 100644 --- a/chatbot_ui/app/run_stream.py +++ b/chatbot_ui/app/run_stream.py @@ -4,13 +4,11 @@ import os from openai import OpenAI import gradio as gr -api_key = os.getenv("OPENAI_API_KEY") CHAT_COMPLETION_ENDPOINT = os.getenv( "CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1" ) -client = OpenAI(api_key=api_key, base_url=CHAT_COMPLETION_ENDPOINT) - +client = OpenAI(api_key="--", base_url=CHAT_COMPLETION_ENDPOINT) def predict(message, history): history_openai_format = [] @@ -20,7 +18,7 @@ def predict(message, history): history_openai_format.append({"role": "user", "content": message}) response = client.chat.completions.create( - model="gpt-3.5-turbo", + model="arch", messages=history_openai_format, temperature=1.0, stream=True, @@ -33,4 +31,4 @@ def predict(message, history): yield partial_message -gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8081) +gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8080) diff --git a/model_server/app/__init__.py b/model_server/app/__init__.py index c3c8e9f6..e2e7268b 100644 --- a/model_server/app/__init__.py +++ b/model_server/app/__init__.py @@ -36,7 +36,7 @@ def start_server(): sys.exit(1) print( - "Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)" + "Starting Archgw Model Server - Loading some awesomeness, this may take a little time." ) process = subprocess.Popen( ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],