From b0c1e97dc5b420c7325d518892a91f5b1b8dcc66 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 3 Jun 2025 15:57:30 -0700 Subject: [PATCH] use req/resp from hermesllm in llm gateway --- arch/envoy.template.yaml | 2 +- arch/tools/cli/docker_cli.py | 3 +++ crates/Cargo.lock | 1 + .../hermesllm/src/providers/openai/types.rs | 23 ++++++++++++++++--- crates/llm_gateway/Cargo.toml | 1 + crates/llm_gateway/src/stream_context.rs | 19 ++++++++------- 6 files changed, 35 insertions(+), 14 deletions(-) diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 92db9f3b..f710f908 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -773,7 +773,7 @@ static_resources: - endpoint: address: socket_address: - address: 0.0.0.0 + address: host.docker.internal port_value: 9091 hostname: localhost diff --git a/arch/tools/cli/docker_cli.py b/arch/tools/cli/docker_cli.py index e8a12a13..ba9ef92c 100644 --- a/arch/tools/cli/docker_cli.py +++ b/arch/tools/cli/docker_cli.py @@ -64,6 +64,8 @@ def docker_start_archgw_detached( item for volume in volume_mappings for item in ("-v", volume) ] + print("using custom release path") + options = [ "docker", "run", @@ -76,6 +78,7 @@ def docker_start_archgw_detached( "--add-host", "host.docker.internal:host-gateway", ARCHGW_DOCKER_IMAGE, + "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro", ] result = subprocess.run(options, capture_output=True, text=True, check=False) diff --git a/crates/Cargo.lock b/crates/Cargo.lock index ba5d3796..44ceb6c9 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -1615,6 +1615,7 @@ dependencies = [ "common", "derivative", "governor", + "hermesllm", "http 1.1.0", "log", "md5", diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs index 880d31d1..7f276b49 100644 --- a/crates/hermesllm/src/providers/openai/types.rs +++ b/crates/hermesllm/src/providers/openai/types.rs @@ -57,6 +57,12 @@ pub struct Message { pub content: Option, } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamOptions { + pub include_usage: bool, +} + #[skip_serializing_none] #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionsRequest { @@ -70,6 +76,7 @@ pub struct ChatCompletionsRequest { pub stop: Option>, pub presence_penalty: Option, pub frequency_penalty: Option, + pub stream_options: Option, } impl Default for ChatCompletionsRequest { @@ -85,6 +92,7 @@ impl Default for ChatCompletionsRequest { stop: None, presence_penalty: None, frequency_penalty: None, + stream_options: None, } } } @@ -110,9 +118,9 @@ pub struct Choice { #[skip_serializing_none] #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, } #[derive(Debug, Clone)] @@ -127,6 +135,7 @@ pub struct OpenAIRequestBuilder { stop: Option>, presence_penalty: Option, frequency_penalty: Option, + stream_options: Option, } impl OpenAIRequestBuilder { @@ -142,6 +151,7 @@ impl OpenAIRequestBuilder { stop: None, presence_penalty: None, frequency_penalty: None, + stream_options: None, } } @@ -185,6 +195,12 @@ impl OpenAIRequestBuilder { self } + pub fn stream_options(mut self, include_usage: bool) -> Self { + self.stream = Some(true); + self.stream_options = Some(StreamOptions { include_usage }); + self + } + pub fn build(self) -> Result { let request = ChatCompletionsRequest { model: self.model, @@ -197,6 +213,7 @@ impl OpenAIRequestBuilder { stop: self.stop, presence_penalty: self.presence_penalty, frequency_penalty: self.frequency_penalty, + stream_options: self.stream_options, }; Ok(request) } diff --git a/crates/llm_gateway/Cargo.toml b/crates/llm_gateway/Cargo.toml index 73d62c3d..b65b57b8 100644 --- a/crates/llm_gateway/Cargo.toml +++ b/crates/llm_gateway/Cargo.toml @@ -22,6 +22,7 @@ rand = "0.8.5" thiserror = "1.0.64" derivative = "2.2.0" sha2 = "0.10.8" +hermesllm = { version = "0.1.0", path = "../hermesllm" } [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/katanemo/test-framework.git", branch = "new" } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 7ca3a99b..7f4620ba 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,8 +1,5 @@ use crate::metrics::Metrics; -use common::api::open_ai::{ - ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse, - ContentType, Message, StreamOptions, -}; +use common::api::open_ai::ChatCompletionStreamResponseServerEvents; use common::configuration::{LlmProvider, LlmProviderType, Overrides}; use common::consts::{ ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, @@ -14,6 +11,10 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; +use hermesllm::providers::openai::types::ChatCompletionsRequest; +use hermesllm::providers::openai::types::{ + ChatCompletionsResponse, ContentType, Message, StreamOptions, +}; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::hostcalls::get_current_time; @@ -302,10 +303,6 @@ impl HttpContext for StreamContext { } }; - for message in deserialized_body.messages.iter_mut() { - message.model = None; - } - self.user_message = deserialized_body .messages .iter() @@ -355,10 +352,12 @@ impl HttpContext for StreamContext { chat_completion_request_str ); - if deserialized_body.stream { + if deserialized_body.stream.unwrap_or_default() { self.streaming_response = true; } - if deserialized_body.stream && deserialized_body.stream_options.is_none() { + if deserialized_body.stream.unwrap_or_default() + && deserialized_body.stream_options.is_none() + { deserialized_body.stream_options = Some(StreamOptions { include_usage: true, });