mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Add workflow logic for weather forecast demo (#24)
This commit is contained in:
parent
7ef68eccfb
commit
33f9dd22e6
32 changed files with 1902 additions and 459 deletions
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -1,3 +1,9 @@
|
|||
envoyfilter/target
|
||||
envoyfilter/qdrant_data/
|
||||
embedding-server/venv/
|
||||
chatbot-ui/venv/
|
||||
__pycache__
|
||||
grafana-data
|
||||
prom_data
|
||||
.env
|
||||
qdrant_data
|
||||
|
|
|
|||
|
|
@ -16,11 +16,11 @@ repos:
|
|||
name: cargo-clippy
|
||||
language: system
|
||||
types: [file, rust]
|
||||
entry: bash -c "cd envoyfilter && cargo clippy --all"
|
||||
entry: bash -c "cd envoyfilter && cargo clippy -p intelligent-prompt-gateway --all"
|
||||
- id: cargo-test
|
||||
name: cargo-test
|
||||
language: system
|
||||
types: [file, rust]
|
||||
# --lib is to only test the library, since when integration tests are made,
|
||||
# they will be in a seperate tests directory
|
||||
entry: bash -c "cd envoyfilter && cargo test --lib"
|
||||
entry: bash -c "cd envoyfilter && cargo test -p intelligent-prompt-gateway --lib"
|
||||
|
|
|
|||
24
chatbot-ui/Dockerfile
Normal file
24
chatbot-ui/Dockerfile
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
# copied from https://github.com/bergos/embedding-server
|
||||
|
||||
FROM python:3 AS base
|
||||
|
||||
#
|
||||
# builder
|
||||
#
|
||||
FROM base AS builder
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
COPY requirements.txt /src/
|
||||
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
|
||||
|
||||
COPY . /src
|
||||
|
||||
FROM python:3-slim AS output
|
||||
|
||||
COPY --from=builder /runtime /usr/local
|
||||
|
||||
COPY /app /app
|
||||
WORKDIR /app
|
||||
|
||||
CMD ["python", "run.py"]
|
||||
81
chatbot-ui/app/run.py
Normal file
81
chatbot-ui/app/run.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import gradio as gr
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import async_timeout
|
||||
|
||||
from loguru import logger
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import os
|
||||
load_dotenv()
|
||||
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1/chat/completions")
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
async def make_completion(messages:List[Message], nb_retries:int=3, delay:int=30) -> Optional[str]:
|
||||
"""
|
||||
Sends a request to the ChatGPT API to retrieve a response based on a list of previous messages.
|
||||
"""
|
||||
header = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}"
|
||||
}
|
||||
try:
|
||||
async with async_timeout.timeout(delay=delay):
|
||||
async with httpx.AsyncClient(headers=header) as aio_client:
|
||||
counter = 0
|
||||
keep_loop = True
|
||||
while keep_loop:
|
||||
logger.debug(f"Chat/Completions Nb Retries : {counter}")
|
||||
try:
|
||||
resp = await aio_client.post(
|
||||
url = CHAT_COMPLETION_ENDPOINT,
|
||||
json = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": messages
|
||||
}
|
||||
)
|
||||
logger.debug(f"Status Code : {resp.status_code}")
|
||||
if resp.status_code == 200:
|
||||
return resp.json()["choices"][0]["message"]["content"]
|
||||
else:
|
||||
logger.warning(resp.content)
|
||||
keep_loop = False
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
counter = counter + 1
|
||||
keep_loop = counter < nb_retries
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.error(f"Timeout {delay} seconds !")
|
||||
return None
|
||||
|
||||
async def predict(input, history):
|
||||
"""
|
||||
Predict the response of the chatbot and complete a running list of chat history.
|
||||
"""
|
||||
history.append({"role": "user", "content": input})
|
||||
print(history)
|
||||
response = await make_completion(history)
|
||||
history.append({"role": "assistant", "content": response})
|
||||
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
|
||||
return messages, history
|
||||
|
||||
"""
|
||||
Gradio Blocks low-level API that allows to create custom web applications (here our chat app)
|
||||
"""
|
||||
with gr.Blocks() as demo:
|
||||
logger.info("Starting Demo...")
|
||||
chatbot = gr.Chatbot(label="WebGPT")
|
||||
state = gr.State([])
|
||||
with gr.Row():
|
||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter")
|
||||
txt.submit(predict, [txt, state], [chatbot, state])
|
||||
|
||||
demo.launch(server_name="0.0.0.0", server_port=8080)
|
||||
6
chatbot-ui/requirements.txt
Normal file
6
chatbot-ui/requirements.txt
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
gradio==4.39.0
|
||||
async_timeout==4.0.3
|
||||
loguru==0.7.2
|
||||
asyncio==3.4.3
|
||||
httpx==0.27.0
|
||||
python-dotenv==1.0.1
|
||||
15
demos/weather-forecast/README.md
Normal file
15
demos/weather-forecast/README.md
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# Weather forecasting
|
||||
This demo shows how you can use intelligent prompt gateway to provide realtime weather forecast.
|
||||
|
||||
# Startig the demo
|
||||
1. Create `.env` file and set OpenAI key using env var `OPENAI_API_KEY`
|
||||
1. Start services
|
||||
```sh
|
||||
$ docker compose up
|
||||
```
|
||||
1. Navigate to http://localhost:18080/
|
||||
1. You can type in queries like "how is the weather in Seattle"
|
||||
1. You can also ask follow up questions like "show me sunny days"
|
||||
2. To see metrics navigate to "http://localhost:3000/" (use admin/grafana for login)
|
||||
1. Open up dahsboard named "Intelligent Gateway Overview"
|
||||
2. On this dashboard you can see reuqest latency and number of requests
|
||||
85
demos/weather-forecast/docker-compose.yaml
Normal file
85
demos/weather-forecast/docker-compose.yaml
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
services:
|
||||
envoy:
|
||||
build:
|
||||
context: ../../
|
||||
dockerfile: envoyfilter/Dockerfile
|
||||
hostname: envoy
|
||||
ports:
|
||||
- "10000:10000"
|
||||
- "19901:9901"
|
||||
volumes:
|
||||
- ./envoy.yaml:/etc/envoy/envoy.yaml
|
||||
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
|
||||
networks:
|
||||
- envoymesh
|
||||
depends_on:
|
||||
embeddingserver:
|
||||
condition: service_healthy
|
||||
|
||||
embeddingserver:
|
||||
build:
|
||||
context: ../../embedding-server
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18081:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
networks:
|
||||
- envoymesh
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
hostname: vector-db
|
||||
ports:
|
||||
- 16333:6333
|
||||
- 16334:6334
|
||||
networks:
|
||||
- envoymesh
|
||||
|
||||
chatbot-ui:
|
||||
build:
|
||||
context: ../../chatbot-ui
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18080:8080"
|
||||
networks:
|
||||
- envoymesh
|
||||
environment:
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- CHAT_COMPLETION_ENDPOINT=http://envoy:10000/v1/chat/completions
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus
|
||||
container_name: prometheus
|
||||
command:
|
||||
- '--config.file=/etc/prometheus/prometheus.yaml'
|
||||
ports:
|
||||
- 9090:9090
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./prometheus:/etc/prometheus
|
||||
- ./prom_data:/prometheus
|
||||
networks:
|
||||
- envoymesh
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana
|
||||
container_name: grafana
|
||||
ports:
|
||||
- 3000:3000
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- GF_SECURITY_ADMIN_USER=admin
|
||||
- GF_SECURITY_ADMIN_PASSWORD=grafana
|
||||
volumes:
|
||||
- ./grafana:/etc/grafana/provisioning/datasources
|
||||
- ./grafana/dashboard.yaml:/etc/grafana/provisioning/dashboards/main.yaml
|
||||
- ./grafana/dashboards:/var/lib/grafana/dashboards
|
||||
# - ./grafana-data:/var/lib/grafana
|
||||
networks:
|
||||
- envoymesh
|
||||
|
||||
networks:
|
||||
envoymesh: {}
|
||||
197
demos/weather-forecast/envoy.yaml
Normal file
197
demos/weather-forecast/envoy.yaml
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
admin:
|
||||
address:
|
||||
socket_address: { address: 0.0.0.0, port_value: 9901 }
|
||||
static_resources:
|
||||
listeners:
|
||||
address:
|
||||
socket_address:
|
||||
address: 0.0.0.0
|
||||
port_value: 10000
|
||||
filter_chains:
|
||||
- filters:
|
||||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
stat_prefix: ingress_http
|
||||
codec_type: AUTO
|
||||
scheme_header_transformation:
|
||||
scheme_to_overwrite: https
|
||||
route_config:
|
||||
name: local_routes
|
||||
virtual_hosts:
|
||||
- name: openai
|
||||
domains:
|
||||
- "api.openai.com"
|
||||
routes:
|
||||
- match:
|
||||
prefix: "/"
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: openai
|
||||
- name: local_service
|
||||
domains:
|
||||
- "*"
|
||||
routes:
|
||||
- match:
|
||||
prefix: "/v1/chat/completions"
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: openai
|
||||
- match:
|
||||
prefix: "/embeddings"
|
||||
route:
|
||||
cluster: embeddingserver
|
||||
- match:
|
||||
prefix: "/"
|
||||
direct_response:
|
||||
status: 200
|
||||
body:
|
||||
inline_string: "Inspect the HTTP header: custom-header.\n"
|
||||
http_filters:
|
||||
- name: envoy.filters.http.wasm
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
|
||||
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
|
||||
value:
|
||||
config:
|
||||
name: "http_config"
|
||||
configuration:
|
||||
"@type": "type.googleapis.com/google.protobuf.StringValue"
|
||||
value: |
|
||||
katanemo-prompt-config:
|
||||
default-prompt-endpoint: "127.0.0.1"
|
||||
load-balancing: "round-robin"
|
||||
timeout-ms: 5000
|
||||
|
||||
embedding-provider:
|
||||
name: "SentenceTransformer"
|
||||
model: "all-MiniLM-L6-v2"
|
||||
|
||||
llm-providers:
|
||||
|
||||
- name: "open-ai-gpt-4"
|
||||
api-key: "$OPEN_AI_API_KEY"
|
||||
model: gpt-4
|
||||
|
||||
prompt-targets:
|
||||
|
||||
- type: context-resolver
|
||||
name: weather-forecast
|
||||
few-shot-examples:
|
||||
- what is the weather in New York?
|
||||
- how is the weather in San Francisco?
|
||||
- what is the forecast in Seattle?
|
||||
entities:
|
||||
- name: city
|
||||
required: true
|
||||
- name: days
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
cache-response: true
|
||||
cache-response-settings:
|
||||
- cache-ttl-secs: 3600 # cache expiry in seconds
|
||||
- cache-max-size: 1000 # in number of items
|
||||
- cache-eviction-strategy: LRU
|
||||
system-prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
vm_config:
|
||||
runtime: "envoy.wasm.runtime.v8"
|
||||
code:
|
||||
local:
|
||||
filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm"
|
||||
- name: envoy.filters.http.router
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
clusters:
|
||||
# LLM Host
|
||||
# Embedding Providers
|
||||
# External LLM Providers
|
||||
- name: openai
|
||||
connect_timeout: 5s
|
||||
type: LOGICAL_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
typed_extension_protocol_options:
|
||||
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
|
||||
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
|
||||
explicit_http_config:
|
||||
http2_protocol_options: {}
|
||||
load_assignment:
|
||||
cluster_name: openai
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: api.openai.com
|
||||
port_value: 443
|
||||
hostname: "api.openai.com"
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
sni: api.openai.com
|
||||
common_tls_context:
|
||||
tls_params:
|
||||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
|
||||
- name: embeddingserver
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: embeddingserver
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: embeddingserver
|
||||
port_value: 80
|
||||
hostname: "embeddingserver"
|
||||
- name: weatherhost
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: weatherhost
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: embeddingserver
|
||||
port_value: 80
|
||||
hostname: "embeddingserver"
|
||||
- name: nerhost
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: nerhost
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: embeddingserver
|
||||
port_value: 80
|
||||
hostname: "embeddingserver"
|
||||
- name: qdrant
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: qdrant
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: qdrant
|
||||
port_value: 6333
|
||||
hostname: "qdrant"
|
||||
12
demos/weather-forecast/grafana/dashboard.yaml
Normal file
12
demos/weather-forecast/grafana/dashboard.yaml
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
apiVersion: 1
|
||||
|
||||
providers:
|
||||
- name: "Dashboard provider"
|
||||
orgId: 1
|
||||
type: file
|
||||
disableDeletion: false
|
||||
updateIntervalSeconds: 10
|
||||
allowUiUpdates: false
|
||||
options:
|
||||
path: /var/lib/grafana/dashboards
|
||||
foldersFromFilesStructure: true
|
||||
355
demos/weather-forecast/grafana/dashboards/envoy_overview.json
Normal file
355
demos/weather-forecast/grafana/dashboards/envoy_overview.json
Normal file
|
|
@ -0,0 +1,355 @@
|
|||
{
|
||||
"annotations": {
|
||||
"list": [
|
||||
{
|
||||
"builtIn": 1,
|
||||
"datasource": {
|
||||
"type": "grafana",
|
||||
"uid": "-- Grafana --"
|
||||
},
|
||||
"enable": true,
|
||||
"hide": true,
|
||||
"iconColor": "rgba(0, 211, 255, 1)",
|
||||
"name": "Annotations & Alerts",
|
||||
"type": "dashboard"
|
||||
}
|
||||
]
|
||||
},
|
||||
"editable": true,
|
||||
"fiscalYearStartMonth": 0,
|
||||
"graphTooltip": 1,
|
||||
"links": [],
|
||||
"panels": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "palette-classic"
|
||||
},
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisColorMode": "text",
|
||||
"axisLabel": "",
|
||||
"axisPlacement": "auto",
|
||||
"barAlignment": 0,
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"hideFrom": {
|
||||
"legend": false,
|
||||
"tooltip": false,
|
||||
"viz": false
|
||||
},
|
||||
"insertNulls": false,
|
||||
"lineInterpolation": "linear",
|
||||
"lineWidth": 1,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"showPoints": "auto",
|
||||
"spanNulls": false,
|
||||
"stacking": {
|
||||
"group": "A",
|
||||
"mode": "none"
|
||||
},
|
||||
"thresholdsStyle": {
|
||||
"mode": "off"
|
||||
}
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 80
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 0
|
||||
},
|
||||
"id": 2,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "list",
|
||||
"placement": "bottom",
|
||||
"showLegend": true
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "single",
|
||||
"sort": "none"
|
||||
}
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "code",
|
||||
"expr": "avg(rate(envoy_cluster_internal_upstream_rq_time_sum[1m]) / rate(envoy_cluster_internal_upstream_rq_time_count[1m])) by (envoy_cluster_name)",
|
||||
"fullMetaSearch": false,
|
||||
"hide": false,
|
||||
"includeNullMetadata": true,
|
||||
"instant": false,
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "A",
|
||||
"useBackend": false
|
||||
}
|
||||
],
|
||||
"title": "request latency - internal (ms)",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "palette-classic"
|
||||
},
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisColorMode": "text",
|
||||
"axisLabel": "",
|
||||
"axisPlacement": "auto",
|
||||
"barAlignment": 0,
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"hideFrom": {
|
||||
"legend": false,
|
||||
"tooltip": false,
|
||||
"viz": false
|
||||
},
|
||||
"insertNulls": false,
|
||||
"lineInterpolation": "linear",
|
||||
"lineWidth": 1,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"showPoints": "auto",
|
||||
"spanNulls": false,
|
||||
"stacking": {
|
||||
"group": "A",
|
||||
"mode": "none"
|
||||
},
|
||||
"thresholdsStyle": {
|
||||
"mode": "off"
|
||||
}
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 80
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 12,
|
||||
"x": 12,
|
||||
"y": 0
|
||||
},
|
||||
"id": 1,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "list",
|
||||
"placement": "bottom",
|
||||
"showLegend": true
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "single",
|
||||
"sort": "none"
|
||||
}
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "code",
|
||||
"expr": "avg(rate(envoy_cluster_external_upstream_rq_time_sum[1m]) / rate(envoy_cluster_external_upstream_rq_time_count[1m])) by (envoy_cluster_name)",
|
||||
"fullMetaSearch": false,
|
||||
"hide": false,
|
||||
"includeNullMetadata": true,
|
||||
"instant": false,
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "A",
|
||||
"useBackend": false
|
||||
}
|
||||
],
|
||||
"title": "request latency - external (ms)",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": {
|
||||
"mode": "palette-classic"
|
||||
},
|
||||
"custom": {
|
||||
"axisBorderShow": false,
|
||||
"axisCenteredZero": false,
|
||||
"axisColorMode": "text",
|
||||
"axisLabel": "",
|
||||
"axisPlacement": "auto",
|
||||
"barAlignment": 0,
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"hideFrom": {
|
||||
"legend": false,
|
||||
"tooltip": false,
|
||||
"viz": false
|
||||
},
|
||||
"insertNulls": false,
|
||||
"lineInterpolation": "linear",
|
||||
"lineWidth": 1,
|
||||
"pointSize": 5,
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"showPoints": "auto",
|
||||
"spanNulls": false,
|
||||
"stacking": {
|
||||
"group": "A",
|
||||
"mode": "none"
|
||||
},
|
||||
"thresholdsStyle": {
|
||||
"mode": "off"
|
||||
}
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 80
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 8
|
||||
},
|
||||
"id": 3,
|
||||
"options": {
|
||||
"legend": {
|
||||
"calcs": [],
|
||||
"displayMode": "list",
|
||||
"placement": "bottom",
|
||||
"showLegend": true
|
||||
},
|
||||
"tooltip": {
|
||||
"mode": "single",
|
||||
"sort": "none"
|
||||
}
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "code",
|
||||
"expr": "avg(rate(envoy_cluster_internal_upstream_rq_completed[1m])) by (envoy_cluster_name)",
|
||||
"fullMetaSearch": false,
|
||||
"includeNullMetadata": true,
|
||||
"instant": false,
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "A",
|
||||
"useBackend": false
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "PBFA97CFB590B2093"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "code",
|
||||
"expr": "avg(rate(envoy_cluster_external_upstream_rq_completed[1m])) by (envoy_cluster_name)",
|
||||
"fullMetaSearch": false,
|
||||
"hide": false,
|
||||
"includeNullMetadata": true,
|
||||
"instant": false,
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "B",
|
||||
"useBackend": false
|
||||
}
|
||||
],
|
||||
"title": "Upstream request count",
|
||||
"type": "timeseries"
|
||||
}
|
||||
],
|
||||
"schemaVersion": 39,
|
||||
"tags": [],
|
||||
"templating": {
|
||||
"list": []
|
||||
},
|
||||
"time": {
|
||||
"from": "now-15m",
|
||||
"to": "now"
|
||||
},
|
||||
"timepicker": {},
|
||||
"timezone": "browser",
|
||||
"title": "Intelligent Gateway Overview",
|
||||
"uid": "adt6uhx5lk8aob",
|
||||
"version": 3,
|
||||
"weekStart": ""
|
||||
}
|
||||
9
demos/weather-forecast/grafana/datasource.yaml
Normal file
9
demos/weather-forecast/grafana/datasource.yaml
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
apiVersion: 1
|
||||
|
||||
datasources:
|
||||
- name: Prometheus
|
||||
type: prometheus
|
||||
url: http://prometheus:9090
|
||||
isDefault: true
|
||||
access: proxy
|
||||
editable: true
|
||||
23
demos/weather-forecast/prometheus/prometheus.yaml
Normal file
23
demos/weather-forecast/prometheus/prometheus.yaml
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
global:
|
||||
scrape_interval: 15s
|
||||
scrape_timeout: 10s
|
||||
evaluation_interval: 15s
|
||||
alerting:
|
||||
alertmanagers:
|
||||
- static_configs:
|
||||
- targets: []
|
||||
scheme: http
|
||||
timeout: 10s
|
||||
api_version: v1
|
||||
scrape_configs:
|
||||
- job_name: envoy
|
||||
honor_timestamps: true
|
||||
scrape_interval: 15s
|
||||
scrape_timeout: 10s
|
||||
metrics_path: /stats
|
||||
scheme: http
|
||||
static_configs:
|
||||
- targets:
|
||||
- envoy:9901
|
||||
params:
|
||||
format: ['prometheus']
|
||||
|
|
@ -24,6 +24,7 @@ FROM python:3-slim AS output
|
|||
# following models have been tested to work with this image
|
||||
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
|
||||
ENV MODELS="BAAI/bge-large-en-v1.5"
|
||||
ENV NER_MODELS="urchade/gliner_large-v2.1"
|
||||
|
||||
COPY --from=builder /runtime /usr/local
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
from load_transformers import load_transformers
|
||||
from load_models import load_transformers, load_ner_models
|
||||
|
||||
print('installing transformers')
|
||||
load_transformers()
|
||||
print('installing ner models')
|
||||
load_ner_models()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import sentence_transformers
|
||||
from gliner import GLiNER
|
||||
|
||||
def load_transformers(models = os.getenv("MODELS", "sentence-transformers/all-MiniLM-L6-v2")):
|
||||
transformers = {}
|
||||
|
|
@ -8,3 +9,11 @@ def load_transformers(models = os.getenv("MODELS", "sentence-transformers/all-Mi
|
|||
transformers[model] = sentence_transformers.SentenceTransformer(model)
|
||||
|
||||
return transformers
|
||||
|
||||
def load_ner_models(models = os.getenv("NER_MODELS", "urchade/gliner_large-v2.1")):
|
||||
ner_models = {}
|
||||
|
||||
for model in models.split(','):
|
||||
ner_models[model] = GLiNER.from_pretrained(model)
|
||||
|
||||
return ner_models
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
import random
|
||||
from fastapi import FastAPI, Response, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from load_transformers import load_transformers
|
||||
from load_models import load_ner_models, load_transformers
|
||||
from datetime import date, timedelta
|
||||
|
||||
transformers = load_transformers()
|
||||
ner_models = load_ner_models()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
@ -10,6 +13,12 @@ class EmbeddingRequest(BaseModel):
|
|||
input: str
|
||||
model: str
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
return {
|
||||
"status": "ok"
|
||||
}
|
||||
|
||||
@app.get("/models")
|
||||
async def models():
|
||||
models = []
|
||||
|
|
@ -27,7 +36,7 @@ async def models():
|
|||
|
||||
@app.post("/embeddings")
|
||||
async def embedding(req: EmbeddingRequest, res: Response):
|
||||
if not req.model in transformers:
|
||||
if req.model not in transformers:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
embeddings = transformers[req.model].encode([req.input])
|
||||
|
|
@ -51,3 +60,48 @@ async def embedding(req: EmbeddingRequest, res: Response):
|
|||
"object": "list",
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
class NERRequest(BaseModel):
|
||||
input: str
|
||||
labels: list[str]
|
||||
model: str
|
||||
|
||||
|
||||
@app.post("/ner")
|
||||
async def ner(req: NERRequest, res: Response):
|
||||
if req.model not in ner_models:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
model = ner_models[req.model]
|
||||
entities = model.predict_entities(req.input, req.labels)
|
||||
|
||||
return {
|
||||
"data": entities,
|
||||
"model": req.model,
|
||||
"object": "list",
|
||||
}
|
||||
|
||||
class WeatherRequest(BaseModel):
|
||||
city: str
|
||||
|
||||
|
||||
@app.post("/weather")
|
||||
async def weather(req: WeatherRequest, res: Response):
|
||||
|
||||
weather_forecast = {
|
||||
"city": req.city,
|
||||
"temperature": [],
|
||||
"unit": "F",
|
||||
}
|
||||
for i in range(7):
|
||||
min_temp = random.randrange(50,90)
|
||||
max_temp = random.randrange(min_temp+5, min_temp+20)
|
||||
weather_forecast["temperature"].append({
|
||||
"date": str(date.today() + timedelta(days=i)),
|
||||
"temperature": {
|
||||
"min": min_temp,
|
||||
"max": max_temp
|
||||
}
|
||||
})
|
||||
|
||||
return weather_forecast
|
||||
|
|
|
|||
|
|
@ -3,3 +3,4 @@ fastapi
|
|||
sentence-transformers
|
||||
torch
|
||||
uvicorn
|
||||
gliner
|
||||
|
|
|
|||
24
envoyfilter/Cargo.lock
generated
24
envoyfilter/Cargo.lock
generated
|
|
@ -1621,9 +1621,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "0.6.6"
|
||||
version = "0.6.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0"
|
||||
checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
|
@ -1936,9 +1936,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.8.15"
|
||||
version = "0.8.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac2caab0bf757388c6c0ae23b3293fdb463fee59434529014f85e3263b995c28"
|
||||
checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
|
|
@ -1948,18 +1948,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "0.6.6"
|
||||
version = "0.6.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf"
|
||||
checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.22.16"
|
||||
version = "0.22.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "278f3d518e152219c994ce877758516bca5e118eaed6996192a774fb9fbf0788"
|
||||
checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"serde",
|
||||
|
|
@ -2121,9 +2121,9 @@ checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
|
|||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
version = "0.9.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
|
|
@ -2730,9 +2730,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
|||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.6.15"
|
||||
version = "0.6.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "557404e450152cd6795bb558bca69e43c585055f4606e3bcae5894fc6dac9ba0"
|
||||
checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
|
|
|||
16
envoyfilter/Dockerfile
Normal file
16
envoyfilter/Dockerfile
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# build filter using rust toolchain
|
||||
FROM rust:1.80.0 as builder
|
||||
WORKDIR /envoyfilter
|
||||
COPY envoyfilter/src /envoyfilter/src
|
||||
COPY envoyfilter/Cargo.toml /envoyfilter/
|
||||
COPY envoyfilter/Cargo.lock /envoyfilter/
|
||||
COPY open-message-format /open-message-format
|
||||
|
||||
RUN rustup -v target add wasm32-wasi
|
||||
RUN cargo build --release --target wasm32-wasi
|
||||
|
||||
# copy built filter into envoy image
|
||||
FROM envoyproxy/envoy:v1.30-latest
|
||||
COPY --from=builder /envoyfilter/target/wasm32-wasi/release/intelligent_prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm
|
||||
COPY envoyfilter/envoy.yaml /etc/envoy.yaml
|
||||
CMD ["envoy", "-c", "/etc/envoy/envoy.yaml"]
|
||||
3
envoyfilter/build_filter.sh
Normal file
3
envoyfilter/build_filter.sh
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
RUST_VERSION=1.80.0
|
||||
docker run --rm -v rustup_cache:/usr/local/rustup/ rust:$RUST_VERSION rustup -v target add wasm32-wasi
|
||||
docker run --rm -v $PWD/../open-message-format:/code/open-message-format -v ~/.cargo:/root/.cargo -v $(pwd):/code/envoyfilter -w /code/envoyfilter -v rustup_cache:/usr/local/rustup/ rust:$RUST_VERSION cargo build --release --target wasm32-wasi
|
||||
|
|
@ -22,7 +22,7 @@ services:
|
|||
ports:
|
||||
- "18080:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl" ,"http://localhost:80"]
|
||||
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
networks:
|
||||
|
|
|
|||
|
|
@ -41,10 +41,6 @@ static_resources:
|
|||
prefix: "/embeddings"
|
||||
route:
|
||||
cluster: embeddingserver
|
||||
- match:
|
||||
prefix: "/inline"
|
||||
route:
|
||||
cluster: httpbin
|
||||
- match:
|
||||
prefix: "/"
|
||||
direct_response:
|
||||
|
|
@ -78,7 +74,7 @@ static_resources:
|
|||
model: gpt-4
|
||||
|
||||
system-prompt: |
|
||||
You are a helpful weather forecaster. Please following following guidelines when responding to user queries:
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
|
|
@ -88,12 +84,22 @@ static_resources:
|
|||
name: weather-forecast
|
||||
few-shot-examples:
|
||||
- what is the weather in New York?
|
||||
endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT"
|
||||
- how is the weather in San Francisco?
|
||||
- what is the forecast in Seattle?
|
||||
entities:
|
||||
- city
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
cache-response: true
|
||||
cache-response-settings:
|
||||
- cache-ttl-secs: 3600 # cache expiry in seconds
|
||||
- cache-max-size: 1000 # in number of items
|
||||
- cache-eviction-strategy: LRU
|
||||
system-prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
- Use miles per hour for wind speed
|
||||
|
||||
vm_config:
|
||||
runtime: "envoy.wasm.runtime.v8"
|
||||
|
|
@ -136,20 +142,6 @@ static_resources:
|
|||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
|
||||
- name: httpbin
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: httpbin
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: httpbin.org
|
||||
port_value: 80
|
||||
hostname: "httpbin.org"
|
||||
- name: embeddingserver
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
|
|
@ -164,6 +156,34 @@ static_resources:
|
|||
address: embeddingserver
|
||||
port_value: 80
|
||||
hostname: "embeddingserver"
|
||||
- name: weatherhost
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: weatherhost
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: embeddingserver
|
||||
port_value: 80
|
||||
hostname: "embeddingserver"
|
||||
- name: nerhost
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: nerhost
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: embeddingserver
|
||||
port_value: 80
|
||||
hostname: "embeddingserver"
|
||||
- name: qdrant
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
|
|
|
|||
9
envoyfilter/grafana/datasource.yaml
Normal file
9
envoyfilter/grafana/datasource.yaml
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
apiVersion: 1
|
||||
|
||||
datasources:
|
||||
- name: Prometheus
|
||||
type: prometheus
|
||||
url: http://prometheus:9090
|
||||
isDefault: true
|
||||
access: proxy
|
||||
editable: true
|
||||
23
envoyfilter/prometheus/prometheus.yaml
Normal file
23
envoyfilter/prometheus/prometheus.yaml
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
global:
|
||||
scrape_interval: 15s
|
||||
scrape_timeout: 10s
|
||||
evaluation_interval: 15s
|
||||
alerting:
|
||||
alertmanagers:
|
||||
- static_configs:
|
||||
- targets: []
|
||||
scheme: http
|
||||
timeout: 10s
|
||||
api_version: v1
|
||||
scrape_configs:
|
||||
- job_name: envoy
|
||||
honor_timestamps: true
|
||||
scrape_interval: 15s
|
||||
scrape_timeout: 10s
|
||||
metrics_path: /stats
|
||||
scheme: http
|
||||
static_configs:
|
||||
- targets:
|
||||
- envoy:9901
|
||||
params:
|
||||
format: ['prometheus']
|
||||
|
|
@ -23,9 +23,53 @@ pub struct StoreVectorEmbeddingsRequest {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum CallContext {
|
||||
EmbeddingRequest(EmbeddingRequest),
|
||||
StoreVectorEmbeddings(StoreVectorEmbeddingsRequest),
|
||||
CreateVectorCollection(String),
|
||||
}
|
||||
|
||||
// https://api.qdrant.tech/master/api-reference/search/points
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointsRequest {
|
||||
pub vector: Vec<f64>,
|
||||
pub limit: i32,
|
||||
pub with_payload: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointResult {
|
||||
pub id: String,
|
||||
pub version: i32,
|
||||
pub score: f64,
|
||||
pub payload: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointsResponse {
|
||||
pub result: Vec<SearchPointResult>,
|
||||
pub status: String,
|
||||
pub time: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NERRequest {
|
||||
pub input: String,
|
||||
pub labels: Vec<String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Entity {
|
||||
pub text: String,
|
||||
pub label: String,
|
||||
pub score: f64,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NERResponse {
|
||||
pub data: Vec<Entity>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
pub mod open_ai {
|
||||
|
|
@ -41,6 +85,6 @@ pub mod open_ai {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ pub struct PromptConfig {
|
|||
pub timeout_ms: u64,
|
||||
pub embedding_provider: EmbeddingProviver,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub system_prompt: String,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_targets: Vec<PromptTarget>,
|
||||
}
|
||||
|
||||
|
|
@ -49,6 +49,29 @@ pub struct LlmProvider {
|
|||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct Endpoint {
|
||||
pub cluster: String,
|
||||
pub path: Option<String>,
|
||||
pub method: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct EntityDetail {
|
||||
pub name: String,
|
||||
pub required: Option<bool>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum EntityType {
|
||||
Vec(Vec<String>),
|
||||
Struct(Vec<EntityDetail>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct PromptTarget {
|
||||
|
|
@ -56,7 +79,9 @@ pub struct PromptTarget {
|
|||
pub prompt_type: String,
|
||||
pub name: String,
|
||||
pub few_shot_examples: Vec<String>,
|
||||
pub endpoint: String,
|
||||
pub entities: Option<EntityType>,
|
||||
pub endpoint: Option<Endpoint>,
|
||||
pub system_prompt: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -88,13 +113,23 @@ katanemo-prompt-config:
|
|||
name: weather-forecast
|
||||
few-shot-examples:
|
||||
- what is the weather in New York?
|
||||
endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT"
|
||||
cache-response: true
|
||||
cache-response-settings:
|
||||
- cache-ttl-secs: 3600 # cache expiry in seconds
|
||||
- cache-max-size: 1000 # in number of items
|
||||
- cache-eviction-strategy: LRU
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- name: location
|
||||
required: true
|
||||
description: "The location for which the weather is requested"
|
||||
|
||||
- type: context-resolver
|
||||
name: weather-forecast-2
|
||||
few-shot-examples:
|
||||
- what is the weather in New York?
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- city
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -1 +1,7 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
|
||||
pub const DEFAULT_NER_MODEL: &str = "urchade/gliner_large-v2.1";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
|
||||
pub const DEFAULT_NER_THRESHOLD: f64 = 0.6;
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
|
|
|
|||
289
envoyfilter/src/filter_context.rs
Normal file
289
envoyfilter/src/filter_context.rs
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
use common_types::{CallContext, EmbeddingRequest};
|
||||
use configuration::PromptTarget;
|
||||
use log::info;
|
||||
use md5::Digest;
|
||||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::to_string;
|
||||
use stats::RecordingMetric;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use consts::DEFAULT_EMBEDDING_MODEL;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
use crate::common_types;
|
||||
use crate::configuration;
|
||||
use crate::consts;
|
||||
use crate::stats;
|
||||
use crate::stream_context::StreamContext;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct WasmMetrics {
|
||||
active_http_calls: stats::Gauge,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: stats::Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FilterContext {
|
||||
metrics: WasmMetrics,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, common_types::CallContext>,
|
||||
config: Option<configuration::Configuration>,
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: WasmMetrics::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets {
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(
|
||||
few_shot_example.to_string(),
|
||||
)),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
let json_data: String = to_string(&embeddings_input).unwrap();
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
let embedding_request = EmbeddingRequest {
|
||||
create_embedding_request: embeddings_input,
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext::EmbeddingRequest(embedding_request)
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_request_handler(
|
||||
&mut self,
|
||||
body_size: usize,
|
||||
create_embedding_request: CreateEmbeddingRequest,
|
||||
prompt_target: PromptTarget,
|
||||
) {
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let embedding_response: CreateEmbeddingResponse =
|
||||
serde_json::from_slice(&body).unwrap();
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
to_string(&prompt_target).unwrap(),
|
||||
);
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(input) => {
|
||||
id = Some(md5::compute(&input));
|
||||
payload.insert("input".to_string(), input);
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
};
|
||||
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_vector_store_points_handler(&self, body_size: usize) {
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
info!(
|
||||
"response body: len {:?}",
|
||||
String::from_utf8(body).unwrap().len()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: run once per envoy instance, right now it runs once per worker
|
||||
fn init_vector_store(&mut self) {
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(b"{ \"vectors\": { \"size\": 1024, \"distance\": \"Cosine\"}}"),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for init-vector-store: {:?}", e);
|
||||
}
|
||||
};
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::CreateVectorCollection("prompt_vector_store".to_string()),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
// self.metrics
|
||||
// .active_http_calls
|
||||
// .record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
|
||||
match callout_data {
|
||||
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
|
||||
}
|
||||
common_types::CallContext::StoreVectorEmbeddings(_) => {
|
||||
self.create_vector_store_points_handler(body_size)
|
||||
}
|
||||
common_types::CallContext::CreateVectorCollection(_) => {
|
||||
let mut http_status_code = "Nil".to_string();
|
||||
self.get_http_call_response_headers()
|
||||
.iter()
|
||||
.for_each(|(k, v)| {
|
||||
if k == ":status" {
|
||||
http_status_code.clone_from(v);
|
||||
}
|
||||
});
|
||||
info!("CreateVectorCollection response: {}", http_status_code);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
Some(Box::new(StreamContext {
|
||||
host_header: None,
|
||||
callouts: HashMap::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
// initialize vector store
|
||||
self.init_vector_store();
|
||||
self.process_prompt_targets();
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,22 +1,13 @@
|
|||
use common_types::{CallContext, EmbeddingRequest};
|
||||
use configuration::PromptTarget;
|
||||
use http::StatusCode;
|
||||
use log::error;
|
||||
use log::info;
|
||||
use md5::Digest;
|
||||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use filter_context::FilterContext;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use stats::{Gauge, RecordingMetric};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
mod common_types;
|
||||
mod configuration;
|
||||
mod consts;
|
||||
mod filter_context;
|
||||
mod stats;
|
||||
mod stream_context;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
|
|
@ -24,397 +15,3 @@ proxy_wasm::main! {{
|
|||
Box::new(FilterContext::new())
|
||||
});
|
||||
}}
|
||||
|
||||
struct StreamContext {
|
||||
host_header: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
fn save_host_header(&mut self) {
|
||||
// Save the host header to be used by filter logic later on.
|
||||
self.host_header = self.get_http_request_header(":host");
|
||||
}
|
||||
|
||||
fn delete_content_length_header(&mut self) {
|
||||
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
|
||||
// Server's generally throw away requests whose body length do not match the Content-Length header.
|
||||
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
|
||||
// manipulate the body in benign ways e.g., compression.
|
||||
self.set_http_request_header("content-length", None);
|
||||
}
|
||||
|
||||
fn modify_path_header(&mut self) {
|
||||
match self.get_http_request_header(":path") {
|
||||
// The gateway can start gathering information necessary for routing. For now change the path to an
|
||||
// OpenAI API path.
|
||||
Some(path) if path == "/llmrouting" => {
|
||||
self.set_http_request_header(":path", Some("/v1/chat/completions"));
|
||||
}
|
||||
// Otherwise let the filter continue.
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
impl HttpContext for StreamContext {
|
||||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
self.save_host_header();
|
||||
self.delete_content_length_header();
|
||||
self.modify_path_header();
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
||||
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
// Let the client send the gateway all the data before sending to the LLM_provider.
|
||||
// TODO: consider a streaming API.
|
||||
if !end_of_stream {
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
if body_size == 0 {
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let mut deserialized_body: common_types::open_ai::ChatCompletions =
|
||||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
vec![],
|
||||
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
// Modify JSON payload
|
||||
deserialized_body.model = String::from("gpt-3.5-turbo");
|
||||
|
||||
match serde_json::to_string(&deserialized_body) {
|
||||
Ok(json_string) => {
|
||||
self.set_http_request_body(0, body_size, &json_string.into_bytes());
|
||||
Action::Continue
|
||||
}
|
||||
Err(error) => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!("Failed to serialize body: {}", error);
|
||||
Action::Pause
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for StreamContext {}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct WasmMetrics {
|
||||
active_http_calls: Gauge,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct FilterContext {
|
||||
metrics: WasmMetrics,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, common_types::CallContext>,
|
||||
config: Option<configuration::Configuration>,
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: WasmMetrics::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
for prompt_target in &self
|
||||
.config
|
||||
.as_ref()
|
||||
.expect("Gateway configuration cannot be non-existent")
|
||||
.prompt_config
|
||||
.prompt_targets
|
||||
{
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
info!("few_shot_example: {:?}", few_shot_example);
|
||||
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(
|
||||
few_shot_example.to_string(),
|
||||
)),
|
||||
model: String::from(consts::DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching embedding server HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
let embedding_request = EmbeddingRequest {
|
||||
create_embedding_request: embeddings_input,
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext::EmbeddingRequest(embedding_request)
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_request_handler(
|
||||
&mut self,
|
||||
body_size: usize,
|
||||
create_embedding_request: CreateEmbeddingRequest,
|
||||
prompt_target: PromptTarget,
|
||||
) {
|
||||
info!("response received for CreateEmbeddingRequest");
|
||||
let body = match self.get_http_call_response_body(0, body_size) {
|
||||
Some(body) => body,
|
||||
None => {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if body.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let (json_data, create_vector_store_points) =
|
||||
match build_qdrant_data(&body, create_embedding_request, &prompt_target) {
|
||||
Ok(tuple) => tuple,
|
||||
Err(error) => {
|
||||
panic!(
|
||||
"Error building qdrant data from embedding response {}",
|
||||
error
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching qdrant HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
info!("on_tick: dispatched HTTP call with token_id = {}", token_id);
|
||||
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id={} in qdrant requests", token_id)
|
||||
}
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
|
||||
// TODO: @adilhafeez implement.
|
||||
fn create_vector_store_points_handler(&self, body_size: usize) {
|
||||
info!("response received for CreateVectorStorePoints");
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
info!("response body: {:?}", String::from_utf8(body).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
info!("on_http_call_response: token_id = {}", token_id);
|
||||
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
match callout_data {
|
||||
common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
|
||||
}
|
||||
common_types::CallContext::StoreVectorEmbeddings(_) => {
|
||||
self.create_vector_store_points_handler(body_size)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
fn on_configure(&mut self, _: usize) -> bool {
|
||||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = match serde_yaml::from_slice(&config_bytes) {
|
||||
Ok(config) => config,
|
||||
Err(error) => {
|
||||
panic!("Failed to deserialize plugin configuration: {}", error);
|
||||
}
|
||||
};
|
||||
info!("on_configure: plugin configuration loaded");
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
Some(Box::new(StreamContext { host_header: None }))
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
info!("on_vm_start: setting up tick timeout");
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
info!("on_tick: starting to process prompt targets");
|
||||
self.process_prompt_targets();
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
}
|
||||
|
||||
fn build_qdrant_data(
|
||||
embedding_data: &[u8],
|
||||
create_embedding_request: CreateEmbeddingRequest,
|
||||
prompt_target: &PromptTarget,
|
||||
) -> Result<(String, common_types::StoreVectorEmbeddingsRequest), serde_json::Error> {
|
||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(embedding_data) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(error) => {
|
||||
panic!("Failed to deserialize embedding response: {}", error);
|
||||
}
|
||||
};
|
||||
info!(
|
||||
"embedding_response model: {}, vector len: {}",
|
||||
embedding_response.model,
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
serde_json::to_string(&prompt_target)?,
|
||||
);
|
||||
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(ref input) => {
|
||||
id = Some(md5::compute(input));
|
||||
payload.insert("input".to_string(), input.clone());
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest {
|
||||
points: vec![common_types::VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
}],
|
||||
};
|
||||
let json_data = serde_json::to_string(&create_vector_store_points)?;
|
||||
info!(
|
||||
"create_vector_store_points: points length: {}",
|
||||
embedding_response.data[0].embedding.len()
|
||||
);
|
||||
|
||||
Ok((json_data, create_vector_store_points))
|
||||
}
|
||||
|
|
|
|||
520
envoyfilter/src/stream_context.rs
Normal file
520
envoyfilter/src/stream_context.rs
Normal file
|
|
@ -0,0 +1,520 @@
|
|||
use http::StatusCode;
|
||||
use log::error;
|
||||
use log::info;
|
||||
use log::warn;
|
||||
use open_message_format::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
use consts::{
|
||||
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
|
||||
use crate::common_types;
|
||||
use crate::common_types::open_ai::Message;
|
||||
use crate::common_types::SearchPointsResponse;
|
||||
use crate::configuration::EntityDetail;
|
||||
use crate::configuration::EntityType;
|
||||
use crate::configuration::PromptTarget;
|
||||
use crate::consts;
|
||||
|
||||
enum RequestType {
|
||||
GetEmbedding,
|
||||
SearchPoints,
|
||||
Ner,
|
||||
ContextResolver,
|
||||
}
|
||||
|
||||
pub struct CallContext {
|
||||
request_type: RequestType,
|
||||
user_message: String,
|
||||
prompt_target: Option<PromptTarget>,
|
||||
request_body: common_types::open_ai::ChatCompletions,
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
pub host_header: Option<String>,
|
||||
pub callouts: HashMap<u32, CallContext>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
fn save_host_header(&mut self) {
|
||||
// Save the host header to be used by filter logic later on.
|
||||
self.host_header = self.get_http_request_header(":host");
|
||||
}
|
||||
|
||||
fn delete_content_length_header(&mut self) {
|
||||
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
|
||||
// Server's generally throw away requests whose body length do not match the Content-Length header.
|
||||
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
|
||||
// manipulate the body in benign ways e.g., compression.
|
||||
self.set_http_request_header("content-length", None);
|
||||
}
|
||||
|
||||
fn modify_path_header(&mut self) {
|
||||
match self.get_http_request_header(":path") {
|
||||
// The gateway can start gathering information necessary for routing. For now change the path to an
|
||||
// OpenAI API path.
|
||||
Some(path) if path == "/llmrouting" => {
|
||||
self.set_http_request_header(":path", Some("/v1/chat/completions"));
|
||||
}
|
||||
// Otherwise let the filter continue.
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(e) => {
|
||||
warn!("Error deserializing embedding response: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let search_points_request = common_types::SearchPointsRequest {
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
limit: 10,
|
||||
with_payload: true,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&search_points_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(e) => {
|
||||
warn!("Error serializing search_points_request: {:?}", e);
|
||||
self.reset_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let path = format!("/collections/{}/points/search", DEFAULT_COLLECTION_NAME);
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", &path),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
|
||||
}
|
||||
};
|
||||
|
||||
callout_context.request_type = RequestType::SearchPoints;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
|
||||
fn search_points_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let search_points_response: SearchPointsResponse = match serde_json::from_slice(&body) {
|
||||
Ok(search_points_response) => search_points_response,
|
||||
Err(e) => {
|
||||
warn!("Error deserializing search_points_response: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let search_results = &search_points_response.result;
|
||||
|
||||
if search_results.is_empty() {
|
||||
info!("No prompt target matched");
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
|
||||
info!("similarity score: {}", search_results[0].score);
|
||||
|
||||
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD {
|
||||
info!(
|
||||
"prompt target below threshold: {}",
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
let prompt_target_str = search_results[0].payload.get("prompt-target").unwrap();
|
||||
let prompt_target: PromptTarget = match serde_json::from_slice(prompt_target_str.as_bytes())
|
||||
{
|
||||
Ok(prompt_target) => prompt_target,
|
||||
Err(e) => {
|
||||
warn!("Error deserializing prompt_target: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
info!("prompt_target name: {:?}", prompt_target.name);
|
||||
|
||||
// only extract entity names
|
||||
let entity_names = get_entity_details(&prompt_target)
|
||||
.iter()
|
||||
.map(|entity| entity.name.clone())
|
||||
.collect();
|
||||
let user_message = callout_context.user_message.clone();
|
||||
let ner_request = common_types::NERRequest {
|
||||
input: user_message,
|
||||
labels: entity_names,
|
||||
model: DEFAULT_NER_MODEL.to_string(),
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&ner_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(e) => {
|
||||
warn!("Error serializing ner_request: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"nerhost",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/ner"),
|
||||
(":authority", "nerhost"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
|
||||
}
|
||||
};
|
||||
callout_context.request_type = RequestType::Ner;
|
||||
callout_context.prompt_target = Some(prompt_target);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
|
||||
fn ner_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let ner_response: common_types::NERResponse = match serde_json::from_slice(&body) {
|
||||
Ok(ner_response) => ner_response,
|
||||
Err(e) => {
|
||||
warn!("Error deserializing ner_response: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
info!("ner_response: {:?}", ner_response);
|
||||
|
||||
let mut request_params: HashMap<String, String> = HashMap::new();
|
||||
for entity in ner_response.data.iter() {
|
||||
if entity.score < DEFAULT_NER_THRESHOLD {
|
||||
warn!(
|
||||
"score of entity was too low entity name: {}, score: {}",
|
||||
entity.label, entity.score
|
||||
);
|
||||
continue;
|
||||
}
|
||||
request_params.insert(entity.label.clone(), entity.text.clone());
|
||||
}
|
||||
|
||||
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
|
||||
let entity_details = get_entity_details(prompt_target);
|
||||
for entity in entity_details {
|
||||
if entity.required.unwrap_or(false) && !request_params.contains_key(&entity.name) {
|
||||
warn!(
|
||||
"required entity missing or score of entity was too low: {}",
|
||||
entity.name
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let req_param_str = match serde_json::to_string(&request_params) {
|
||||
Ok(req_param_str) => req_param_str,
|
||||
Err(e) => {
|
||||
warn!("Error serializing request_params: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let endpoint = callout_context
|
||||
.prompt_target
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.endpoint
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
|
||||
let http_path = match &endpoint.path {
|
||||
Some(path) => path.clone(),
|
||||
None => "/".to_string(),
|
||||
};
|
||||
|
||||
let http_method = match &endpoint.method {
|
||||
Some(method) => method.clone(),
|
||||
None => http::Method::POST.to_string(),
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
&endpoint.cluster.clone(),
|
||||
vec![
|
||||
(":method", http_method.as_str()),
|
||||
(":path", http_path.as_str()),
|
||||
(":authority", endpoint.cluster.as_str()),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(req_param_str.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for context-resolver: {:?}", e);
|
||||
}
|
||||
};
|
||||
callout_context.request_type = RequestType::ContextResolver;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
|
||||
fn context_resolver_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
|
||||
info!("response received for context-resolver");
|
||||
let body_string = String::from_utf8(body);
|
||||
let prompt_target = callout_context.prompt_target.unwrap();
|
||||
let mut request_body = callout_context.request_body;
|
||||
match prompt_target.system_prompt {
|
||||
None => {}
|
||||
Some(system_prompt) => {
|
||||
let system_prompt_message: Message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt),
|
||||
};
|
||||
request_body.messages.push(system_prompt_message);
|
||||
}
|
||||
}
|
||||
match body_string {
|
||||
Ok(body_string) => {
|
||||
info!("context-resolver response: {}", body_string);
|
||||
let context_resolver_response = Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(body_string),
|
||||
};
|
||||
request_body.messages.push(context_resolver_response);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error converting response to string: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let json_string = match serde_json::to_string(&request_body) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
warn!("Error serializing request_body: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
info!("sending request to openai: msg len: {}", json_string.len());
|
||||
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
}
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
impl HttpContext for StreamContext {
|
||||
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
|
||||
// the lifecycle of the http request and response.
|
||||
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
self.save_host_header();
|
||||
self.delete_content_length_header();
|
||||
self.modify_path_header();
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
||||
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
// Let the client send the gateway all the data before sending to the LLM_provider.
|
||||
// TODO: consider a streaming API.
|
||||
if !end_of_stream {
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
if body_size == 0 {
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let deserialized_body: common_types::open_ai::ChatCompletions =
|
||||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
vec![],
|
||||
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let user_message = match deserialized_body
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|last_message| last_message.content.as_ref())
|
||||
{
|
||||
Some(content) => content,
|
||||
None => {
|
||||
info!("No messages in the request body");
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
let get_embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
let call_context = CallContext {
|
||||
request_type: RequestType::GetEmbedding,
|
||||
user_message: user_message.clone(),
|
||||
prompt_target: None,
|
||||
request_body: deserialized_body,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
|
||||
Action::Pause
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for StreamContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_context = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
let resp = self.get_http_call_response_body(0, body_size);
|
||||
|
||||
if resp.is_none() {
|
||||
warn!("No response body");
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
|
||||
let body = match resp {
|
||||
Some(body) => body,
|
||||
None => {
|
||||
warn!("Empty response body");
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match callout_context.request_type {
|
||||
RequestType::GetEmbedding => {
|
||||
self.embeddings_handler(body, callout_context);
|
||||
}
|
||||
|
||||
RequestType::SearchPoints => {
|
||||
self.search_points_handler(body, callout_context);
|
||||
}
|
||||
RequestType::Ner => {
|
||||
self.ner_handler(body, callout_context);
|
||||
}
|
||||
RequestType::ContextResolver => {
|
||||
self.context_resolver_handler(body, callout_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_entity_details(prompt_target: &PromptTarget) -> Vec<EntityDetail> {
|
||||
match prompt_target.entities.as_ref() {
|
||||
Some(EntityType::Vec(entity_names)) => {
|
||||
let mut entity_details: Vec<EntityDetail> = Vec::new();
|
||||
for entity_name in entity_names {
|
||||
entity_details.push(EntityDetail {
|
||||
name: entity_name.clone(),
|
||||
required: Some(true),
|
||||
description: None,
|
||||
});
|
||||
}
|
||||
entity_details
|
||||
}
|
||||
Some(EntityType::Struct(entity_details)) => entity_details.clone(),
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
|
@ -90,7 +90,7 @@ fn successful_request_to_open_ai_chat_completions() {
|
|||
.returning(Some(chat_completions_request_body))
|
||||
// TODO: assert that the model field was added.
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,17 +13,17 @@
|
|||
"path": "embedding-server"
|
||||
},
|
||||
{
|
||||
"name": "chatbot-client",
|
||||
"path": "chatbot-client"
|
||||
"name": "chatbot-ui",
|
||||
"path": "chatbot-ui"
|
||||
},
|
||||
{
|
||||
"name": "open-message-format",
|
||||
"path": "open-message-format"
|
||||
},
|
||||
{
|
||||
"name": "demos",
|
||||
"path": "./demos"
|
||||
},
|
||||
"name": "demos/weather-forecast",
|
||||
"path": "./demos/weather-forecast",
|
||||
}
|
||||
],
|
||||
"settings": {}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue