add retry

This commit is contained in:
Adil Hafeez 2025-09-17 17:53:19 -07:00
parent 71658ddbd9
commit a3f93de85d
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
7 changed files with 29 additions and 9 deletions

View file

@ -1,7 +1,7 @@
use std::collections::HashMap;
use common::configuration::{Agent, AgentPipeline};
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::{ARCH_UPSTREAM_HOST_HEADER, ENVOY_RETRY_HEADER};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
use hyper::header::HeaderMap;
use tracing::{debug, warn};
@ -112,6 +112,11 @@ impl PipelineProcessor {
.map_err(|_| PipelineError::AgentNotFound(agent.name.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let response = self
.client
.post(&self.llm_endpoint)
@ -160,6 +165,11 @@ impl PipelineProcessor {
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.name.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let response = self
.client
.post(&self.llm_endpoint)

View file

@ -29,3 +29,4 @@ pub const HALLUCINATION_TEMPLATE: &str =
pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
pub const OTEL_POST_PATH: &str = "/v1/traces";
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";
pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries";

View file

@ -44,7 +44,7 @@ listeners:
- access_key: $OPENAI_API_KEY
model: openai/gpt-4o-mini
address: 0.0.0.0
port: 9000
port: 12000
tracing:
random_sampling: 100

View file

@ -6,3 +6,12 @@ services:
- "16686:16686"
- "4317:4317"
- "4318:4318"
open-web-ui:
image: ghcr.io/open-webui/open-webui:main
restart: always
ports:
- "8080:8080"
environment:
- DEFAULT_MODEL=gpt-4o-mini
- ENABLE_OPENAI_API=true
- OPENAI_API_BASE_URL=http://host.docker.internal:8001/v1

View file

@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
# Configuration for archgw LLM gateway
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:9000/v1")
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
RAG_MODEL = "gpt-4o-mini"
# Initialize OpenAI client for archgw
@ -91,7 +91,7 @@ async def find_relevant_passages(
logger.info(f"Calling archgw to find relevant passages for query: '{query}'")
# Prepare extra headers if traceparent is provided
extra_headers = {}
extra_headers = {"x-envoy-max-retries": "3"}
if traceparent:
extra_headers["traceparent"] = traceparent

View file

@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
# Configuration for archgw LLM gateway
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:9000/v1")
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
QUERY_REWRITE_MODEL = "gpt-4o-mini"
# Initialize OpenAI client for archgw
@ -50,7 +50,7 @@ async def rewrite_query_with_archgw(
try:
# Call archgw using OpenAI client
extra_headers = {}
extra_headers = {"x-envoy-max-retries": "3"}
if traceparent_header:
extra_headers["traceparent"] = traceparent_header
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to rewrite query")

View file

@ -20,7 +20,7 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration for archgw LLM gateway
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:9000/v1")
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
RESPONSE_MODEL = "gpt-4o"
# System prompt for response generation
@ -94,7 +94,7 @@ async def stream_chat_completions(
)
# Prepare extra headers if traceparent is provided
extra_headers = {}
extra_headers = {"x-envoy-max-retries": "3"}
if traceparent_header:
extra_headers["traceparent"] = traceparent_header
@ -191,7 +191,7 @@ async def non_streaming_chat_completions(
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate response")
# Prepare extra headers if traceparent is provided
extra_headers = {}
extra_headers = {"x-envoy-max-retries": "3"}
if traceparent_header:
extra_headers["traceparent"] = traceparent_header