mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
add support for using custom upstream llm (#365)
This commit is contained in:
parent
3fc21de60c
commit
07ef3149b8
13 changed files with 263 additions and 52 deletions
|
|
@ -190,7 +190,7 @@ llm_providers:
|
|||
|
||||
- name: ministral-3b
|
||||
access_key: $MISTRAL_API_KEY
|
||||
provider: mistral
|
||||
provider: openai
|
||||
model: ministral-3b-latest
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -43,19 +43,27 @@ properties:
|
|||
properties:
|
||||
name:
|
||||
type: string
|
||||
# this field is deprecated, use provider_interface instead
|
||||
provider:
|
||||
type: string
|
||||
enum:
|
||||
- openai
|
||||
provider_interface:
|
||||
type: string
|
||||
enum:
|
||||
- openai
|
||||
- mistral
|
||||
access_key:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
default:
|
||||
type: boolean
|
||||
endpoint:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
- provider
|
||||
- access_key
|
||||
- model
|
||||
overrides:
|
||||
type: object
|
||||
|
|
|
|||
|
|
@ -125,15 +125,21 @@ static_resources:
|
|||
- "*"
|
||||
routes:
|
||||
{% for provider in arch_llm_providers %}
|
||||
# if endpoint is set then use custom cluster for upstream llm
|
||||
{% if provider.endpoint %}
|
||||
{% set llm_cluster_name = provider.name %}
|
||||
{% else %}
|
||||
{% set llm_cluster_name = provider.provider_interface %}
|
||||
{% endif %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
headers:
|
||||
- name: "x-arch-llm-provider"
|
||||
string_match:
|
||||
exact: {{ provider.name }}
|
||||
exact: {{ llm_cluster_name }}
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: {{ provider.provider }}
|
||||
cluster: {{ llm_cluster_name }}
|
||||
timeout: 60s
|
||||
{% endfor %}
|
||||
http_filters:
|
||||
|
|
@ -237,16 +243,16 @@ static_resources:
|
|||
domains:
|
||||
- "*"
|
||||
routes:
|
||||
{% for internal_clustrer in ["arch_fc", "model_server"] %}
|
||||
{% for internal_cluster in ["arch_fc", "model_server"] %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
headers:
|
||||
- name: "x-arch-upstream"
|
||||
string_match:
|
||||
exact: {{ internal_clustrer }}
|
||||
exact: {{ internal_cluster }}
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: {{ internal_clustrer }}
|
||||
cluster: {{ internal_cluster }}
|
||||
timeout: 60s
|
||||
{% endfor %}
|
||||
|
||||
|
|
@ -370,15 +376,21 @@ static_resources:
|
|||
cluster: openai
|
||||
timeout: 60s
|
||||
{% for provider in arch_llm_providers %}
|
||||
# if endpoint is set then use custom cluster for upstream llm
|
||||
{% if provider.endpoint %}
|
||||
{% set llm_cluster_name = provider.name %}
|
||||
{% else %}
|
||||
{% set llm_cluster_name = provider.provider_interface %}
|
||||
{% endif %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
headers:
|
||||
- name: "x-arch-llm-provider"
|
||||
string_match:
|
||||
exact: {{ provider.name }}
|
||||
exact: {{ llm_cluster_name }}
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: {{ provider.provider }}
|
||||
cluster: {{ llm_cluster_name }}
|
||||
timeout: 60s
|
||||
{% endfor %}
|
||||
- match:
|
||||
|
|
@ -538,6 +550,24 @@ static_resources:
|
|||
tls_maximum_protocol_version: TLSv1_3
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
{% for local_llm_provider in local_llms %}
|
||||
- name: {{ local_llm_provider.name }}
|
||||
connect_timeout: 5s
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: {{ local_llm_provider.name }}
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: {{ local_llm_provider.endpoint }}
|
||||
port_value: {{ local_llm_provider.port }}
|
||||
hostname: {{ local_llm_provider.endpoint }}
|
||||
{% endfor %}
|
||||
- name: arch_internal
|
||||
connect_timeout: 5s
|
||||
type: LOGICAL_DNS
|
||||
|
|
|
|||
|
|
@ -16,18 +16,6 @@ ARCH_CONFIG_SCHEMA_FILE = os.getenv(
|
|||
)
|
||||
|
||||
|
||||
def add_secret_key_to_llm_providers(config_yaml):
|
||||
llm_providers = []
|
||||
for llm_provider in config_yaml.get("llm_providers", []):
|
||||
access_key_env_var = llm_provider.get("access_key", False)
|
||||
access_key_value = os.getenv(access_key_env_var, False)
|
||||
if access_key_env_var and access_key_value:
|
||||
llm_provider["access_key"] = access_key_value
|
||||
llm_providers.append(llm_provider)
|
||||
config_yaml["llm_providers"] = llm_providers
|
||||
return config_yaml
|
||||
|
||||
|
||||
def validate_and_render_schema():
|
||||
env = Environment(loader=FileSystemLoader("./"))
|
||||
template = env.get_template("envoy.template.yaml")
|
||||
|
|
@ -70,18 +58,42 @@ def validate_and_render_schema():
|
|||
f"Unknown endpoint {name}, please add it in endpoints section in your arch_config.yaml file"
|
||||
)
|
||||
|
||||
arch_llm_providers = config_yaml["llm_providers"]
|
||||
arch_tracing = config_yaml.get("tracing", {})
|
||||
|
||||
llms_with_endpoint = []
|
||||
|
||||
updated_llm_providers = []
|
||||
for llm_provider in config_yaml["llm_providers"]:
|
||||
provider = None
|
||||
if llm_provider.get("provider") and llm_provider.get("provider_interface"):
|
||||
raise Exception(
|
||||
"Please provide either provider or provider_interface, not both"
|
||||
)
|
||||
if llm_provider.get("provider"):
|
||||
provider = llm_provider["provider"]
|
||||
llm_provider["provider_interface"] = provider
|
||||
del llm_provider["provider"]
|
||||
updated_llm_providers.append(llm_provider)
|
||||
|
||||
if llm_provider.get("endpoint", None):
|
||||
endpoint = llm_provider["endpoint"]
|
||||
if len(endpoint.split(":")) > 1:
|
||||
llm_provider["endpoint"] = endpoint.split(":")[0]
|
||||
llm_provider["port"] = int(endpoint.split(":")[1])
|
||||
llms_with_endpoint.append(llm_provider)
|
||||
|
||||
config_yaml["llm_providers"] = updated_llm_providers
|
||||
|
||||
arch_config_string = yaml.dump(config_yaml)
|
||||
config_yaml["mode"] = "llm"
|
||||
arch_llm_config_string = yaml.dump(config_yaml)
|
||||
|
||||
data = {
|
||||
"arch_config": arch_config_string,
|
||||
"arch_llm_config": arch_llm_config_string,
|
||||
"arch_clusters": inferred_clusters,
|
||||
"arch_llm_providers": arch_llm_providers,
|
||||
"arch_llm_providers": config_yaml["llm_providers"],
|
||||
"arch_tracing": arch_tracing,
|
||||
"local_llms": llms_with_endpoint,
|
||||
}
|
||||
|
||||
rendered = template.render(data)
|
||||
|
|
|
|||
|
|
@ -162,15 +162,34 @@ pub struct EmbeddingProviver {
|
|||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum LlmProviderType {
|
||||
#[serde(rename = "openai")]
|
||||
OpenAI,
|
||||
#[serde(rename = "mistral")]
|
||||
Mistral,
|
||||
}
|
||||
|
||||
impl Display for LlmProviderType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
LlmProviderType::OpenAI => write!(f, "openai"),
|
||||
LlmProviderType::Mistral => write!(f, "mistral"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct LlmProvider {
|
||||
pub name: String,
|
||||
pub provider: String,
|
||||
pub provider_interface: LlmProviderType,
|
||||
pub access_key: Option<String>,
|
||||
pub model: String,
|
||||
pub default: Option<bool>,
|
||||
pub stream: Option<bool>,
|
||||
pub endpoint: Option<String>,
|
||||
pub port: Option<u16>,
|
||||
pub rate_limits: Option<LlmRatelimit>,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ impl StreamContext {
|
|||
fn select_llm_provider(&mut self) {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.map(|provider_name| provider_name.into());
|
||||
.map(|llm_name| llm_name.into());
|
||||
|
||||
debug!("llm provider hint: {:?}", provider_hint);
|
||||
self.llm_provider = Some(routing::get_llm_provider(
|
||||
|
|
@ -174,10 +174,22 @@ impl HttpContext for StreamContext {
|
|||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
self.select_llm_provider();
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
|
||||
// if endpoint is not set then use provider name as routing header so envoy can resolve the cluster name
|
||||
if self.llm_provider().endpoint.is_none() {
|
||||
self.add_http_request_header(
|
||||
ARCH_ROUTING_HEADER,
|
||||
&self.llm_provider().provider_interface.to_string(),
|
||||
);
|
||||
} else {
|
||||
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
|
||||
}
|
||||
|
||||
if let Err(error) = self.modify_auth_headers() {
|
||||
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
|
||||
// ensure that the provider has an endpoint if the access key is missing else return a bad request
|
||||
if self.llm_provider.as_ref().unwrap().endpoint.is_none() {
|
||||
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
self.delete_content_length_header();
|
||||
self.save_ratelimit_header();
|
||||
|
|
@ -334,16 +346,18 @@ impl HttpContext for StreamContext {
|
|||
// Record the latency to the latency histogram
|
||||
self.metrics.request_latency.record(duration_ms as u64);
|
||||
|
||||
// Compute the time per output token
|
||||
let tpot = duration_ms as u64 / self.response_tokens as u64;
|
||||
if self.response_tokens > 0 {
|
||||
// Compute the time per output token
|
||||
let tpot = duration_ms as u64 / self.response_tokens as u64;
|
||||
|
||||
debug!("Time per output token: {} milliseconds", tpot);
|
||||
// Record the time per output token
|
||||
self.metrics.time_per_output_token.record(tpot);
|
||||
debug!("Time per output token: {} milliseconds", tpot);
|
||||
// Record the time per output token
|
||||
self.metrics.time_per_output_token.record(tpot);
|
||||
|
||||
debug!("Tokens per second: {}", 1000 / tpot);
|
||||
// Record the tokens per second
|
||||
self.metrics.tokens_per_second.record(1000 / tpot);
|
||||
debug!("Tokens per second: {}", 1000 / tpot);
|
||||
// Record the tokens per second
|
||||
self.metrics.tokens_per_second.record(1000 / tpot);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("SystemTime error: {:?}", e);
|
||||
|
|
@ -381,11 +395,13 @@ impl HttpContext for StreamContext {
|
|||
self.llm_provider().name.to_string(),
|
||||
);
|
||||
|
||||
llm_span.add_event(Event::new(
|
||||
"time_to_first_token".to_string(),
|
||||
self.ttft_time.unwrap(),
|
||||
));
|
||||
trace_data.add_span(llm_span);
|
||||
if self.ttft_time.is_some() {
|
||||
llm_span.add_event(Event::new(
|
||||
"time_to_first_token".to_string(),
|
||||
self.ttft_time.unwrap(),
|
||||
));
|
||||
trace_data.add_span(llm_span);
|
||||
}
|
||||
|
||||
self.traces_queue.lock().unwrap().push_back(trace_data);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,11 +23,15 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
|||
Some("x-arch-llm-provider-hint"),
|
||||
)
|
||||
.returning(Some("default"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(
|
||||
Some(LogLevel::Debug),
|
||||
Some("llm provider hint: Some(Default)"),
|
||||
)
|
||||
.expect_log(Some(LogLevel::Debug), Some("selected llm: open-ai-gpt-4"))
|
||||
.expect_add_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
Some("x-arch-llm-provider"),
|
||||
Some("open-ai-gpt-4"),
|
||||
Some("openai"),
|
||||
)
|
||||
.expect_replace_header_map_value(
|
||||
Some(MapType::HttpRequestHeaders),
|
||||
|
|
@ -46,8 +50,6 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
|||
.returning(None)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
||||
.returning(Some("/v1/chat/completions"))
|
||||
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
|
||||
.returning(None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
|
||||
.returning(None)
|
||||
|
|
@ -110,12 +112,12 @@ endpoints:
|
|||
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
provider: openai
|
||||
provider_interface: openai
|
||||
access_key: secret_key
|
||||
model: gpt-4
|
||||
default: true
|
||||
- name: open-ai-gpt-4o
|
||||
provider: openai
|
||||
provider_interface: openai
|
||||
access_key: secret_key
|
||||
model: gpt-4o
|
||||
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ endpoints:
|
|||
|
||||
llm_providers:
|
||||
- name: open-ai-gpt-4
|
||||
provider: openai
|
||||
provider_interface: openai
|
||||
access_key: secret_key
|
||||
model: gpt-4
|
||||
default: true
|
||||
|
|
|
|||
3
demos/currency_exchange_ollama/README.md
Normal file
3
demos/currency_exchange_ollama/README.md
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
This demo shows how you can use ollama as upstream LLM.
|
||||
|
||||
Before you can start the demo please make sure you have ollama up and running. You can use command `ollama run llama3.2` to start llama 3.2 (3b) model locally at port `11434`.
|
||||
53
demos/currency_exchange_ollama/arch_config.yaml
Normal file
53
demos/currency_exchange_ollama/arch_config.yaml
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
version: v0.1
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: local-llama
|
||||
provider_interface: openai
|
||||
model: llama3.2
|
||||
endpoint: host.docker.internal:11434
|
||||
default: true
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
|
||||
|
||||
prompt_targets:
|
||||
- name: currency_exchange
|
||||
description: Get currency exchange rate from USD to other currencies
|
||||
parameters:
|
||||
- name: currency_symbol
|
||||
description: the currency that needs conversion
|
||||
required: true
|
||||
type: str
|
||||
in_path: true
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/latest?base=USD&symbols={currency_symbol}
|
||||
system_prompt: |
|
||||
You are a helpful assistant. Show me the currency symbol you want to convert from USD.
|
||||
|
||||
- name: get_supported_currencies
|
||||
description: Get list of supported currencies for conversion
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/currencies
|
||||
|
||||
endpoints:
|
||||
frankfurther_api:
|
||||
endpoint: api.frankfurter.dev:443
|
||||
protocol: https
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
trace_arch_internal: true
|
||||
21
demos/currency_exchange_ollama/docker-compose.yaml
Normal file
21
demos/currency_exchange_ollama/docker-compose.yaml
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
services:
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../shared/chatbot_ui
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
# this is only because we are running the sample app in the same docker container environemtn as archgw
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
|
||||
jaeger:
|
||||
build:
|
||||
context: ../shared/jaeger
|
||||
ports:
|
||||
- "16686:16686"
|
||||
- "4317:4317"
|
||||
- "4318:4318"
|
||||
47
demos/currency_exchange_ollama/run_demo.sh
Normal file
47
demos/currency_exchange_ollama/run_demo.sh
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Function to start the demo
|
||||
start_demo() {
|
||||
# Step 1: Check if .env file exists
|
||||
if [ -f ".env" ]; then
|
||||
echo ".env file already exists. Skipping creation."
|
||||
else
|
||||
# Step 2: Create `.env` file and set OpenAI key
|
||||
if [ -z "$OPENAI_API_KEY" ]; then
|
||||
echo "Error: OPENAI_API_KEY environment variable is not set for the demo."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Creating .env file..."
|
||||
echo "OPENAI_API_KEY=$OPENAI_API_KEY" > .env
|
||||
echo ".env file created with OPENAI_API_KEY."
|
||||
fi
|
||||
|
||||
# Step 3: Start Arch
|
||||
echo "Starting Arch with arch_config.yaml..."
|
||||
archgw up arch_config.yaml
|
||||
|
||||
# Step 4: Start developer services
|
||||
echo "Starting Network Agent using Docker Compose..."
|
||||
docker compose up -d # Run in detached mode
|
||||
}
|
||||
|
||||
# Function to stop the demo
|
||||
stop_demo() {
|
||||
# Step 1: Stop Docker Compose services
|
||||
echo "Stopping Network Agent using Docker Compose..."
|
||||
docker compose down
|
||||
|
||||
# Step 2: Stop Arch
|
||||
echo "Stopping Arch..."
|
||||
archgw down
|
||||
}
|
||||
|
||||
# Main script logic
|
||||
if [ "$1" == "down" ]; then
|
||||
stop_demo
|
||||
else
|
||||
# Default action is to bring the demo up
|
||||
start_demo
|
||||
fi
|
||||
|
|
@ -31,7 +31,7 @@ endpoints:
|
|||
# Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way
|
||||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
provider_interface: openai
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
|
|
@ -46,12 +46,12 @@ llm_providers:
|
|||
unit: minute
|
||||
|
||||
- name: Mistral8x7b
|
||||
provider: mistral
|
||||
provider_interface: openai
|
||||
access_key: $MISTRAL_API_KEY
|
||||
model: mistral-8x7b
|
||||
|
||||
- name: MistralLocal7b
|
||||
provider: local
|
||||
provider_interface: openai
|
||||
model: mistral-7b-instruct
|
||||
endpoint: mistral_local
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue