diff --git a/.github/workflows/e2e_archgw.yml b/.github/workflows/e2e_archgw.yml index 84ccff68..fdc43726 100644 --- a/.github/workflows/e2e_archgw.yml +++ b/.github/workflows/e2e_archgw.yml @@ -30,6 +30,7 @@ jobs: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} + GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} run: | docker compose up | tee &> archgw.logs & @@ -55,5 +56,6 @@ jobs: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} + GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} run: | docker compose down diff --git a/.github/workflows/e2e_test_demos.yml b/.github/workflows/e2e_test_demos.yml index 9033ca29..d353fa46 100644 --- a/.github/workflows/e2e_test_demos.yml +++ b/.github/workflows/e2e_test_demos.yml @@ -48,6 +48,7 @@ jobs: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} + GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} run: | source venv/bin/activate cd demos/shared/test_runner && sh run_demo_tests.sh diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index f894b713..576a7fc3 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -29,6 +29,7 @@ jobs: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} + GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} run: | python -mvenv venv source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index cd52220e..523f5781 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -11,7 +11,7 @@ pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; -pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; +pub const CHAT_COMPLETIONS_PATH: [&str; 2] = ["/v1/chat/completions", "/openai/v1/chat/completions"]; pub const HEALTHZ_PATH: &str = "/healthz"; pub const ARCH_STATE_HEADER: &str = "x-arch-state"; pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B"; diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 78d7e21e..55ee53ce 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -89,6 +89,23 @@ impl StreamContext { provider_hint, )); + // Check if we need to modify the path based on the provider's base_url + let needs_openai_prefix = self + .llm_provider + .as_ref() + .and_then(|provider| provider.endpoint.as_ref()) + .map(|url| url.contains("api.groq.com")) + .unwrap_or(false); + + if needs_openai_prefix { + if let Some(path) = self.get_http_request_header(":path") { + if path.starts_with("/v1/") { + let new_path = format!("/openai{}", path); + self.set_http_request_header(":path", Some(new_path.as_str())); + } + } + } + debug!( "request received: llm provider hint: {}, selected llm: {}, model: {}", self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER) @@ -237,8 +254,8 @@ impl HttpContext for StreamContext { self.delete_content_length_header(); self.save_ratelimit_header(); - self.is_chat_completions_request = - self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; + let request_path = self.get_http_request_header(":path").unwrap_or_default(); + self.is_chat_completions_request = CHAT_COMPLETIONS_PATH.contains(&request_path.as_str()); self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER); diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 3a7dc7d9..6fc23921 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -61,7 +61,7 @@ impl HttpContext for StreamContext { return Action::Continue; } - self.is_chat_completions_request = request_path == CHAT_COMPLETIONS_PATH; + self.is_chat_completions_request = CHAT_COMPLETIONS_PATH.contains(&request_path.as_str()); debug!( "on_http_request_headers S[{}] req_headers={:?}", diff --git a/demos/samples_python/weather_forecast/arch_config.yaml b/demos/samples_python/weather_forecast/arch_config.yaml index 8b0f4ca0..db18eb85 100644 --- a/demos/samples_python/weather_forecast/arch_config.yaml +++ b/demos/samples_python/weather_forecast/arch_config.yaml @@ -17,17 +17,13 @@ overrides: prompt_target_intent_matching_threshold: 0.6 llm_providers: - - name: gpt-4o-mini - access_key: $OPENAI_API_KEY + - name: groq + access_key: $GROQ_API_KEY provider_interface: openai - model: gpt-4o-mini + model: llama-3.2-3b-preview + base_url: https://api.groq.com default: true - - name: gpt-3.5-turbo-0125 - access_key: $OPENAI_API_KEY - provider_interface: openai - model: gpt-3.5-turbo-0125 - - name: gpt-4o access_key: $OPENAI_API_KEY provider_interface: openai diff --git a/demos/samples_python/weather_forecast/docker-compose.yaml b/demos/samples_python/weather_forecast/docker-compose.yaml index 566dfa8d..abb69654 100644 --- a/demos/samples_python/weather_forecast/docker-compose.yaml +++ b/demos/samples_python/weather_forecast/docker-compose.yaml @@ -19,3 +19,5 @@ services: - CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1 extra_hosts: - "host.docker.internal:host-gateway" + volumes: + - ./arch_config.yaml:/app/arch_config.yaml diff --git a/tests/e2e/test_prompt_gateway.py b/tests/e2e/test_prompt_gateway.py index f122ad30..9b804bae 100644 --- a/tests/e2e/test_prompt_gateway.py +++ b/tests/e2e/test_prompt_gateway.py @@ -62,7 +62,7 @@ def test_prompt_gateway(stream): # third..end chunk is summarization (role = assistant) response_json = json.loads(chunks[2]) - assert response_json.get("model").startswith("gpt-4o-mini") + assert response_json.get("model").startswith("llama-3.2-3b-preview") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["delta"] @@ -71,7 +71,7 @@ def test_prompt_gateway(stream): else: response_json = response.json() - assert response_json.get("model").startswith("gpt-4o-mini") + assert response_json.get("model").startswith("llama-3.2-3b-preview") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["message"] @@ -231,7 +231,7 @@ def test_prompt_gateway_param_tool_call(stream): # third..end chunk is summarization (role = assistant) response_json = json.loads(chunks[2]) - assert response_json.get("model").startswith("gpt-4o-mini") + assert response_json.get("model").startswith("llama-3.2-3b-preview") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["delta"] @@ -240,7 +240,7 @@ def test_prompt_gateway_param_tool_call(stream): else: response_json = response.json() - assert response_json.get("model").startswith("gpt-4o-mini") + assert response_json.get("model").startswith("llama-3.2-3b-preview") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["message"]