Add workflow logic for weather forecast demo (#24)

This commit is contained in:
Adil Hafeez 2024-07-30 16:23:23 -07:00 committed by GitHub
parent 7ef68eccfb
commit 33f9dd22e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 1902 additions and 459 deletions

6
.gitignore vendored
View file

@ -1,3 +1,9 @@
envoyfilter/target
envoyfilter/qdrant_data/
embedding-server/venv/
chatbot-ui/venv/
__pycache__
grafana-data
prom_data
.env
qdrant_data

View file

@ -16,11 +16,11 @@ repos:
name: cargo-clippy
language: system
types: [file, rust]
entry: bash -c "cd envoyfilter && cargo clippy --all"
entry: bash -c "cd envoyfilter && cargo clippy -p intelligent-prompt-gateway --all"
- id: cargo-test
name: cargo-test
language: system
types: [file, rust]
# --lib is to only test the library, since when integration tests are made,
# they will be in a seperate tests directory
entry: bash -c "cd envoyfilter && cargo test --lib"
entry: bash -c "cd envoyfilter && cargo test -p intelligent-prompt-gateway --lib"

24
chatbot-ui/Dockerfile Normal file
View file

@ -0,0 +1,24 @@
# copied from https://github.com/bergos/embedding-server
FROM python:3 AS base
#
# builder
#
FROM base AS builder
WORKDIR /src
COPY requirements.txt /src/
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
COPY . /src
FROM python:3-slim AS output
COPY --from=builder /runtime /usr/local
COPY /app /app
WORKDIR /app
CMD ["python", "run.py"]

81
chatbot-ui/app/run.py Normal file
View file

@ -0,0 +1,81 @@
import gradio as gr
import asyncio
import httpx
import async_timeout
from loguru import logger
from typing import Optional, List
from pydantic import BaseModel
from dotenv import load_dotenv
import os
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1/chat/completions")
class Message(BaseModel):
role: str
content: str
async def make_completion(messages:List[Message], nb_retries:int=3, delay:int=30) -> Optional[str]:
"""
Sends a request to the ChatGPT API to retrieve a response based on a list of previous messages.
"""
header = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}"
}
try:
async with async_timeout.timeout(delay=delay):
async with httpx.AsyncClient(headers=header) as aio_client:
counter = 0
keep_loop = True
while keep_loop:
logger.debug(f"Chat/Completions Nb Retries : {counter}")
try:
resp = await aio_client.post(
url = CHAT_COMPLETION_ENDPOINT,
json = {
"model": "gpt-3.5-turbo",
"messages": messages
}
)
logger.debug(f"Status Code : {resp.status_code}")
if resp.status_code == 200:
return resp.json()["choices"][0]["message"]["content"]
else:
logger.warning(resp.content)
keep_loop = False
except Exception as e:
logger.error(e)
counter = counter + 1
keep_loop = counter < nb_retries
except asyncio.TimeoutError as e:
logger.error(f"Timeout {delay} seconds !")
return None
async def predict(input, history):
"""
Predict the response of the chatbot and complete a running list of chat history.
"""
history.append({"role": "user", "content": input})
print(history)
response = await make_completion(history)
history.append({"role": "assistant", "content": response})
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
return messages, history
"""
Gradio Blocks low-level API that allows to create custom web applications (here our chat app)
"""
with gr.Blocks() as demo:
logger.info("Starting Demo...")
chatbot = gr.Chatbot(label="WebGPT")
state = gr.State([])
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter")
txt.submit(predict, [txt, state], [chatbot, state])
demo.launch(server_name="0.0.0.0", server_port=8080)

View file

@ -0,0 +1,6 @@
gradio==4.39.0
async_timeout==4.0.3
loguru==0.7.2
asyncio==3.4.3
httpx==0.27.0
python-dotenv==1.0.1

View file

@ -0,0 +1,15 @@
# Weather forecasting
This demo shows how you can use intelligent prompt gateway to provide realtime weather forecast.
# Startig the demo
1. Create `.env` file and set OpenAI key using env var `OPENAI_API_KEY`
1. Start services
```sh
$ docker compose up
```
1. Navigate to http://localhost:18080/
1. You can type in queries like "how is the weather in Seattle"
1. You can also ask follow up questions like "show me sunny days"
2. To see metrics navigate to "http://localhost:3000/" (use admin/grafana for login)
1. Open up dahsboard named "Intelligent Gateway Overview"
2. On this dashboard you can see reuqest latency and number of requests

View file

@ -0,0 +1,85 @@
services:
envoy:
build:
context: ../../
dockerfile: envoyfilter/Dockerfile
hostname: envoy
ports:
- "10000:10000"
- "19901:9901"
volumes:
- ./envoy.yaml:/etc/envoy/envoy.yaml
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
networks:
- envoymesh
depends_on:
embeddingserver:
condition: service_healthy
embeddingserver:
build:
context: ../../embedding-server
dockerfile: Dockerfile
ports:
- "18081:80"
healthcheck:
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
interval: 5s
retries: 20
networks:
- envoymesh
qdrant:
image: qdrant/qdrant
hostname: vector-db
ports:
- 16333:6333
- 16334:6334
networks:
- envoymesh
chatbot-ui:
build:
context: ../../chatbot-ui
dockerfile: Dockerfile
ports:
- "18080:8080"
networks:
- envoymesh
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- CHAT_COMPLETION_ENDPOINT=http://envoy:10000/v1/chat/completions
prometheus:
image: prom/prometheus
container_name: prometheus
command:
- '--config.file=/etc/prometheus/prometheus.yaml'
ports:
- 9090:9090
restart: unless-stopped
volumes:
- ./prometheus:/etc/prometheus
- ./prom_data:/prometheus
networks:
- envoymesh
grafana:
image: grafana/grafana
container_name: grafana
ports:
- 3000:3000
restart: unless-stopped
environment:
- GF_SECURITY_ADMIN_USER=admin
- GF_SECURITY_ADMIN_PASSWORD=grafana
volumes:
- ./grafana:/etc/grafana/provisioning/datasources
- ./grafana/dashboard.yaml:/etc/grafana/provisioning/dashboards/main.yaml
- ./grafana/dashboards:/var/lib/grafana/dashboards
# - ./grafana-data:/var/lib/grafana
networks:
- envoymesh
networks:
envoymesh: {}

View file

@ -0,0 +1,197 @@
admin:
address:
socket_address: { address: 0.0.0.0, port_value: 9901 }
static_resources:
listeners:
address:
socket_address:
address: 0.0.0.0
port_value: 10000
filter_chains:
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: ingress_http
codec_type: AUTO
scheme_header_transformation:
scheme_to_overwrite: https
route_config:
name: local_routes
virtual_hosts:
- name: openai
domains:
- "api.openai.com"
routes:
- match:
prefix: "/"
route:
auto_host_rewrite: true
cluster: openai
- name: local_service
domains:
- "*"
routes:
- match:
prefix: "/v1/chat/completions"
route:
auto_host_rewrite: true
cluster: openai
- match:
prefix: "/embeddings"
route:
cluster: embeddingserver
- match:
prefix: "/"
direct_response:
status: 200
body:
inline_string: "Inspect the HTTP header: custom-header.\n"
http_filters:
- name: envoy.filters.http.wasm
typed_config:
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
value:
config:
name: "http_config"
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: |
katanemo-prompt-config:
default-prompt-endpoint: "127.0.0.1"
load-balancing: "round-robin"
timeout-ms: 5000
embedding-provider:
name: "SentenceTransformer"
model: "all-MiniLM-L6-v2"
llm-providers:
- name: "open-ai-gpt-4"
api-key: "$OPEN_AI_API_KEY"
model: gpt-4
prompt-targets:
- type: context-resolver
name: weather-forecast
few-shot-examples:
- what is the weather in New York?
- how is the weather in San Francisco?
- what is the forecast in Seattle?
entities:
- name: city
required: true
- name: days
endpoint:
cluster: weatherhost
path: /weather
cache-response: true
cache-response-settings:
- cache-ttl-secs: 3600 # cache expiry in seconds
- cache-max-size: 1000 # in number of items
- cache-eviction-strategy: LRU
system-prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
vm_config:
runtime: "envoy.wasm.runtime.v8"
code:
local:
filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm"
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
clusters:
# LLM Host
# Embedding Providers
# External LLM Providers
- name: openai
connect_timeout: 5s
type: LOGICAL_DNS
lb_policy: ROUND_ROBIN
typed_extension_protocol_options:
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
explicit_http_config:
http2_protocol_options: {}
load_assignment:
cluster_name: openai
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: api.openai.com
port_value: 443
hostname: "api.openai.com"
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: api.openai.com
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: embeddingserver
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: embeddingserver
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: embeddingserver
port_value: 80
hostname: "embeddingserver"
- name: weatherhost
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: weatherhost
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: embeddingserver
port_value: 80
hostname: "embeddingserver"
- name: nerhost
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: nerhost
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: embeddingserver
port_value: 80
hostname: "embeddingserver"
- name: qdrant
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: qdrant
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: qdrant
port_value: 6333
hostname: "qdrant"

View file

@ -0,0 +1,12 @@
apiVersion: 1
providers:
- name: "Dashboard provider"
orgId: 1
type: file
disableDeletion: false
updateIntervalSeconds: 10
allowUiUpdates: false
options:
path: /var/lib/grafana/dashboards
foldersFromFilesStructure: true

View file

@ -0,0 +1,355 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "-- Grafana --"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 1,
"links": [],
"panels": [
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 0
},
"id": 2,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "avg(rate(envoy_cluster_internal_upstream_rq_time_sum[1m]) / rate(envoy_cluster_internal_upstream_rq_time_count[1m])) by (envoy_cluster_name)",
"fullMetaSearch": false,
"hide": false,
"includeNullMetadata": true,
"instant": false,
"legendFormat": "__auto",
"range": true,
"refId": "A",
"useBackend": false
}
],
"title": "request latency - internal (ms)",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 0
},
"id": 1,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "avg(rate(envoy_cluster_external_upstream_rq_time_sum[1m]) / rate(envoy_cluster_external_upstream_rq_time_count[1m])) by (envoy_cluster_name)",
"fullMetaSearch": false,
"hide": false,
"includeNullMetadata": true,
"instant": false,
"legendFormat": "__auto",
"range": true,
"refId": "A",
"useBackend": false
}
],
"title": "request latency - external (ms)",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 8
},
"id": 3,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "avg(rate(envoy_cluster_internal_upstream_rq_completed[1m])) by (envoy_cluster_name)",
"fullMetaSearch": false,
"includeNullMetadata": true,
"instant": false,
"legendFormat": "__auto",
"range": true,
"refId": "A",
"useBackend": false
},
{
"datasource": {
"type": "prometheus",
"uid": "PBFA97CFB590B2093"
},
"disableTextWrap": false,
"editorMode": "code",
"expr": "avg(rate(envoy_cluster_external_upstream_rq_completed[1m])) by (envoy_cluster_name)",
"fullMetaSearch": false,
"hide": false,
"includeNullMetadata": true,
"instant": false,
"legendFormat": "__auto",
"range": true,
"refId": "B",
"useBackend": false
}
],
"title": "Upstream request count",
"type": "timeseries"
}
],
"schemaVersion": 39,
"tags": [],
"templating": {
"list": []
},
"time": {
"from": "now-15m",
"to": "now"
},
"timepicker": {},
"timezone": "browser",
"title": "Intelligent Gateway Overview",
"uid": "adt6uhx5lk8aob",
"version": 3,
"weekStart": ""
}

View file

@ -0,0 +1,9 @@
apiVersion: 1
datasources:
- name: Prometheus
type: prometheus
url: http://prometheus:9090
isDefault: true
access: proxy
editable: true

View file

@ -0,0 +1,23 @@
global:
scrape_interval: 15s
scrape_timeout: 10s
evaluation_interval: 15s
alerting:
alertmanagers:
- static_configs:
- targets: []
scheme: http
timeout: 10s
api_version: v1
scrape_configs:
- job_name: envoy
honor_timestamps: true
scrape_interval: 15s
scrape_timeout: 10s
metrics_path: /stats
scheme: http
static_configs:
- targets:
- envoy:9901
params:
format: ['prometheus']

View file

@ -24,6 +24,7 @@ FROM python:3-slim AS output
# following models have been tested to work with this image
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
ENV MODELS="BAAI/bge-large-en-v1.5"
ENV NER_MODELS="urchade/gliner_large-v2.1"
COPY --from=builder /runtime /usr/local

View file

@ -1,3 +1,6 @@
from load_transformers import load_transformers
from load_models import load_transformers, load_ner_models
print('installing transformers')
load_transformers()
print('installing ner models')
load_ner_models()

View file

@ -1,5 +1,6 @@
import os
import sentence_transformers
from gliner import GLiNER
def load_transformers(models = os.getenv("MODELS", "sentence-transformers/all-MiniLM-L6-v2")):
transformers = {}
@ -8,3 +9,11 @@ def load_transformers(models = os.getenv("MODELS", "sentence-transformers/all-Mi
transformers[model] = sentence_transformers.SentenceTransformer(model)
return transformers
def load_ner_models(models = os.getenv("NER_MODELS", "urchade/gliner_large-v2.1")):
ner_models = {}
for model in models.split(','):
ner_models[model] = GLiNER.from_pretrained(model)
return ner_models

View file

@ -1,8 +1,11 @@
import random
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from load_transformers import load_transformers
from load_models import load_ner_models, load_transformers
from datetime import date, timedelta
transformers = load_transformers()
ner_models = load_ner_models()
app = FastAPI()
@ -10,6 +13,12 @@ class EmbeddingRequest(BaseModel):
input: str
model: str
@app.get("/healthz")
async def healthz():
return {
"status": "ok"
}
@app.get("/models")
async def models():
models = []
@ -27,7 +36,7 @@ async def models():
@app.post("/embeddings")
async def embedding(req: EmbeddingRequest, res: Response):
if not req.model in transformers:
if req.model not in transformers:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
embeddings = transformers[req.model].encode([req.input])
@ -51,3 +60,48 @@ async def embedding(req: EmbeddingRequest, res: Response):
"object": "list",
"usage": usage
}
class NERRequest(BaseModel):
input: str
labels: list[str]
model: str
@app.post("/ner")
async def ner(req: NERRequest, res: Response):
if req.model not in ner_models:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
model = ner_models[req.model]
entities = model.predict_entities(req.input, req.labels)
return {
"data": entities,
"model": req.model,
"object": "list",
}
class WeatherRequest(BaseModel):
city: str
@app.post("/weather")
async def weather(req: WeatherRequest, res: Response):
weather_forecast = {
"city": req.city,
"temperature": [],
"unit": "F",
}
for i in range(7):
min_temp = random.randrange(50,90)
max_temp = random.randrange(min_temp+5, min_temp+20)
weather_forecast["temperature"].append({
"date": str(date.today() + timedelta(days=i)),
"temperature": {
"min": min_temp,
"max": max_temp
}
})
return weather_forecast

View file

@ -3,3 +3,4 @@ fastapi
sentence-transformers
torch
uvicorn
gliner

24
envoyfilter/Cargo.lock generated
View file

@ -1621,9 +1621,9 @@ dependencies = [
[[package]]
name = "serde_spanned"
version = "0.6.6"
version = "0.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0"
checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d"
dependencies = [
"serde",
]
@ -1936,9 +1936,9 @@ dependencies = [
[[package]]
name = "toml"
version = "0.8.15"
version = "0.8.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac2caab0bf757388c6c0ae23b3293fdb463fee59434529014f85e3263b995c28"
checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c"
dependencies = [
"serde",
"serde_spanned",
@ -1948,18 +1948,18 @@ dependencies = [
[[package]]
name = "toml_datetime"
version = "0.6.6"
version = "0.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf"
checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db"
dependencies = [
"serde",
]
[[package]]
name = "toml_edit"
version = "0.22.16"
version = "0.22.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "278f3d518e152219c994ce877758516bca5e118eaed6996192a774fb9fbf0788"
checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16"
dependencies = [
"indexmap",
"serde",
@ -2121,9 +2121,9 @@ checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
[[package]]
name = "version_check"
version = "0.9.4"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "want"
@ -2730,9 +2730,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.6.15"
version = "0.6.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "557404e450152cd6795bb558bca69e43c585055f4606e3bcae5894fc6dac9ba0"
checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c"
dependencies = [
"memchr",
]

16
envoyfilter/Dockerfile Normal file
View file

@ -0,0 +1,16 @@
# build filter using rust toolchain
FROM rust:1.80.0 as builder
WORKDIR /envoyfilter
COPY envoyfilter/src /envoyfilter/src
COPY envoyfilter/Cargo.toml /envoyfilter/
COPY envoyfilter/Cargo.lock /envoyfilter/
COPY open-message-format /open-message-format
RUN rustup -v target add wasm32-wasi
RUN cargo build --release --target wasm32-wasi
# copy built filter into envoy image
FROM envoyproxy/envoy:v1.30-latest
COPY --from=builder /envoyfilter/target/wasm32-wasi/release/intelligent_prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm
COPY envoyfilter/envoy.yaml /etc/envoy.yaml
CMD ["envoy", "-c", "/etc/envoy/envoy.yaml"]

View file

@ -0,0 +1,3 @@
RUST_VERSION=1.80.0
docker run --rm -v rustup_cache:/usr/local/rustup/ rust:$RUST_VERSION rustup -v target add wasm32-wasi
docker run --rm -v $PWD/../open-message-format:/code/open-message-format -v ~/.cargo:/root/.cargo -v $(pwd):/code/envoyfilter -w /code/envoyfilter -v rustup_cache:/usr/local/rustup/ rust:$RUST_VERSION cargo build --release --target wasm32-wasi

View file

@ -22,7 +22,7 @@ services:
ports:
- "18080:80"
healthcheck:
test: ["CMD", "curl" ,"http://localhost:80"]
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
interval: 5s
retries: 20
networks:

View file

@ -41,10 +41,6 @@ static_resources:
prefix: "/embeddings"
route:
cluster: embeddingserver
- match:
prefix: "/inline"
route:
cluster: httpbin
- match:
prefix: "/"
direct_response:
@ -78,7 +74,7 @@ static_resources:
model: gpt-4
system-prompt: |
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
@ -88,12 +84,22 @@ static_resources:
name: weather-forecast
few-shot-examples:
- what is the weather in New York?
endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT"
- how is the weather in San Francisco?
- what is the forecast in Seattle?
entities:
- city
endpoint:
cluster: weatherhost
path: /weather
cache-response: true
cache-response-settings:
- cache-ttl-secs: 3600 # cache expiry in seconds
- cache-max-size: 1000 # in number of items
- cache-eviction-strategy: LRU
system-prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature
- Use miles per hour for wind speed
vm_config:
runtime: "envoy.wasm.runtime.v8"
@ -136,20 +142,6 @@ static_resources:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
- name: httpbin
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: httpbin
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: httpbin.org
port_value: 80
hostname: "httpbin.org"
- name: embeddingserver
connect_timeout: 5s
type: STRICT_DNS
@ -164,6 +156,34 @@ static_resources:
address: embeddingserver
port_value: 80
hostname: "embeddingserver"
- name: weatherhost
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: weatherhost
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: embeddingserver
port_value: 80
hostname: "embeddingserver"
- name: nerhost
connect_timeout: 5s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: nerhost
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: embeddingserver
port_value: 80
hostname: "embeddingserver"
- name: qdrant
connect_timeout: 5s
type: STRICT_DNS

View file

@ -0,0 +1,9 @@
apiVersion: 1
datasources:
- name: Prometheus
type: prometheus
url: http://prometheus:9090
isDefault: true
access: proxy
editable: true

View file

@ -0,0 +1,23 @@
global:
scrape_interval: 15s
scrape_timeout: 10s
evaluation_interval: 15s
alerting:
alertmanagers:
- static_configs:
- targets: []
scheme: http
timeout: 10s
api_version: v1
scrape_configs:
- job_name: envoy
honor_timestamps: true
scrape_interval: 15s
scrape_timeout: 10s
metrics_path: /stats
scheme: http
static_configs:
- targets:
- envoy:9901
params:
format: ['prometheus']

View file

@ -23,9 +23,53 @@ pub struct StoreVectorEmbeddingsRequest {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(clippy::large_enum_variant)]
pub enum CallContext {
EmbeddingRequest(EmbeddingRequest),
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
CreateVectorCollection(String),
}
// https://api.qdrant.tech/master/api-reference/search/points
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointsRequest {
pub vector: Vec<f64>,
pub limit: i32,
pub with_payload: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointResult {
pub id: String,
pub version: i32,
pub score: f64,
pub payload: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointsResponse {
pub result: Vec<SearchPointResult>,
pub status: String,
pub time: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NERRequest {
pub input: String,
pub labels: Vec<String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Entity {
pub text: String,
pub label: String,
pub score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NERResponse {
pub data: Vec<Entity>,
pub model: String,
}
pub mod open_ai {
@ -41,6 +85,6 @@ pub mod open_ai {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
pub content: Option<String>,
}
}

View file

@ -28,7 +28,7 @@ pub struct PromptConfig {
pub timeout_ms: u64,
pub embedding_provider: EmbeddingProviver,
pub llm_providers: Vec<LlmProvider>,
pub system_prompt: String,
pub system_prompt: Option<String>,
pub prompt_targets: Vec<PromptTarget>,
}
@ -49,6 +49,29 @@ pub struct LlmProvider {
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct Endpoint {
pub cluster: String,
pub path: Option<String>,
pub method: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct EntityDetail {
pub name: String,
pub required: Option<bool>,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EntityType {
Vec(Vec<String>),
Struct(Vec<EntityDetail>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct PromptTarget {
@ -56,7 +79,9 @@ pub struct PromptTarget {
pub prompt_type: String,
pub name: String,
pub few_shot_examples: Vec<String>,
pub endpoint: String,
pub entities: Option<EntityType>,
pub endpoint: Option<Endpoint>,
pub system_prompt: Option<String>,
}
#[cfg(test)]
@ -88,13 +113,23 @@ katanemo-prompt-config:
name: weather-forecast
few-shot-examples:
- what is the weather in New York?
endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT"
cache-response: true
cache-response-settings:
- cache-ttl-secs: 3600 # cache expiry in seconds
- cache-max-size: 1000 # in number of items
- cache-eviction-strategy: LRU
endpoint:
cluster: weatherhost
path: /weather
entities:
- name: location
required: true
description: "The location for which the weather is requested"
- type: context-resolver
name: weather-forecast-2
few-shot-examples:
- what is the weather in New York?
endpoint:
cluster: weatherhost
path: /weather
entities:
- city
"#;
#[test]

View file

@ -1 +1,7 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
pub const DEFAULT_NER_MODEL: &str = "urchade/gliner_large-v2.1";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
pub const DEFAULT_NER_THRESHOLD: f64 = 0.6;
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";

View file

@ -0,0 +1,289 @@
use common_types::{CallContext, EmbeddingRequest};
use configuration::PromptTarget;
use log::info;
use md5::Digest;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::to_string;
use stats::RecordingMetric;
use std::collections::HashMap;
use std::time::Duration;
use consts::DEFAULT_EMBEDDING_MODEL;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use crate::common_types;
use crate::configuration;
use crate::consts;
use crate::stats;
use crate::stream_context::StreamContext;
#[derive(Copy, Clone)]
struct WasmMetrics {
active_http_calls: stats::Gauge,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
}
}
}
pub struct FilterContext {
metrics: WasmMetrics,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, common_types::CallContext>,
config: Option<configuration::Configuration>,
}
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
metrics: WasmMetrics::new(),
}
}
fn process_prompt_targets(&mut self) {
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
for few_shot_example in &prompt_target.few_shot_examples {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(
few_shot_example.to_string(),
)),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
// TODO: Handle potential errors
let json_data: String = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
let embedding_request = EmbeddingRequest {
create_embedding_request: embeddings_input,
prompt_target: prompt_target.clone(),
};
if self
.callouts
.insert(token_id, {
CallContext::EmbeddingRequest(embedding_request)
})
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
}
fn embedding_request_handler(
&mut self,
body_size: usize,
create_embedding_request: CreateEmbeddingRequest,
prompt_target: PromptTarget,
) {
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let embedding_response: CreateEmbeddingResponse =
serde_json::from_slice(&body).unwrap();
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
to_string(&prompt_target).unwrap(),
);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(input) => {
id = Some(md5::compute(&input));
payload.insert("input".to_string(), input);
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],
};
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
}
fn create_vector_store_points_handler(&self, body_size: usize) {
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
info!(
"response body: len {:?}",
String::from_utf8(body).unwrap().len()
);
}
}
}
//TODO: run once per envoy instance, right now it runs once per worker
fn init_vector_store(&mut self) {
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(b"{ \"vectors\": { \"size\": 1024, \"distance\": \"Cosine\"}}"),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for init-vector-store: {:?}", e);
}
};
if self
.callouts
.insert(
token_id,
CallContext::CreateVectorCollection("prompt_vector_store".to_string()),
)
.is_some()
{
panic!("duplicate token_id")
}
// self.metrics
// .active_http_calls
// .record(self.callouts.len().try_into().unwrap());
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
}
common_types::CallContext::StoreVectorEmbeddings(_) => {
self.create_vector_store_points_handler(body_size)
}
common_types::CallContext::CreateVectorCollection(_) => {
let mut http_status_code = "Nil".to_string();
self.get_http_call_response_headers()
.iter()
.for_each(|(k, v)| {
if k == ":status" {
http_status_code.clone_from(v);
}
});
info!("CreateVectorCollection response: {}", http_status_code);
}
}
}
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
}
true
}
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
Some(Box::new(StreamContext {
host_header: None,
callouts: HashMap::new(),
}))
}
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _: usize) -> bool {
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
// initialize vector store
self.init_vector_store();
self.process_prompt_targets();
self.set_tick_period(Duration::from_secs(0));
}
}

View file

@ -1,22 +1,13 @@
use common_types::{CallContext, EmbeddingRequest};
use configuration::PromptTarget;
use http::StatusCode;
use log::error;
use log::info;
use md5::Digest;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use filter_context::FilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use stats::{Gauge, RecordingMetric};
use std::collections::HashMap;
use std::time::Duration;
mod common_types;
mod configuration;
mod consts;
mod filter_context;
mod stats;
mod stream_context;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
@ -24,397 +15,3 @@ proxy_wasm::main! {{
Box::new(FilterContext::new())
});
}}
struct StreamContext {
host_header: Option<String>,
}
impl StreamContext {
fn save_host_header(&mut self) {
// Save the host header to be used by filter logic later on.
self.host_header = self.get_http_request_header(":host");
}
fn delete_content_length_header(&mut self) {
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
// Server's generally throw away requests whose body length do not match the Content-Length header.
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
// manipulate the body in benign ways e.g., compression.
self.set_http_request_header("content-length", None);
}
fn modify_path_header(&mut self) {
match self.get_http_request_header(":path") {
// The gateway can start gathering information necessary for routing. For now change the path to an
// OpenAI API path.
Some(path) if path == "/llmrouting" => {
self.set_http_request_header(":path", Some("/v1/chat/completions"));
}
// Otherwise let the filter continue.
_ => (),
}
}
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
self.save_host_header();
self.delete_content_length_header();
self.modify_path_header();
Action::Continue
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
// Let the client send the gateway all the data before sending to the LLM_provider.
// TODO: consider a streaming API.
if !end_of_stream {
return Action::Pause;
}
if body_size == 0 {
return Action::Continue;
}
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: common_types::open_ai::ChatCompletions =
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
vec![],
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
);
return Action::Pause;
}
},
None => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!(
"Failed to obtain body bytes even though body_size is {}",
body_size
);
return Action::Pause;
}
};
// Modify JSON payload
deserialized_body.model = String::from("gpt-3.5-turbo");
match serde_json::to_string(&deserialized_body) {
Ok(json_string) => {
self.set_http_request_body(0, body_size, &json_string.into_bytes());
Action::Continue
}
Err(error) => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!("Failed to serialize body: {}", error);
Action::Pause
}
}
}
}
impl Context for StreamContext {}
#[derive(Copy, Clone)]
struct WasmMetrics {
active_http_calls: Gauge,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
}
}
}
struct FilterContext {
metrics: WasmMetrics,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, common_types::CallContext>,
config: Option<configuration::Configuration>,
}
impl FilterContext {
fn new() -> FilterContext {
FilterContext {
callouts: HashMap::new(),
config: None,
metrics: WasmMetrics::new(),
}
}
fn process_prompt_targets(&mut self) {
for prompt_target in &self
.config
.as_ref()
.expect("Gateway configuration cannot be non-existent")
.prompt_config
.prompt_targets
{
for few_shot_example in &prompt_target.few_shot_examples {
info!("few_shot_example: {:?}", few_shot_example);
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(
few_shot_example.to_string(),
)),
model: String::from(consts::DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data: String = match serde_json::to_string(&embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
}
};
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching embedding server HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
let embedding_request = EmbeddingRequest {
create_embedding_request: embeddings_input,
prompt_target: prompt_target.clone(),
};
if self
.callouts
.insert(token_id, {
CallContext::EmbeddingRequest(embedding_request)
})
.is_some()
{
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
}
fn embedding_request_handler(
&mut self,
body_size: usize,
create_embedding_request: CreateEmbeddingRequest,
prompt_target: PromptTarget,
) {
info!("response received for CreateEmbeddingRequest");
let body = match self.get_http_call_response_body(0, body_size) {
Some(body) => body,
None => {
return;
}
};
if body.is_empty() {
return;
}
let (json_data, create_vector_store_points) =
match build_qdrant_data(&body, create_embedding_request, &prompt_target) {
Ok(tuple) => tuple,
Err(error) => {
panic!(
"Error building qdrant data from embedding response {}",
error
);
}
};
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching qdrant HTTP call: {:?}", e);
}
};
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id={} in qdrant requests", token_id)
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
// TODO: @adilhafeez implement.
fn create_vector_store_points_handler(&self, body_size: usize) {
info!("response received for CreateVectorStorePoints");
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
info!("response body: {:?}", String::from_utf8(body).unwrap());
}
}
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
info!("on_http_call_response: token_id = {}", token_id);
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
}
common_types::CallContext::StoreVectorEmbeddings(_) => {
self.create_vector_store_points_handler(body_size)
}
}
}
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = match serde_yaml::from_slice(&config_bytes) {
Ok(config) => config,
Err(error) => {
panic!("Failed to deserialize plugin configuration: {}", error);
}
};
info!("on_configure: plugin configuration loaded");
}
true
}
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
Some(Box::new(StreamContext { host_header: None }))
}
fn get_type(&self) -> Option<ContextType> {
Some(ContextType::HttpContext)
}
fn on_vm_start(&mut self, _: usize) -> bool {
info!("on_vm_start: setting up tick timeout");
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
info!("on_tick: starting to process prompt targets");
self.process_prompt_targets();
self.set_tick_period(Duration::from_secs(0));
}
}
fn build_qdrant_data(
embedding_data: &[u8],
create_embedding_request: CreateEmbeddingRequest,
prompt_target: &PromptTarget,
) -> Result<(String, common_types::StoreVectorEmbeddingsRequest), serde_json::Error> {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(embedding_data) {
Ok(embedding_response) => embedding_response,
Err(error) => {
panic!("Failed to deserialize embedding response: {}", error);
}
};
info!(
"embedding_response model: {}, vector len: {}",
embedding_response.model,
embedding_response.data[0].embedding.len()
);
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
serde_json::to_string(&prompt_target)?,
);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(ref input) => {
id = Some(md5::compute(input));
payload.insert("input".to_string(), input.clone());
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
points: vec![common_types::VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data[0].embedding.clone(),
}],
};
let json_data = serde_json::to_string(&create_vector_store_points)?;
info!(
"create_vector_store_points: points length: {}",
embedding_response.data[0].embedding.len()
);
Ok((json_data, create_vector_store_points))
}

View file

@ -0,0 +1,520 @@
use http::StatusCode;
use log::error;
use log::info;
use log::warn;
use open_message_format::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use std::collections::HashMap;
use std::time::Duration;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use consts::{
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
};
use crate::common_types;
use crate::common_types::open_ai::Message;
use crate::common_types::SearchPointsResponse;
use crate::configuration::EntityDetail;
use crate::configuration::EntityType;
use crate::configuration::PromptTarget;
use crate::consts;
enum RequestType {
GetEmbedding,
SearchPoints,
Ner,
ContextResolver,
}
pub struct CallContext {
request_type: RequestType,
user_message: String,
prompt_target: Option<PromptTarget>,
request_body: common_types::open_ai::ChatCompletions,
}
pub struct StreamContext {
pub host_header: Option<String>,
pub callouts: HashMap<u32, CallContext>,
}
impl StreamContext {
fn save_host_header(&mut self) {
// Save the host header to be used by filter logic later on.
self.host_header = self.get_http_request_header(":host");
}
fn delete_content_length_header(&mut self) {
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
// Server's generally throw away requests whose body length do not match the Content-Length header.
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
// manipulate the body in benign ways e.g., compression.
self.set_http_request_header("content-length", None);
}
fn modify_path_header(&mut self) {
match self.get_http_request_header(":path") {
// The gateway can start gathering information necessary for routing. For now change the path to an
// OpenAI API path.
Some(path) if path == "/llmrouting" => {
self.set_http_request_header(":path", Some("/v1/chat/completions"));
}
// Otherwise let the filter continue.
_ => (),
}
}
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response,
Err(e) => {
warn!("Error deserializing embedding response: {:?}", e);
self.resume_http_request();
return;
}
};
let search_points_request = common_types::SearchPointsRequest {
vector: embedding_response.data[0].embedding.clone(),
limit: 10,
with_payload: true,
};
let json_data: String = match serde_json::to_string(&search_points_request) {
Ok(json_data) => json_data,
Err(e) => {
warn!("Error serializing search_points_request: {:?}", e);
self.reset_http_request();
return;
}
};
let path = format!("/collections/{}/points/search", DEFAULT_COLLECTION_NAME);
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "POST"),
(":path", &path),
(":authority", "qdrant"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
}
};
callout_context.request_type = RequestType::SearchPoints;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
}
fn search_points_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let search_points_response: SearchPointsResponse = match serde_json::from_slice(&body) {
Ok(search_points_response) => search_points_response,
Err(e) => {
warn!("Error deserializing search_points_response: {:?}", e);
self.resume_http_request();
return;
}
};
let search_results = &search_points_response.result;
if search_results.is_empty() {
info!("No prompt target matched");
self.resume_http_request();
return;
}
info!("similarity score: {}", search_results[0].score);
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD {
info!(
"prompt target below threshold: {}",
DEFAULT_PROMPT_TARGET_THRESHOLD
);
self.resume_http_request();
return;
}
let prompt_target_str = search_results[0].payload.get("prompt-target").unwrap();
let prompt_target: PromptTarget = match serde_json::from_slice(prompt_target_str.as_bytes())
{
Ok(prompt_target) => prompt_target,
Err(e) => {
warn!("Error deserializing prompt_target: {:?}", e);
self.resume_http_request();
return;
}
};
info!("prompt_target name: {:?}", prompt_target.name);
// only extract entity names
let entity_names = get_entity_details(&prompt_target)
.iter()
.map(|entity| entity.name.clone())
.collect();
let user_message = callout_context.user_message.clone();
let ner_request = common_types::NERRequest {
input: user_message,
labels: entity_names,
model: DEFAULT_NER_MODEL.to_string(),
};
let json_data: String = match serde_json::to_string(&ner_request) {
Ok(json_data) => json_data,
Err(e) => {
warn!("Error serializing ner_request: {:?}", e);
self.resume_http_request();
return;
}
};
let token_id = match self.dispatch_http_call(
"nerhost",
vec![
(":method", "POST"),
(":path", "/ner"),
(":authority", "nerhost"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
}
};
callout_context.request_type = RequestType::Ner;
callout_context.prompt_target = Some(prompt_target);
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
}
fn ner_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let ner_response: common_types::NERResponse = match serde_json::from_slice(&body) {
Ok(ner_response) => ner_response,
Err(e) => {
warn!("Error deserializing ner_response: {:?}", e);
self.resume_http_request();
return;
}
};
info!("ner_response: {:?}", ner_response);
let mut request_params: HashMap<String, String> = HashMap::new();
for entity in ner_response.data.iter() {
if entity.score < DEFAULT_NER_THRESHOLD {
warn!(
"score of entity was too low entity name: {}, score: {}",
entity.label, entity.score
);
continue;
}
request_params.insert(entity.label.clone(), entity.text.clone());
}
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
let entity_details = get_entity_details(prompt_target);
for entity in entity_details {
if entity.required.unwrap_or(false) && !request_params.contains_key(&entity.name) {
warn!(
"required entity missing or score of entity was too low: {}",
entity.name
);
self.resume_http_request();
return;
}
}
let req_param_str = match serde_json::to_string(&request_params) {
Ok(req_param_str) => req_param_str,
Err(e) => {
warn!("Error serializing request_params: {:?}", e);
self.resume_http_request();
return;
}
};
let endpoint = callout_context
.prompt_target
.as_ref()
.unwrap()
.endpoint
.as_ref()
.unwrap();
let http_path = match &endpoint.path {
Some(path) => path.clone(),
None => "/".to_string(),
};
let http_method = match &endpoint.method {
Some(method) => method.clone(),
None => http::Method::POST.to_string(),
};
let token_id = match self.dispatch_http_call(
&endpoint.cluster.clone(),
vec![
(":method", http_method.as_str()),
(":path", http_path.as_str()),
(":authority", endpoint.cluster.as_str()),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(req_param_str.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for context-resolver: {:?}", e);
}
};
callout_context.request_type = RequestType::ContextResolver;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
}
fn context_resolver_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
info!("response received for context-resolver");
let body_string = String::from_utf8(body);
let prompt_target = callout_context.prompt_target.unwrap();
let mut request_body = callout_context.request_body;
match prompt_target.system_prompt {
None => {}
Some(system_prompt) => {
let system_prompt_message: Message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt),
};
request_body.messages.push(system_prompt_message);
}
}
match body_string {
Ok(body_string) => {
info!("context-resolver response: {}", body_string);
let context_resolver_response = Message {
role: USER_ROLE.to_string(),
content: Some(body_string),
};
request_body.messages.push(context_resolver_response);
}
Err(e) => {
warn!("Error converting response to string: {:?}", e);
self.resume_http_request();
return;
}
}
let json_string = match serde_json::to_string(&request_body) {
Ok(json_string) => json_string,
Err(e) => {
warn!("Error serializing request_body: {:?}", e);
self.resume_http_request();
return;
}
};
info!("sending request to openai: msg len: {}", json_string.len());
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
self.resume_http_request();
}
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
self.save_host_header();
self.delete_content_length_header();
self.modify_path_header();
Action::Continue
}
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
// Let the client send the gateway all the data before sending to the LLM_provider.
// TODO: consider a streaming API.
if !end_of_stream {
return Action::Pause;
}
if body_size == 0 {
return Action::Continue;
}
// Deserialize body into spec.
// Currently OpenAI API.
let deserialized_body: common_types::open_ai::ChatCompletions =
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
vec![],
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
);
return Action::Pause;
}
},
None => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!(
"Failed to obtain body bytes even though body_size is {}",
body_size
);
return Action::Pause;
}
};
let user_message = match deserialized_body
.messages
.last()
.and_then(|last_message| last_message.content.as_ref())
{
Some(content) => content,
None => {
info!("No messages in the request body");
return Action::Continue;
}
};
let get_embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
}
};
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!(
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
e
);
}
};
let call_context = CallContext {
request_type: RequestType::GetEmbedding,
user_message: user_message.clone(),
prompt_target: None,
request_body: deserialized_body,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
Action::Pause
}
}
impl Context for StreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
let callout_context = self.callouts.remove(&token_id).expect("invalid token_id");
let resp = self.get_http_call_response_body(0, body_size);
if resp.is_none() {
warn!("No response body");
self.resume_http_request();
return;
}
let body = match resp {
Some(body) => body,
None => {
warn!("Empty response body");
self.resume_http_request();
return;
}
};
match callout_context.request_type {
RequestType::GetEmbedding => {
self.embeddings_handler(body, callout_context);
}
RequestType::SearchPoints => {
self.search_points_handler(body, callout_context);
}
RequestType::Ner => {
self.ner_handler(body, callout_context);
}
RequestType::ContextResolver => {
self.context_resolver_handler(body, callout_context);
}
}
}
}
fn get_entity_details(prompt_target: &PromptTarget) -> Vec<EntityDetail> {
match prompt_target.entities.as_ref() {
Some(EntityType::Vec(entity_names)) => {
let mut entity_details: Vec<EntityDetail> = Vec::new();
for entity_name in entity_names {
entity_details.push(EntityDetail {
name: entity_name.clone(),
required: Some(true),
description: None,
});
}
entity_details
}
Some(EntityType::Struct(entity_details)) => entity_details.clone(),
None => Vec::new(),
}
}

View file

@ -90,7 +90,7 @@ fn successful_request_to_open_ai_chat_completions() {
.returning(Some(chat_completions_request_body))
// TODO: assert that the model field was added.
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
}

View file

@ -13,17 +13,17 @@
"path": "embedding-server"
},
{
"name": "chatbot-client",
"path": "chatbot-client"
"name": "chatbot-ui",
"path": "chatbot-ui"
},
{
"name": "open-message-format",
"path": "open-message-format"
},
{
"name": "demos",
"path": "./demos"
},
"name": "demos/weather-forecast",
"path": "./demos/weather-forecast",
}
],
"settings": {}
}