diff --git a/.gitignore b/.gitignore index f67b2130..317c0d11 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,9 @@ envoyfilter/target envoyfilter/qdrant_data/ embedding-server/venv/ +chatbot-ui/venv/ +__pycache__ +grafana-data +prom_data +.env +qdrant_data diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d23b654f..c6548745 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,11 +16,11 @@ repos: name: cargo-clippy language: system types: [file, rust] - entry: bash -c "cd envoyfilter && cargo clippy --all" + entry: bash -c "cd envoyfilter && cargo clippy -p intelligent-prompt-gateway --all" - id: cargo-test name: cargo-test language: system types: [file, rust] # --lib is to only test the library, since when integration tests are made, # they will be in a seperate tests directory - entry: bash -c "cd envoyfilter && cargo test --lib" + entry: bash -c "cd envoyfilter && cargo test -p intelligent-prompt-gateway --lib" diff --git a/chatbot-ui/Dockerfile b/chatbot-ui/Dockerfile new file mode 100644 index 00000000..f05aad18 --- /dev/null +++ b/chatbot-ui/Dockerfile @@ -0,0 +1,24 @@ +# copied from https://github.com/bergos/embedding-server + +FROM python:3 AS base + +# +# builder +# +FROM base AS builder + +WORKDIR /src + +COPY requirements.txt /src/ +RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt + +COPY . /src + +FROM python:3-slim AS output + +COPY --from=builder /runtime /usr/local + +COPY /app /app +WORKDIR /app + +CMD ["python", "run.py"] diff --git a/chatbot-ui/app/run.py b/chatbot-ui/app/run.py new file mode 100644 index 00000000..6dc8af9d --- /dev/null +++ b/chatbot-ui/app/run.py @@ -0,0 +1,81 @@ +import gradio as gr + +import asyncio +import httpx +import async_timeout + +from loguru import logger +from typing import Optional, List +from pydantic import BaseModel +from dotenv import load_dotenv + +import os +load_dotenv() + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1/chat/completions") + +class Message(BaseModel): + role: str + content: str + +async def make_completion(messages:List[Message], nb_retries:int=3, delay:int=30) -> Optional[str]: + """ + Sends a request to the ChatGPT API to retrieve a response based on a list of previous messages. + """ + header = { + "Content-Type": "application/json", + "Authorization": f"Bearer {OPENAI_API_KEY}" + } + try: + async with async_timeout.timeout(delay=delay): + async with httpx.AsyncClient(headers=header) as aio_client: + counter = 0 + keep_loop = True + while keep_loop: + logger.debug(f"Chat/Completions Nb Retries : {counter}") + try: + resp = await aio_client.post( + url = CHAT_COMPLETION_ENDPOINT, + json = { + "model": "gpt-3.5-turbo", + "messages": messages + } + ) + logger.debug(f"Status Code : {resp.status_code}") + if resp.status_code == 200: + return resp.json()["choices"][0]["message"]["content"] + else: + logger.warning(resp.content) + keep_loop = False + except Exception as e: + logger.error(e) + counter = counter + 1 + keep_loop = counter < nb_retries + except asyncio.TimeoutError as e: + logger.error(f"Timeout {delay} seconds !") + return None + +async def predict(input, history): + """ + Predict the response of the chatbot and complete a running list of chat history. + """ + history.append({"role": "user", "content": input}) + print(history) + response = await make_completion(history) + history.append({"role": "assistant", "content": response}) + messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)] + return messages, history + +""" +Gradio Blocks low-level API that allows to create custom web applications (here our chat app) +""" +with gr.Blocks() as demo: + logger.info("Starting Demo...") + chatbot = gr.Chatbot(label="WebGPT") + state = gr.State([]) + with gr.Row(): + txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter") + txt.submit(predict, [txt, state], [chatbot, state]) + +demo.launch(server_name="0.0.0.0", server_port=8080) diff --git a/chatbot-ui/requirements.txt b/chatbot-ui/requirements.txt new file mode 100644 index 00000000..732b77c8 --- /dev/null +++ b/chatbot-ui/requirements.txt @@ -0,0 +1,6 @@ +gradio==4.39.0 +async_timeout==4.0.3 +loguru==0.7.2 +asyncio==3.4.3 +httpx==0.27.0 +python-dotenv==1.0.1 diff --git a/demos/weather-forecast/README.md b/demos/weather-forecast/README.md new file mode 100644 index 00000000..e0127bf1 --- /dev/null +++ b/demos/weather-forecast/README.md @@ -0,0 +1,15 @@ +# Weather forecasting +This demo shows how you can use intelligent prompt gateway to provide realtime weather forecast. + +# Startig the demo +1. Create `.env` file and set OpenAI key using env var `OPENAI_API_KEY` +1. Start services + ```sh + $ docker compose up + ``` +1. Navigate to http://localhost:18080/ +1. You can type in queries like "how is the weather in Seattle" + 1. You can also ask follow up questions like "show me sunny days" +2. To see metrics navigate to "http://localhost:3000/" (use admin/grafana for login) + 1. Open up dahsboard named "Intelligent Gateway Overview" + 2. On this dashboard you can see reuqest latency and number of requests diff --git a/demos/weather-forecast/docker-compose.yaml b/demos/weather-forecast/docker-compose.yaml new file mode 100644 index 00000000..e4e64350 --- /dev/null +++ b/demos/weather-forecast/docker-compose.yaml @@ -0,0 +1,85 @@ +services: + envoy: + build: + context: ../../ + dockerfile: envoyfilter/Dockerfile + hostname: envoy + ports: + - "10000:10000" + - "19901:9901" + volumes: + - ./envoy.yaml:/etc/envoy/envoy.yaml + - /etc/ssl/cert.pem:/etc/ssl/cert.pem + networks: + - envoymesh + depends_on: + embeddingserver: + condition: service_healthy + + embeddingserver: + build: + context: ../../embedding-server + dockerfile: Dockerfile + ports: + - "18081:80" + healthcheck: + test: ["CMD", "curl" ,"http://localhost:80/healthz"] + interval: 5s + retries: 20 + networks: + - envoymesh + + qdrant: + image: qdrant/qdrant + hostname: vector-db + ports: + - 16333:6333 + - 16334:6334 + networks: + - envoymesh + + chatbot-ui: + build: + context: ../../chatbot-ui + dockerfile: Dockerfile + ports: + - "18080:8080" + networks: + - envoymesh + environment: + - OPENAI_API_KEY=${OPENAI_API_KEY} + - CHAT_COMPLETION_ENDPOINT=http://envoy:10000/v1/chat/completions + + prometheus: + image: prom/prometheus + container_name: prometheus + command: + - '--config.file=/etc/prometheus/prometheus.yaml' + ports: + - 9090:9090 + restart: unless-stopped + volumes: + - ./prometheus:/etc/prometheus + - ./prom_data:/prometheus + networks: + - envoymesh + + grafana: + image: grafana/grafana + container_name: grafana + ports: + - 3000:3000 + restart: unless-stopped + environment: + - GF_SECURITY_ADMIN_USER=admin + - GF_SECURITY_ADMIN_PASSWORD=grafana + volumes: + - ./grafana:/etc/grafana/provisioning/datasources + - ./grafana/dashboard.yaml:/etc/grafana/provisioning/dashboards/main.yaml + - ./grafana/dashboards:/var/lib/grafana/dashboards + # - ./grafana-data:/var/lib/grafana + networks: + - envoymesh + +networks: + envoymesh: {} diff --git a/demos/weather-forecast/envoy.yaml b/demos/weather-forecast/envoy.yaml new file mode 100644 index 00000000..3da846ea --- /dev/null +++ b/demos/weather-forecast/envoy.yaml @@ -0,0 +1,197 @@ +admin: + address: + socket_address: { address: 0.0.0.0, port_value: 9901 } +static_resources: + listeners: + address: + socket_address: + address: 0.0.0.0 + port_value: 10000 + filter_chains: + - filters: + - name: envoy.filters.network.http_connection_manager + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + stat_prefix: ingress_http + codec_type: AUTO + scheme_header_transformation: + scheme_to_overwrite: https + route_config: + name: local_routes + virtual_hosts: + - name: openai + domains: + - "api.openai.com" + routes: + - match: + prefix: "/" + route: + auto_host_rewrite: true + cluster: openai + - name: local_service + domains: + - "*" + routes: + - match: + prefix: "/v1/chat/completions" + route: + auto_host_rewrite: true + cluster: openai + - match: + prefix: "/embeddings" + route: + cluster: embeddingserver + - match: + prefix: "/" + direct_response: + status: 200 + body: + inline_string: "Inspect the HTTP header: custom-header.\n" + http_filters: + - name: envoy.filters.http.wasm + typed_config: + "@type": type.googleapis.com/udpa.type.v1.TypedStruct + type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm + value: + config: + name: "http_config" + configuration: + "@type": "type.googleapis.com/google.protobuf.StringValue" + value: | + katanemo-prompt-config: + default-prompt-endpoint: "127.0.0.1" + load-balancing: "round-robin" + timeout-ms: 5000 + + embedding-provider: + name: "SentenceTransformer" + model: "all-MiniLM-L6-v2" + + llm-providers: + + - name: "open-ai-gpt-4" + api-key: "$OPEN_AI_API_KEY" + model: gpt-4 + + prompt-targets: + + - type: context-resolver + name: weather-forecast + few-shot-examples: + - what is the weather in New York? + - how is the weather in San Francisco? + - what is the forecast in Seattle? + entities: + - name: city + required: true + - name: days + endpoint: + cluster: weatherhost + path: /weather + cache-response: true + cache-response-settings: + - cache-ttl-secs: 3600 # cache expiry in seconds + - cache-max-size: 1000 # in number of items + - cache-eviction-strategy: LRU + system-prompt: | + You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: + - Use farenheight for temperature + - Use miles per hour for wind speed + + vm_config: + runtime: "envoy.wasm.runtime.v8" + code: + local: + filename: "/etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm" + - name: envoy.filters.http.router + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router + clusters: + # LLM Host + # Embedding Providers + # External LLM Providers + - name: openai + connect_timeout: 5s + type: LOGICAL_DNS + lb_policy: ROUND_ROBIN + typed_extension_protocol_options: + envoy.extensions.upstreams.http.v3.HttpProtocolOptions: + "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions + explicit_http_config: + http2_protocol_options: {} + load_assignment: + cluster_name: openai + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: api.openai.com + port_value: 443 + hostname: "api.openai.com" + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + sni: api.openai.com + common_tls_context: + tls_params: + tls_minimum_protocol_version: TLSv1_2 + tls_maximum_protocol_version: TLSv1_3 + + - name: embeddingserver + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: embeddingserver + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: embeddingserver + port_value: 80 + hostname: "embeddingserver" + - name: weatherhost + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: weatherhost + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: embeddingserver + port_value: 80 + hostname: "embeddingserver" + - name: nerhost + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: nerhost + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: embeddingserver + port_value: 80 + hostname: "embeddingserver" + - name: qdrant + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: qdrant + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: qdrant + port_value: 6333 + hostname: "qdrant" diff --git a/demos/weather-forecast/grafana/dashboard.yaml b/demos/weather-forecast/grafana/dashboard.yaml new file mode 100644 index 00000000..fd66a479 --- /dev/null +++ b/demos/weather-forecast/grafana/dashboard.yaml @@ -0,0 +1,12 @@ +apiVersion: 1 + +providers: + - name: "Dashboard provider" + orgId: 1 + type: file + disableDeletion: false + updateIntervalSeconds: 10 + allowUiUpdates: false + options: + path: /var/lib/grafana/dashboards + foldersFromFilesStructure: true diff --git a/demos/weather-forecast/grafana/dashboards/envoy_overview.json b/demos/weather-forecast/grafana/dashboards/envoy_overview.json new file mode 100644 index 00000000..51bff777 --- /dev/null +++ b/demos/weather-forecast/grafana/dashboards/envoy_overview.json @@ -0,0 +1,355 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 1, + "links": [], + "panels": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 0 + }, + "id": 2, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "code", + "expr": "avg(rate(envoy_cluster_internal_upstream_rq_time_sum[1m]) / rate(envoy_cluster_internal_upstream_rq_time_count[1m])) by (envoy_cluster_name)", + "fullMetaSearch": false, + "hide": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "request latency - internal (ms)", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 0 + }, + "id": 1, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "code", + "expr": "avg(rate(envoy_cluster_external_upstream_rq_time_sum[1m]) / rate(envoy_cluster_external_upstream_rq_time_count[1m])) by (envoy_cluster_name)", + "fullMetaSearch": false, + "hide": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "request latency - external (ms)", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, + "id": 3, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "code", + "expr": "avg(rate(envoy_cluster_internal_upstream_rq_completed[1m])) by (envoy_cluster_name)", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "code", + "expr": "avg(rate(envoy_cluster_external_upstream_rq_completed[1m])) by (envoy_cluster_name)", + "fullMetaSearch": false, + "hide": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "__auto", + "range": true, + "refId": "B", + "useBackend": false + } + ], + "title": "Upstream request count", + "type": "timeseries" + } + ], + "schemaVersion": 39, + "tags": [], + "templating": { + "list": [] + }, + "time": { + "from": "now-15m", + "to": "now" + }, + "timepicker": {}, + "timezone": "browser", + "title": "Intelligent Gateway Overview", + "uid": "adt6uhx5lk8aob", + "version": 3, + "weekStart": "" +} diff --git a/demos/weather-forecast/grafana/datasource.yaml b/demos/weather-forecast/grafana/datasource.yaml new file mode 100644 index 00000000..4870174e --- /dev/null +++ b/demos/weather-forecast/grafana/datasource.yaml @@ -0,0 +1,9 @@ +apiVersion: 1 + +datasources: +- name: Prometheus + type: prometheus + url: http://prometheus:9090 + isDefault: true + access: proxy + editable: true diff --git a/demos/weather-forecast/prometheus/prometheus.yaml b/demos/weather-forecast/prometheus/prometheus.yaml new file mode 100644 index 00000000..5aa25e0d --- /dev/null +++ b/demos/weather-forecast/prometheus/prometheus.yaml @@ -0,0 +1,23 @@ +global: + scrape_interval: 15s + scrape_timeout: 10s + evaluation_interval: 15s +alerting: + alertmanagers: + - static_configs: + - targets: [] + scheme: http + timeout: 10s + api_version: v1 +scrape_configs: +- job_name: envoy + honor_timestamps: true + scrape_interval: 15s + scrape_timeout: 10s + metrics_path: /stats + scheme: http + static_configs: + - targets: + - envoy:9901 + params: + format: ['prometheus'] diff --git a/embedding-server/Dockerfile b/embedding-server/Dockerfile index 0ec28ba7..85c0018f 100644 --- a/embedding-server/Dockerfile +++ b/embedding-server/Dockerfile @@ -24,6 +24,7 @@ FROM python:3-slim AS output # following models have been tested to work with this image # "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small" ENV MODELS="BAAI/bge-large-en-v1.5" +ENV NER_MODELS="urchade/gliner_large-v2.1" COPY --from=builder /runtime /usr/local diff --git a/embedding-server/app/install.py b/embedding-server/app/install.py index 15cacc91..ad6ecb10 100644 --- a/embedding-server/app/install.py +++ b/embedding-server/app/install.py @@ -1,3 +1,6 @@ -from load_transformers import load_transformers +from load_models import load_transformers, load_ner_models +print('installing transformers') load_transformers() +print('installing ner models') +load_ner_models() diff --git a/embedding-server/app/load_transformers.py b/embedding-server/app/load_models.py similarity index 54% rename from embedding-server/app/load_transformers.py rename to embedding-server/app/load_models.py index 052f7e0e..1e8bc7cf 100644 --- a/embedding-server/app/load_transformers.py +++ b/embedding-server/app/load_models.py @@ -1,5 +1,6 @@ import os import sentence_transformers +from gliner import GLiNER def load_transformers(models = os.getenv("MODELS", "sentence-transformers/all-MiniLM-L6-v2")): transformers = {} @@ -8,3 +9,11 @@ def load_transformers(models = os.getenv("MODELS", "sentence-transformers/all-Mi transformers[model] = sentence_transformers.SentenceTransformer(model) return transformers + +def load_ner_models(models = os.getenv("NER_MODELS", "urchade/gliner_large-v2.1")): + ner_models = {} + + for model in models.split(','): + ner_models[model] = GLiNER.from_pretrained(model) + + return ner_models diff --git a/embedding-server/app/main.py b/embedding-server/app/main.py index 05843a45..1328872f 100644 --- a/embedding-server/app/main.py +++ b/embedding-server/app/main.py @@ -1,8 +1,11 @@ +import random from fastapi import FastAPI, Response, HTTPException from pydantic import BaseModel -from load_transformers import load_transformers +from load_models import load_ner_models, load_transformers +from datetime import date, timedelta transformers = load_transformers() +ner_models = load_ner_models() app = FastAPI() @@ -10,6 +13,12 @@ class EmbeddingRequest(BaseModel): input: str model: str +@app.get("/healthz") +async def healthz(): + return { + "status": "ok" + } + @app.get("/models") async def models(): models = [] @@ -27,7 +36,7 @@ async def models(): @app.post("/embeddings") async def embedding(req: EmbeddingRequest, res: Response): - if not req.model in transformers: + if req.model not in transformers: raise HTTPException(status_code=400, detail="unknown model: " + req.model) embeddings = transformers[req.model].encode([req.input]) @@ -51,3 +60,48 @@ async def embedding(req: EmbeddingRequest, res: Response): "object": "list", "usage": usage } + +class NERRequest(BaseModel): + input: str + labels: list[str] + model: str + + +@app.post("/ner") +async def ner(req: NERRequest, res: Response): + if req.model not in ner_models: + raise HTTPException(status_code=400, detail="unknown model: " + req.model) + + model = ner_models[req.model] + entities = model.predict_entities(req.input, req.labels) + + return { + "data": entities, + "model": req.model, + "object": "list", + } + +class WeatherRequest(BaseModel): + city: str + + +@app.post("/weather") +async def weather(req: WeatherRequest, res: Response): + + weather_forecast = { + "city": req.city, + "temperature": [], + "unit": "F", + } + for i in range(7): + min_temp = random.randrange(50,90) + max_temp = random.randrange(min_temp+5, min_temp+20) + weather_forecast["temperature"].append({ + "date": str(date.today() + timedelta(days=i)), + "temperature": { + "min": min_temp, + "max": max_temp + } + }) + + return weather_forecast diff --git a/embedding-server/requirements.txt b/embedding-server/requirements.txt index 613aa60e..231aca9c 100644 --- a/embedding-server/requirements.txt +++ b/embedding-server/requirements.txt @@ -3,3 +3,4 @@ fastapi sentence-transformers torch uvicorn +gliner diff --git a/envoyfilter/Cargo.lock b/envoyfilter/Cargo.lock index 93c9e116..fcfbea66 100644 --- a/envoyfilter/Cargo.lock +++ b/envoyfilter/Cargo.lock @@ -1621,9 +1621,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" +checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" dependencies = [ "serde", ] @@ -1936,9 +1936,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.15" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac2caab0bf757388c6c0ae23b3293fdb463fee59434529014f85e3263b995c28" +checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c" dependencies = [ "serde", "serde_spanned", @@ -1948,18 +1948,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.16" +version = "0.22.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "278f3d518e152219c994ce877758516bca5e118eaed6996192a774fb9fbf0788" +checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16" dependencies = [ "indexmap", "serde", @@ -2121,9 +2121,9 @@ checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "want" @@ -2730,9 +2730,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.15" +version = "0.6.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "557404e450152cd6795bb558bca69e43c585055f4606e3bcae5894fc6dac9ba0" +checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c" dependencies = [ "memchr", ] diff --git a/envoyfilter/Dockerfile b/envoyfilter/Dockerfile new file mode 100644 index 00000000..2fa3a064 --- /dev/null +++ b/envoyfilter/Dockerfile @@ -0,0 +1,16 @@ +# build filter using rust toolchain +FROM rust:1.80.0 as builder +WORKDIR /envoyfilter +COPY envoyfilter/src /envoyfilter/src +COPY envoyfilter/Cargo.toml /envoyfilter/ +COPY envoyfilter/Cargo.lock /envoyfilter/ +COPY open-message-format /open-message-format + +RUN rustup -v target add wasm32-wasi +RUN cargo build --release --target wasm32-wasi + +# copy built filter into envoy image +FROM envoyproxy/envoy:v1.30-latest +COPY --from=builder /envoyfilter/target/wasm32-wasi/release/intelligent_prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/intelligent_prompt_gateway.wasm +COPY envoyfilter/envoy.yaml /etc/envoy.yaml +CMD ["envoy", "-c", "/etc/envoy/envoy.yaml"] diff --git a/envoyfilter/build_filter.sh b/envoyfilter/build_filter.sh new file mode 100644 index 00000000..ff42dede --- /dev/null +++ b/envoyfilter/build_filter.sh @@ -0,0 +1,3 @@ +RUST_VERSION=1.80.0 +docker run --rm -v rustup_cache:/usr/local/rustup/ rust:$RUST_VERSION rustup -v target add wasm32-wasi +docker run --rm -v $PWD/../open-message-format:/code/open-message-format -v ~/.cargo:/root/.cargo -v $(pwd):/code/envoyfilter -w /code/envoyfilter -v rustup_cache:/usr/local/rustup/ rust:$RUST_VERSION cargo build --release --target wasm32-wasi diff --git a/envoyfilter/docker-compose.yaml b/envoyfilter/docker-compose.yaml index 426ac4d5..0198a883 100644 --- a/envoyfilter/docker-compose.yaml +++ b/envoyfilter/docker-compose.yaml @@ -22,7 +22,7 @@ services: ports: - "18080:80" healthcheck: - test: ["CMD", "curl" ,"http://localhost:80"] + test: ["CMD", "curl" ,"http://localhost:80/healthz"] interval: 5s retries: 20 networks: diff --git a/envoyfilter/envoy.yaml b/envoyfilter/envoy.yaml index 65ce61eb..3e2a7d11 100644 --- a/envoyfilter/envoy.yaml +++ b/envoyfilter/envoy.yaml @@ -41,10 +41,6 @@ static_resources: prefix: "/embeddings" route: cluster: embeddingserver - - match: - prefix: "/inline" - route: - cluster: httpbin - match: prefix: "/" direct_response: @@ -78,7 +74,7 @@ static_resources: model: gpt-4 system-prompt: | - You are a helpful weather forecaster. Please following following guidelines when responding to user queries: + You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: - Use farenheight for temperature - Use miles per hour for wind speed @@ -88,12 +84,22 @@ static_resources: name: weather-forecast few-shot-examples: - what is the weather in New York? - endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT" + - how is the weather in San Francisco? + - what is the forecast in Seattle? + entities: + - city + endpoint: + cluster: weatherhost + path: /weather cache-response: true cache-response-settings: - cache-ttl-secs: 3600 # cache expiry in seconds - cache-max-size: 1000 # in number of items - cache-eviction-strategy: LRU + system-prompt: | + You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: + - Use farenheight for temperature + - Use miles per hour for wind speed vm_config: runtime: "envoy.wasm.runtime.v8" @@ -136,20 +142,6 @@ static_resources: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 - - name: httpbin - connect_timeout: 5s - type: STRICT_DNS - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: httpbin - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: httpbin.org - port_value: 80 - hostname: "httpbin.org" - name: embeddingserver connect_timeout: 5s type: STRICT_DNS @@ -164,6 +156,34 @@ static_resources: address: embeddingserver port_value: 80 hostname: "embeddingserver" + - name: weatherhost + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: weatherhost + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: embeddingserver + port_value: 80 + hostname: "embeddingserver" + - name: nerhost + connect_timeout: 5s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: nerhost + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: embeddingserver + port_value: 80 + hostname: "embeddingserver" - name: qdrant connect_timeout: 5s type: STRICT_DNS diff --git a/envoyfilter/grafana/datasource.yaml b/envoyfilter/grafana/datasource.yaml new file mode 100644 index 00000000..4870174e --- /dev/null +++ b/envoyfilter/grafana/datasource.yaml @@ -0,0 +1,9 @@ +apiVersion: 1 + +datasources: +- name: Prometheus + type: prometheus + url: http://prometheus:9090 + isDefault: true + access: proxy + editable: true diff --git a/envoyfilter/prometheus/prometheus.yaml b/envoyfilter/prometheus/prometheus.yaml new file mode 100644 index 00000000..5aa25e0d --- /dev/null +++ b/envoyfilter/prometheus/prometheus.yaml @@ -0,0 +1,23 @@ +global: + scrape_interval: 15s + scrape_timeout: 10s + evaluation_interval: 15s +alerting: + alertmanagers: + - static_configs: + - targets: [] + scheme: http + timeout: 10s + api_version: v1 +scrape_configs: +- job_name: envoy + honor_timestamps: true + scrape_interval: 15s + scrape_timeout: 10s + metrics_path: /stats + scheme: http + static_configs: + - targets: + - envoy:9901 + params: + format: ['prometheus'] diff --git a/envoyfilter/src/common_types.rs b/envoyfilter/src/common_types.rs index 4ee8f7fa..3d82eca4 100644 --- a/envoyfilter/src/common_types.rs +++ b/envoyfilter/src/common_types.rs @@ -23,9 +23,53 @@ pub struct StoreVectorEmbeddingsRequest { } #[derive(Debug, Clone, Serialize, Deserialize)] +#[allow(clippy::large_enum_variant)] pub enum CallContext { EmbeddingRequest(EmbeddingRequest), StoreVectorEmbeddings(StoreVectorEmbeddingsRequest), + CreateVectorCollection(String), +} + +// https://api.qdrant.tech/master/api-reference/search/points +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchPointsRequest { + pub vector: Vec, + pub limit: i32, + pub with_payload: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchPointResult { + pub id: String, + pub version: i32, + pub score: f64, + pub payload: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchPointsResponse { + pub result: Vec, + pub status: String, + pub time: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NERRequest { + pub input: String, + pub labels: Vec, + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Entity { + pub text: String, + pub label: String, + pub score: f64, +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NERResponse { + pub data: Vec, + pub model: String, } pub mod open_ai { @@ -41,6 +85,6 @@ pub mod open_ai { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { pub role: String, - pub content: String, + pub content: Option, } } diff --git a/envoyfilter/src/configuration.rs b/envoyfilter/src/configuration.rs index e01162f3..ea4c5f4b 100644 --- a/envoyfilter/src/configuration.rs +++ b/envoyfilter/src/configuration.rs @@ -28,7 +28,7 @@ pub struct PromptConfig { pub timeout_ms: u64, pub embedding_provider: EmbeddingProviver, pub llm_providers: Vec, - pub system_prompt: String, + pub system_prompt: Option, pub prompt_targets: Vec, } @@ -49,6 +49,29 @@ pub struct LlmProvider { pub model: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct Endpoint { + pub cluster: String, + pub path: Option, + pub method: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct EntityDetail { + pub name: String, + pub required: Option, + pub description: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum EntityType { + Vec(Vec), + Struct(Vec), +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct PromptTarget { @@ -56,7 +79,9 @@ pub struct PromptTarget { pub prompt_type: String, pub name: String, pub few_shot_examples: Vec, - pub endpoint: String, + pub entities: Option, + pub endpoint: Option, + pub system_prompt: Option, } #[cfg(test)] @@ -88,13 +113,23 @@ katanemo-prompt-config: name: weather-forecast few-shot-examples: - what is the weather in New York? - endpoint: "POST:$WEATHER_FORECAST_API_ENDPOINT" - cache-response: true - cache-response-settings: - - cache-ttl-secs: 3600 # cache expiry in seconds - - cache-max-size: 1000 # in number of items - - cache-eviction-strategy: LRU + endpoint: + cluster: weatherhost + path: /weather + entities: + - name: location + required: true + description: "The location for which the weather is requested" + - type: context-resolver + name: weather-forecast-2 + few-shot-examples: + - what is the weather in New York? + endpoint: + cluster: weatherhost + path: /weather + entities: + - city "#; #[test] diff --git a/envoyfilter/src/consts.rs b/envoyfilter/src/consts.rs index a403fbb8..0f844023 100644 --- a/envoyfilter/src/consts.rs +++ b/envoyfilter/src/consts.rs @@ -1 +1,7 @@ pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5"; +pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store"; +pub const DEFAULT_NER_MODEL: &str = "urchade/gliner_large-v2.1"; +pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6; +pub const DEFAULT_NER_THRESHOLD: f64 = 0.6; +pub const SYSTEM_ROLE: &str = "system"; +pub const USER_ROLE: &str = "user"; diff --git a/envoyfilter/src/filter_context.rs b/envoyfilter/src/filter_context.rs new file mode 100644 index 00000000..21654a8e --- /dev/null +++ b/envoyfilter/src/filter_context.rs @@ -0,0 +1,289 @@ +use common_types::{CallContext, EmbeddingRequest}; +use configuration::PromptTarget; +use log::info; +use md5::Digest; +use open_message_format::models::{ + CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, +}; +use serde_json::to_string; +use stats::RecordingMetric; +use std::collections::HashMap; +use std::time::Duration; + +use consts::DEFAULT_EMBEDDING_MODEL; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; + +use crate::common_types; +use crate::configuration; +use crate::consts; +use crate::stats; +use crate::stream_context::StreamContext; + +#[derive(Copy, Clone)] +struct WasmMetrics { + active_http_calls: stats::Gauge, +} + +impl WasmMetrics { + fn new() -> WasmMetrics { + WasmMetrics { + active_http_calls: stats::Gauge::new(String::from("active_http_calls")), + } + } +} + +pub struct FilterContext { + metrics: WasmMetrics, + // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. + callouts: HashMap, + config: Option, +} + +impl FilterContext { + pub fn new() -> FilterContext { + FilterContext { + callouts: HashMap::new(), + config: None, + metrics: WasmMetrics::new(), + } + } + + fn process_prompt_targets(&mut self) { + for prompt_target in &self.config.as_ref().unwrap().prompt_config.prompt_targets { + for few_shot_example in &prompt_target.few_shot_examples { + let embeddings_input = CreateEmbeddingRequest { + input: Box::new(CreateEmbeddingRequestInput::String( + few_shot_example.to_string(), + )), + model: String::from(DEFAULT_EMBEDDING_MODEL), + encoding_format: None, + dimensions: None, + user: None, + }; + + // TODO: Handle potential errors + let json_data: String = to_string(&embeddings_input).unwrap(); + + let token_id = match self.dispatch_http_call( + "embeddingserver", + vec![ + (":method", "POST"), + (":path", "/embeddings"), + (":authority", "embeddingserver"), + ("content-type", "application/json"), + ], + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!("Error dispatching HTTP call: {:?}", e); + } + }; + let embedding_request = EmbeddingRequest { + create_embedding_request: embeddings_input, + prompt_target: prompt_target.clone(), + }; + if self + .callouts + .insert(token_id, { + CallContext::EmbeddingRequest(embedding_request) + }) + .is_some() + { + panic!("duplicate token_id") + } + self.metrics + .active_http_calls + .record(self.callouts.len().try_into().unwrap()); + } + } + } + + fn embedding_request_handler( + &mut self, + body_size: usize, + create_embedding_request: CreateEmbeddingRequest, + prompt_target: PromptTarget, + ) { + if let Some(body) = self.get_http_call_response_body(0, body_size) { + if !body.is_empty() { + let embedding_response: CreateEmbeddingResponse = + serde_json::from_slice(&body).unwrap(); + + let mut payload: HashMap = HashMap::new(); + payload.insert( + "prompt-target".to_string(), + to_string(&prompt_target).unwrap(), + ); + let id: Option; + match *create_embedding_request.input { + CreateEmbeddingRequestInput::String(input) => { + id = Some(md5::compute(&input)); + payload.insert("input".to_string(), input); + } + CreateEmbeddingRequestInput::Array(_) => todo!(), + } + + let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest { + points: vec![common_types::VectorPoint { + id: format!("{:x}", id.unwrap()), + payload, + vector: embedding_response.data[0].embedding.clone(), + }], + }; + let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors + let token_id = match self.dispatch_http_call( + "qdrant", + vec![ + (":method", "PUT"), + (":path", "/collections/prompt_vector_store/points"), + (":authority", "qdrant"), + ("content-type", "application/json"), + ], + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!("Error dispatching HTTP call: {:?}", e); + } + }; + + if self + .callouts + .insert( + token_id, + CallContext::StoreVectorEmbeddings(create_vector_store_points), + ) + .is_some() + { + panic!("duplicate token_id") + } + self.metrics + .active_http_calls + .record(self.callouts.len().try_into().unwrap()); + } + } + } + + fn create_vector_store_points_handler(&self, body_size: usize) { + if let Some(body) = self.get_http_call_response_body(0, body_size) { + if !body.is_empty() { + info!( + "response body: len {:?}", + String::from_utf8(body).unwrap().len() + ); + } + } + } + + //TODO: run once per envoy instance, right now it runs once per worker + fn init_vector_store(&mut self) { + let token_id = match self.dispatch_http_call( + "qdrant", + vec![ + (":method", "PUT"), + (":path", "/collections/prompt_vector_store"), + (":authority", "qdrant"), + ("content-type", "application/json"), + ], + Some(b"{ \"vectors\": { \"size\": 1024, \"distance\": \"Cosine\"}}"), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!("Error dispatching HTTP call for init-vector-store: {:?}", e); + } + }; + if self + .callouts + .insert( + token_id, + CallContext::CreateVectorCollection("prompt_vector_store".to_string()), + ) + .is_some() + { + panic!("duplicate token_id") + } + // self.metrics + // .active_http_calls + // .record(self.callouts.len().try_into().unwrap()); + } +} + +impl Context for FilterContext { + fn on_http_call_response( + &mut self, + token_id: u32, + _num_headers: usize, + body_size: usize, + _num_trailers: usize, + ) { + let callout_data = self.callouts.remove(&token_id).expect("invalid token_id"); + + self.metrics + .active_http_calls + .record(self.callouts.len().try_into().unwrap()); + + match callout_data { + common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest { + create_embedding_request, + prompt_target, + }) => { + self.embedding_request_handler(body_size, create_embedding_request, prompt_target) + } + common_types::CallContext::StoreVectorEmbeddings(_) => { + self.create_vector_store_points_handler(body_size) + } + common_types::CallContext::CreateVectorCollection(_) => { + let mut http_status_code = "Nil".to_string(); + self.get_http_call_response_headers() + .iter() + .for_each(|(k, v)| { + if k == ":status" { + http_status_code.clone_from(v); + } + }); + info!("CreateVectorCollection response: {}", http_status_code); + } + } + } +} + +// RootContext allows the Rust code to reach into the Envoy Config +impl RootContext for FilterContext { + fn on_configure(&mut self, _: usize) -> bool { + if let Some(config_bytes) = self.get_plugin_configuration() { + self.config = serde_yaml::from_slice(&config_bytes).unwrap(); + } + true + } + + fn create_http_context(&self, _context_id: u32) -> Option> { + Some(Box::new(StreamContext { + host_header: None, + callouts: HashMap::new(), + })) + } + + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } + + fn on_vm_start(&mut self, _: usize) -> bool { + self.set_tick_period(Duration::from_secs(1)); + true + } + + fn on_tick(&mut self) { + // initialize vector store + self.init_vector_store(); + self.process_prompt_targets(); + self.set_tick_period(Duration::from_secs(0)); + } +} diff --git a/envoyfilter/src/lib.rs b/envoyfilter/src/lib.rs index 088e70e9..f62c6367 100644 --- a/envoyfilter/src/lib.rs +++ b/envoyfilter/src/lib.rs @@ -1,22 +1,13 @@ -use common_types::{CallContext, EmbeddingRequest}; -use configuration::PromptTarget; -use http::StatusCode; -use log::error; -use log::info; -use md5::Digest; -use open_message_format::models::{ - CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, -}; +use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -use stats::{Gauge, RecordingMetric}; -use std::collections::HashMap; -use std::time::Duration; mod common_types; mod configuration; mod consts; +mod filter_context; mod stats; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); @@ -24,397 +15,3 @@ proxy_wasm::main! {{ Box::new(FilterContext::new()) }); }} - -struct StreamContext { - host_header: Option, -} - -impl StreamContext { - fn save_host_header(&mut self) { - // Save the host header to be used by filter logic later on. - self.host_header = self.get_http_request_header(":host"); - } - - fn delete_content_length_header(&mut self) { - // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. - // Server's generally throw away requests whose body length do not match the Content-Length header. - // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could - // manipulate the body in benign ways e.g., compression. - self.set_http_request_header("content-length", None); - } - - fn modify_path_header(&mut self) { - match self.get_http_request_header(":path") { - // The gateway can start gathering information necessary for routing. For now change the path to an - // OpenAI API path. - Some(path) if path == "/llmrouting" => { - self.set_http_request_header(":path", Some("/v1/chat/completions")); - } - // Otherwise let the filter continue. - _ => (), - } - } -} - -// HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for StreamContext { - // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto - // the lifecycle of the http request and response. - fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - self.save_host_header(); - self.delete_content_length_header(); - self.modify_path_header(); - - Action::Continue - } - - fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - // Let the client send the gateway all the data before sending to the LLM_provider. - // TODO: consider a streaming API. - if !end_of_stream { - return Action::Pause; - } - - if body_size == 0 { - return Action::Continue; - } - - // Deserialize body into spec. - // Currently OpenAI API. - let mut deserialized_body: common_types::open_ai::ChatCompletions = - match self.get_http_request_body(0, body_size) { - Some(body_bytes) => match serde_json::from_slice(&body_bytes) { - Ok(deserialized) => deserialized, - Err(msg) => { - self.send_http_response( - StatusCode::BAD_REQUEST.as_u16().into(), - vec![], - Some(format!("Failed to deserialize: {}", msg).as_bytes()), - ); - return Action::Pause; - } - }, - None => { - self.send_http_response( - StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), - vec![], - None, - ); - error!( - "Failed to obtain body bytes even though body_size is {}", - body_size - ); - return Action::Pause; - } - }; - - // Modify JSON payload - deserialized_body.model = String::from("gpt-3.5-turbo"); - - match serde_json::to_string(&deserialized_body) { - Ok(json_string) => { - self.set_http_request_body(0, body_size, &json_string.into_bytes()); - Action::Continue - } - Err(error) => { - self.send_http_response( - StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), - vec![], - None, - ); - error!("Failed to serialize body: {}", error); - Action::Pause - } - } - } -} - -impl Context for StreamContext {} - -#[derive(Copy, Clone)] -struct WasmMetrics { - active_http_calls: Gauge, -} - -impl WasmMetrics { - fn new() -> WasmMetrics { - WasmMetrics { - active_http_calls: Gauge::new(String::from("active_http_calls")), - } - } -} - -struct FilterContext { - metrics: WasmMetrics, - // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. - callouts: HashMap, - config: Option, -} - -impl FilterContext { - fn new() -> FilterContext { - FilterContext { - callouts: HashMap::new(), - config: None, - metrics: WasmMetrics::new(), - } - } - - fn process_prompt_targets(&mut self) { - for prompt_target in &self - .config - .as_ref() - .expect("Gateway configuration cannot be non-existent") - .prompt_config - .prompt_targets - { - for few_shot_example in &prompt_target.few_shot_examples { - info!("few_shot_example: {:?}", few_shot_example); - - let embeddings_input = CreateEmbeddingRequest { - input: Box::new(CreateEmbeddingRequestInput::String( - few_shot_example.to_string(), - )), - model: String::from(consts::DEFAULT_EMBEDDING_MODEL), - encoding_format: None, - dimensions: None, - user: None, - }; - - let json_data: String = match serde_json::to_string(&embeddings_input) { - Ok(json_data) => json_data, - Err(error) => { - panic!("Error serializing embeddings input: {}", error); - } - }; - - let token_id = match self.dispatch_http_call( - "embeddingserver", - vec![ - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "embeddingserver"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ], - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - panic!("Error dispatching embedding server HTTP call: {:?}", e); - } - }; - info!("on_tick: dispatched HTTP call with token_id = {}", token_id); - let embedding_request = EmbeddingRequest { - create_embedding_request: embeddings_input, - prompt_target: prompt_target.clone(), - }; - if self - .callouts - .insert(token_id, { - CallContext::EmbeddingRequest(embedding_request) - }) - .is_some() - { - panic!( - "duplicate token_id={} in embedding server requests", - token_id - ) - } - self.metrics - .active_http_calls - .record(self.callouts.len().try_into().unwrap()); - } - } - } - - fn embedding_request_handler( - &mut self, - body_size: usize, - create_embedding_request: CreateEmbeddingRequest, - prompt_target: PromptTarget, - ) { - info!("response received for CreateEmbeddingRequest"); - let body = match self.get_http_call_response_body(0, body_size) { - Some(body) => body, - None => { - return; - } - }; - - if body.is_empty() { - return; - } - - let (json_data, create_vector_store_points) = - match build_qdrant_data(&body, create_embedding_request, &prompt_target) { - Ok(tuple) => tuple, - Err(error) => { - panic!( - "Error building qdrant data from embedding response {}", - error - ); - } - }; - - let token_id = match self.dispatch_http_call( - "qdrant", - vec![ - (":method", "PUT"), - (":path", "/collections/prompt_vector_store/points"), - (":authority", "qdrant"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ], - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - panic!("Error dispatching qdrant HTTP call: {:?}", e); - } - }; - info!("on_tick: dispatched HTTP call with token_id = {}", token_id); - - if self - .callouts - .insert( - token_id, - CallContext::StoreVectorEmbeddings(create_vector_store_points), - ) - .is_some() - { - panic!("duplicate token_id={} in qdrant requests", token_id) - } - - self.metrics - .active_http_calls - .record(self.callouts.len().try_into().unwrap()); - } - - // TODO: @adilhafeez implement. - fn create_vector_store_points_handler(&self, body_size: usize) { - info!("response received for CreateVectorStorePoints"); - if let Some(body) = self.get_http_call_response_body(0, body_size) { - if !body.is_empty() { - info!("response body: {:?}", String::from_utf8(body).unwrap()); - } - } - } -} - -impl Context for FilterContext { - fn on_http_call_response( - &mut self, - token_id: u32, - _num_headers: usize, - body_size: usize, - _num_trailers: usize, - ) { - info!("on_http_call_response: token_id = {}", token_id); - - let callout_data = self.callouts.remove(&token_id).expect("invalid token_id"); - - self.metrics - .active_http_calls - .record(self.callouts.len().try_into().unwrap()); - match callout_data { - common_types::CallContext::EmbeddingRequest(common_types::EmbeddingRequest { - create_embedding_request, - prompt_target, - }) => { - self.embedding_request_handler(body_size, create_embedding_request, prompt_target) - } - common_types::CallContext::StoreVectorEmbeddings(_) => { - self.create_vector_store_points_handler(body_size) - } - } - } -} - -// RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for FilterContext { - fn on_configure(&mut self, _: usize) -> bool { - if let Some(config_bytes) = self.get_plugin_configuration() { - self.config = match serde_yaml::from_slice(&config_bytes) { - Ok(config) => config, - Err(error) => { - panic!("Failed to deserialize plugin configuration: {}", error); - } - }; - info!("on_configure: plugin configuration loaded"); - } - true - } - - fn create_http_context(&self, _context_id: u32) -> Option> { - Some(Box::new(StreamContext { host_header: None })) - } - - fn get_type(&self) -> Option { - Some(ContextType::HttpContext) - } - - fn on_vm_start(&mut self, _: usize) -> bool { - info!("on_vm_start: setting up tick timeout"); - self.set_tick_period(Duration::from_secs(1)); - true - } - - fn on_tick(&mut self) { - info!("on_tick: starting to process prompt targets"); - self.process_prompt_targets(); - self.set_tick_period(Duration::from_secs(0)); - } -} - -fn build_qdrant_data( - embedding_data: &[u8], - create_embedding_request: CreateEmbeddingRequest, - prompt_target: &PromptTarget, -) -> Result<(String, common_types::StoreVectorEmbeddingsRequest), serde_json::Error> { - let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(embedding_data) { - Ok(embedding_response) => embedding_response, - Err(error) => { - panic!("Failed to deserialize embedding response: {}", error); - } - }; - info!( - "embedding_response model: {}, vector len: {}", - embedding_response.model, - embedding_response.data[0].embedding.len() - ); - - let mut payload: HashMap = HashMap::new(); - payload.insert( - "prompt-target".to_string(), - serde_json::to_string(&prompt_target)?, - ); - - let id: Option; - match *create_embedding_request.input { - CreateEmbeddingRequestInput::String(ref input) => { - id = Some(md5::compute(input)); - payload.insert("input".to_string(), input.clone()); - } - CreateEmbeddingRequestInput::Array(_) => todo!(), - } - - let create_vector_store_points = common_types::StoreVectorEmbeddingsRequest { - points: vec![common_types::VectorPoint { - id: format!("{:x}", id.unwrap()), - payload, - vector: embedding_response.data[0].embedding.clone(), - }], - }; - let json_data = serde_json::to_string(&create_vector_store_points)?; - info!( - "create_vector_store_points: points length: {}", - embedding_response.data[0].embedding.len() - ); - - Ok((json_data, create_vector_store_points)) -} diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs new file mode 100644 index 00000000..36d90edf --- /dev/null +++ b/envoyfilter/src/stream_context.rs @@ -0,0 +1,520 @@ +use http::StatusCode; +use log::error; +use log::info; +use log::warn; +use open_message_format::models::{ + CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, +}; +use std::collections::HashMap; +use std::time::Duration; + +use proxy_wasm::traits::*; +use proxy_wasm::types::*; + +use consts::{ + DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD, + DEFAULT_PROMPT_TARGET_THRESHOLD, SYSTEM_ROLE, USER_ROLE, +}; + +use crate::common_types; +use crate::common_types::open_ai::Message; +use crate::common_types::SearchPointsResponse; +use crate::configuration::EntityDetail; +use crate::configuration::EntityType; +use crate::configuration::PromptTarget; +use crate::consts; + +enum RequestType { + GetEmbedding, + SearchPoints, + Ner, + ContextResolver, +} + +pub struct CallContext { + request_type: RequestType, + user_message: String, + prompt_target: Option, + request_body: common_types::open_ai::ChatCompletions, +} + +pub struct StreamContext { + pub host_header: Option, + pub callouts: HashMap, +} + +impl StreamContext { + fn save_host_header(&mut self) { + // Save the host header to be used by filter logic later on. + self.host_header = self.get_http_request_header(":host"); + } + + fn delete_content_length_header(&mut self) { + // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. + // Server's generally throw away requests whose body length do not match the Content-Length header. + // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could + // manipulate the body in benign ways e.g., compression. + self.set_http_request_header("content-length", None); + } + + fn modify_path_header(&mut self) { + match self.get_http_request_header(":path") { + // The gateway can start gathering information necessary for routing. For now change the path to an + // OpenAI API path. + Some(path) if path == "/llmrouting" => { + self.set_http_request_header(":path", Some("/v1/chat/completions")); + } + // Otherwise let the filter continue. + _ => (), + } + } + + fn embeddings_handler(&mut self, body: Vec, mut callout_context: CallContext) { + let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { + Ok(embedding_response) => embedding_response, + Err(e) => { + warn!("Error deserializing embedding response: {:?}", e); + self.resume_http_request(); + return; + } + }; + + let search_points_request = common_types::SearchPointsRequest { + vector: embedding_response.data[0].embedding.clone(), + limit: 10, + with_payload: true, + }; + + let json_data: String = match serde_json::to_string(&search_points_request) { + Ok(json_data) => json_data, + Err(e) => { + warn!("Error serializing search_points_request: {:?}", e); + self.reset_http_request(); + return; + } + }; + + let path = format!("/collections/{}/points/search", DEFAULT_COLLECTION_NAME); + + let token_id = match self.dispatch_http_call( + "qdrant", + vec![ + (":method", "POST"), + (":path", &path), + (":authority", "qdrant"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ], + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!("Error dispatching HTTP call for get-embeddings: {:?}", e); + } + }; + + callout_context.request_type = RequestType::SearchPoints; + if self.callouts.insert(token_id, callout_context).is_some() { + panic!("duplicate token_id") + } + } + + fn search_points_handler(&mut self, body: Vec, mut callout_context: CallContext) { + let search_points_response: SearchPointsResponse = match serde_json::from_slice(&body) { + Ok(search_points_response) => search_points_response, + Err(e) => { + warn!("Error deserializing search_points_response: {:?}", e); + self.resume_http_request(); + return; + } + }; + + let search_results = &search_points_response.result; + + if search_results.is_empty() { + info!("No prompt target matched"); + self.resume_http_request(); + return; + } + + info!("similarity score: {}", search_results[0].score); + + if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD { + info!( + "prompt target below threshold: {}", + DEFAULT_PROMPT_TARGET_THRESHOLD + ); + self.resume_http_request(); + return; + } + let prompt_target_str = search_results[0].payload.get("prompt-target").unwrap(); + let prompt_target: PromptTarget = match serde_json::from_slice(prompt_target_str.as_bytes()) + { + Ok(prompt_target) => prompt_target, + Err(e) => { + warn!("Error deserializing prompt_target: {:?}", e); + self.resume_http_request(); + return; + } + }; + info!("prompt_target name: {:?}", prompt_target.name); + + // only extract entity names + let entity_names = get_entity_details(&prompt_target) + .iter() + .map(|entity| entity.name.clone()) + .collect(); + let user_message = callout_context.user_message.clone(); + let ner_request = common_types::NERRequest { + input: user_message, + labels: entity_names, + model: DEFAULT_NER_MODEL.to_string(), + }; + + let json_data: String = match serde_json::to_string(&ner_request) { + Ok(json_data) => json_data, + Err(e) => { + warn!("Error serializing ner_request: {:?}", e); + self.resume_http_request(); + return; + } + }; + + let token_id = match self.dispatch_http_call( + "nerhost", + vec![ + (":method", "POST"), + (":path", "/ner"), + (":authority", "nerhost"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ], + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!("Error dispatching HTTP call for get-embeddings: {:?}", e); + } + }; + callout_context.request_type = RequestType::Ner; + callout_context.prompt_target = Some(prompt_target); + if self.callouts.insert(token_id, callout_context).is_some() { + panic!("duplicate token_id") + } + } + + fn ner_handler(&mut self, body: Vec, mut callout_context: CallContext) { + let ner_response: common_types::NERResponse = match serde_json::from_slice(&body) { + Ok(ner_response) => ner_response, + Err(e) => { + warn!("Error deserializing ner_response: {:?}", e); + self.resume_http_request(); + return; + } + }; + info!("ner_response: {:?}", ner_response); + + let mut request_params: HashMap = HashMap::new(); + for entity in ner_response.data.iter() { + if entity.score < DEFAULT_NER_THRESHOLD { + warn!( + "score of entity was too low entity name: {}, score: {}", + entity.label, entity.score + ); + continue; + } + request_params.insert(entity.label.clone(), entity.text.clone()); + } + + let prompt_target = callout_context.prompt_target.as_ref().unwrap(); + let entity_details = get_entity_details(prompt_target); + for entity in entity_details { + if entity.required.unwrap_or(false) && !request_params.contains_key(&entity.name) { + warn!( + "required entity missing or score of entity was too low: {}", + entity.name + ); + self.resume_http_request(); + return; + } + } + + let req_param_str = match serde_json::to_string(&request_params) { + Ok(req_param_str) => req_param_str, + Err(e) => { + warn!("Error serializing request_params: {:?}", e); + self.resume_http_request(); + return; + } + }; + + let endpoint = callout_context + .prompt_target + .as_ref() + .unwrap() + .endpoint + .as_ref() + .unwrap(); + + let http_path = match &endpoint.path { + Some(path) => path.clone(), + None => "/".to_string(), + }; + + let http_method = match &endpoint.method { + Some(method) => method.clone(), + None => http::Method::POST.to_string(), + }; + + let token_id = match self.dispatch_http_call( + &endpoint.cluster.clone(), + vec![ + (":method", http_method.as_str()), + (":path", http_path.as_str()), + (":authority", endpoint.cluster.as_str()), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ], + Some(req_param_str.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!("Error dispatching HTTP call for context-resolver: {:?}", e); + } + }; + callout_context.request_type = RequestType::ContextResolver; + if self.callouts.insert(token_id, callout_context).is_some() { + panic!("duplicate token_id") + } + } + + fn context_resolver_handler(&mut self, body: Vec, callout_context: CallContext) { + info!("response received for context-resolver"); + let body_string = String::from_utf8(body); + let prompt_target = callout_context.prompt_target.unwrap(); + let mut request_body = callout_context.request_body; + match prompt_target.system_prompt { + None => {} + Some(system_prompt) => { + let system_prompt_message: Message = Message { + role: SYSTEM_ROLE.to_string(), + content: Some(system_prompt), + }; + request_body.messages.push(system_prompt_message); + } + } + match body_string { + Ok(body_string) => { + info!("context-resolver response: {}", body_string); + let context_resolver_response = Message { + role: USER_ROLE.to_string(), + content: Some(body_string), + }; + request_body.messages.push(context_resolver_response); + } + Err(e) => { + warn!("Error converting response to string: {:?}", e); + self.resume_http_request(); + return; + } + } + + let json_string = match serde_json::to_string(&request_body) { + Ok(json_string) => json_string, + Err(e) => { + warn!("Error serializing request_body: {:?}", e); + self.resume_http_request(); + return; + } + }; + info!("sending request to openai: msg len: {}", json_string.len()); + self.set_http_request_body(0, json_string.len(), &json_string.into_bytes()); + self.resume_http_request(); + } +} + +// HttpContext is the trait that allows the Rust code to interact with HTTP objects. +impl HttpContext for StreamContext { + // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto + // the lifecycle of the http request and response. + fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + self.save_host_header(); + self.delete_content_length_header(); + self.modify_path_header(); + + Action::Continue + } + + fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + // Let the client send the gateway all the data before sending to the LLM_provider. + // TODO: consider a streaming API. + if !end_of_stream { + return Action::Pause; + } + + if body_size == 0 { + return Action::Continue; + } + + // Deserialize body into spec. + // Currently OpenAI API. + let deserialized_body: common_types::open_ai::ChatCompletions = + match self.get_http_request_body(0, body_size) { + Some(body_bytes) => match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(msg) => { + self.send_http_response( + StatusCode::BAD_REQUEST.as_u16().into(), + vec![], + Some(format!("Failed to deserialize: {}", msg).as_bytes()), + ); + return Action::Pause; + } + }, + None => { + self.send_http_response( + StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), + vec![], + None, + ); + error!( + "Failed to obtain body bytes even though body_size is {}", + body_size + ); + return Action::Pause; + } + }; + + let user_message = match deserialized_body + .messages + .last() + .and_then(|last_message| last_message.content.as_ref()) + { + Some(content) => content, + None => { + info!("No messages in the request body"); + return Action::Continue; + } + }; + + let get_embeddings_input = CreateEmbeddingRequest { + input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())), + model: String::from(DEFAULT_EMBEDDING_MODEL), + encoding_format: None, + dimensions: None, + user: None, + }; + + let json_data: String = match serde_json::to_string(&get_embeddings_input) { + Ok(json_data) => json_data, + Err(error) => { + panic!("Error serializing embeddings input: {}", error); + } + }; + + let token_id = match self.dispatch_http_call( + "embeddingserver", + vec![ + (":method", "POST"), + (":path", "/embeddings"), + (":authority", "embeddingserver"), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ], + Some(json_data.as_bytes()), + vec![], + Duration::from_secs(5), + ) { + Ok(token_id) => token_id, + Err(e) => { + panic!( + "Error dispatching embedding server HTTP call for get-embeddings: {:?}", + e + ); + } + }; + let call_context = CallContext { + request_type: RequestType::GetEmbedding, + user_message: user_message.clone(), + prompt_target: None, + request_body: deserialized_body, + }; + if self.callouts.insert(token_id, call_context).is_some() { + panic!( + "duplicate token_id={} in embedding server requests", + token_id + ) + } + + Action::Pause + } +} + +impl Context for StreamContext { + fn on_http_call_response( + &mut self, + token_id: u32, + _num_headers: usize, + body_size: usize, + _num_trailers: usize, + ) { + let callout_context = self.callouts.remove(&token_id).expect("invalid token_id"); + + let resp = self.get_http_call_response_body(0, body_size); + + if resp.is_none() { + warn!("No response body"); + self.resume_http_request(); + return; + } + + let body = match resp { + Some(body) => body, + None => { + warn!("Empty response body"); + self.resume_http_request(); + return; + } + }; + + match callout_context.request_type { + RequestType::GetEmbedding => { + self.embeddings_handler(body, callout_context); + } + + RequestType::SearchPoints => { + self.search_points_handler(body, callout_context); + } + RequestType::Ner => { + self.ner_handler(body, callout_context); + } + RequestType::ContextResolver => { + self.context_resolver_handler(body, callout_context); + } + } + } +} + +fn get_entity_details(prompt_target: &PromptTarget) -> Vec { + match prompt_target.entities.as_ref() { + Some(EntityType::Vec(entity_names)) => { + let mut entity_details: Vec = Vec::new(); + for entity_name in entity_names { + entity_details.push(EntityDetail { + name: entity_name.clone(), + required: Some(true), + description: None, + }); + } + entity_details + } + Some(EntityType::Struct(entity_details)) => entity_details.clone(), + None => Vec::new(), + } +} diff --git a/envoyfilter/tests/integration.rs b/envoyfilter/tests/integration.rs index 6dfb7675..35a6256e 100644 --- a/envoyfilter/tests/integration.rs +++ b/envoyfilter/tests/integration.rs @@ -90,7 +90,7 @@ fn successful_request_to_open_ai_chat_completions() { .returning(Some(chat_completions_request_body)) // TODO: assert that the model field was added. .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) - .execute_and_expect(ReturnType::Action(Action::Continue)) + .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); } diff --git a/gateway.code-workspace b/gateway.code-workspace index bc33ccde..31fbf9e5 100644 --- a/gateway.code-workspace +++ b/gateway.code-workspace @@ -13,17 +13,17 @@ "path": "embedding-server" }, { - "name": "chatbot-client", - "path": "chatbot-client" + "name": "chatbot-ui", + "path": "chatbot-ui" }, { "name": "open-message-format", "path": "open-message-format" }, { - "name": "demos", - "path": "./demos" - }, + "name": "demos/weather-forecast", + "path": "./demos/weather-forecast", + } ], "settings": {} }