mirror of
https://github.com/katanemo/plano.git
synced 2026-06-11 15:05:14 +02:00
Use intent model from archfc to pick prompt gateway (#328)
This commit is contained in:
parent
67b8fd635e
commit
ba7279becb
151 changed files with 8642 additions and 10932 deletions
58
.github/workflows/e2e_archgw.yml
vendored
Normal file
58
.github/workflows/e2e_archgw.yml
vendored
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
name: e2e archgw tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest-m
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./tests/archgw
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: build arch docker image
|
||||
run: |
|
||||
cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw
|
||||
|
||||
- name: start archgw
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
||||
run: |
|
||||
docker compose up | tee &> archgw.logs &
|
||||
|
||||
- name: wait for archgw to be healthy
|
||||
run: |
|
||||
source common.sh && wait_for_healthz http://localhost:10000/healthz
|
||||
|
||||
- name: install poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
|
||||
- name: install test dependencies
|
||||
run: |
|
||||
poetry install
|
||||
|
||||
- name: run archgw tests
|
||||
run: |
|
||||
poetry run pytest || tail -100 archgw.logs
|
||||
|
||||
- name: stop archgw docker container
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
||||
run: |
|
||||
docker compose down
|
||||
40
.github/workflows/e2e_model_server.yml
vendored
Normal file
40
.github/workflows/e2e_model_server.yml
vendored
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
name: e2e model server tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest-m
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./tests/modelserver
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: install poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
|
||||
- name: install model server and start it
|
||||
run: |
|
||||
cd ../../model_server/ && poetry install && poetry run archgw_modelserver start
|
||||
|
||||
- name: install test dependencies
|
||||
run: |
|
||||
poetry install
|
||||
|
||||
- name: run tests
|
||||
run: |
|
||||
poetry run pytest
|
||||
47
.github/workflows/e2e_test_demos.yml
vendored
Normal file
47
.github/workflows/e2e_test_demos.yml
vendored
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
name: e2e demo tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest-m
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: build arch docker image
|
||||
run: |
|
||||
docker build -f arch/Dockerfile . -t katanemo/archgw
|
||||
|
||||
- name: install poetry
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
- name: setup python venv
|
||||
run: |
|
||||
python -m venv venv
|
||||
|
||||
- name: install model server, arch gateway and test dependencies
|
||||
run: |
|
||||
source venv/bin/activate
|
||||
cd model_server/ && echo "installing model server" && poetry install
|
||||
cd ../arch/tools && echo "installing archgw cli" && poetry install
|
||||
cd ../../demos/test_runner && echo "installing test dependencies" && poetry install
|
||||
|
||||
- name: run demo tests
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
||||
run: |
|
||||
source venv/bin/activate
|
||||
cd demos/test_runner && sh run_demo_tests.sh
|
||||
5
.github/workflows/e2e_tests.yml
vendored
5
.github/workflows/e2e_tests.yml
vendored
|
|
@ -8,8 +8,7 @@ on:
|
|||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest-m
|
||||
# runs-on: gh-large-150gb-ssd
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
|
@ -31,4 +30,4 @@ jobs:
|
|||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
||||
run: |
|
||||
python -mvenv venv
|
||||
source venv/bin/activate && cd e2e_tests && bash run_e2e_tests.sh
|
||||
source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh
|
||||
|
|
|
|||
2
.github/workflows/model-server-tests.yml
vendored
2
.github/workflows/model-server-tests.yml
vendored
|
|
@ -41,4 +41,4 @@ jobs:
|
|||
PYTHONPATH: model_server # Ensure the app's path is available
|
||||
run: |
|
||||
cd model_server
|
||||
poetry run pytest --maxfail=5 --disable-warnings
|
||||
poetry run pytest
|
||||
|
|
|
|||
157
.gitignore
vendored
157
.gitignore
vendored
|
|
@ -1,35 +1,142 @@
|
|||
arch/qdrant_data/
|
||||
/venv/
|
||||
__pycache__
|
||||
grafana-data
|
||||
prom_data
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
*.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
qdrant_data
|
||||
generated
|
||||
.DS_Store
|
||||
*.gguf
|
||||
venv
|
||||
demos/function_calling/ollama/models/
|
||||
demos/function_calling/ollama/id_ed*
|
||||
docs/build/
|
||||
demos/function_calling/open-webui/
|
||||
demos/employee_details_copilot/open-webui/
|
||||
demos/employee_details_copilot_arch/open-webui/
|
||||
demos/network_copilot/open-webui/
|
||||
demos/employee_details_copilot/ollama/models/
|
||||
demos/employee_details_copilot_arch/ollama/models/
|
||||
demos/network_copilot/ollama/models/
|
||||
arch_log/
|
||||
arch/tools/*.egg-info
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
# VSCode stuff:
|
||||
.vscode/
|
||||
|
||||
# MacOS Metadata
|
||||
*.DS_Store
|
||||
|
||||
|
||||
|
||||
# =========================================
|
||||
|
||||
# Arch
|
||||
arch/tools/config
|
||||
arch/tools/build
|
||||
model_server/model_server.egg-info
|
||||
|
||||
# Archgw - model_server
|
||||
model_server/venv_model_server
|
||||
model_server/build
|
||||
model_server/dist
|
||||
|
||||
# Archgw - Docs
|
||||
docs/build/
|
||||
|
||||
# Archgw - Demos
|
||||
demos/function_calling/ollama/models/
|
||||
demos/function_calling/ollama/id_ed*
|
||||
demos/function_calling/open-webui/
|
||||
demos/function_calling/open-webui/
|
||||
demos/shared/signoz/data
|
||||
|
||||
# Arch - Miscellaneous
|
||||
grafana-data
|
||||
prom_data
|
||||
arch_log/
|
||||
arch_logs/
|
||||
dist/
|
||||
crates/*/target/
|
||||
crates/target/
|
||||
build.log
|
||||
demos/shared/signoz/data
|
||||
|
||||
archgw.log
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ Arch's CLI allows you to manage and interact with the Arch gateway efficiently.
|
|||
```console
|
||||
$ python -m venv venv
|
||||
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
|
||||
$ pip install archgw==0.1.6
|
||||
$ pip install archgw==0.1.7
|
||||
```
|
||||
|
||||
### Build AI Agent with Arch Gateway
|
||||
|
|
|
|||
|
|
@ -99,6 +99,8 @@ properties:
|
|||
type: string
|
||||
in_path:
|
||||
type: boolean
|
||||
format:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
|
|
|
|||
|
|
@ -22,3 +22,4 @@ services:
|
|||
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
|
||||
- MISTRAL_API_KEY=${MISTRAL_API_KEY:?error}
|
||||
- OTEL_TRACING_HTTP_ENDPOINT=http://host.docker.internal:4318/v1/traces
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-51000}
|
||||
|
|
|
|||
|
|
@ -237,8 +237,7 @@ static_resources:
|
|||
domains:
|
||||
- "*"
|
||||
routes:
|
||||
|
||||
{% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination"] %}
|
||||
{% for internal_clustrer in ["arch_fc", "model_server"] %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
headers:
|
||||
|
|
@ -251,16 +250,16 @@ static_resources:
|
|||
timeout: 60s
|
||||
{% endfor %}
|
||||
|
||||
{% for _, cluster in arch_clusters.items() %}
|
||||
{% for cluster_name, cluster in arch_clusters.items() %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
headers:
|
||||
- name: "x-arch-upstream"
|
||||
string_match:
|
||||
exact: {{ cluster.name }}
|
||||
exact: {{ cluster_name }}
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: {{ cluster.name }}
|
||||
cluster: {{ cluster_name }}
|
||||
timeout: 60s
|
||||
{% endfor %}
|
||||
http_filters:
|
||||
|
|
@ -475,7 +474,7 @@ static_resources:
|
|||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
sni: api.mistral.ai
|
||||
{% for internal_clustrer in ["embeddings", "zeroshot", "guard", "arch_fc", "hallucination"] %}
|
||||
{% for internal_clustrer in ["arch_fc", "model_server"] %}
|
||||
- name: {{ internal_clustrer }}
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
|
|
@ -489,7 +488,7 @@ static_resources:
|
|||
address:
|
||||
socket_address:
|
||||
address: host.docker.internal
|
||||
port_value: 51000
|
||||
port_value: $MODEL_SERVER_PORT
|
||||
hostname: {{ internal_clustrer }}
|
||||
{% endfor %}
|
||||
- name: mistral_7b_instruct
|
||||
|
|
@ -507,8 +506,8 @@ static_resources:
|
|||
address: mistral_7b_instruct
|
||||
port_value: 10001
|
||||
hostname: "mistral_7b_instruct"
|
||||
{% for _, cluster in arch_clusters.items() %}
|
||||
- name: {{ cluster.name }}
|
||||
{% for cluster_name, cluster in arch_clusters.items() %}
|
||||
- name: {{ cluster_name }}
|
||||
{% if cluster.connect_timeout -%}
|
||||
connect_timeout: {{ cluster.connect_timeout }}
|
||||
{% else -%}
|
||||
|
|
@ -518,7 +517,7 @@ static_resources:
|
|||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: {{ cluster.name }}
|
||||
cluster_name: {{ cluster_name }}
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ source venv/bin/activate
|
|||
|
||||
### Step 3: Run the build script
|
||||
```bash
|
||||
pip install archgw==0.1.6
|
||||
pip install archgw==0.1.7
|
||||
```
|
||||
|
||||
## Uninstall Instructions: archgw CLI
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
import yaml
|
||||
|
|
@ -47,32 +48,27 @@ def validate_and_render_schema():
|
|||
config_schema_yaml = yaml.safe_load(arch_config_schema)
|
||||
inferred_clusters = {}
|
||||
|
||||
endpoints = config_yaml.get("endpoints", {})
|
||||
|
||||
# override the inferred clusters with the ones defined in the config
|
||||
for name, endpoint_details in endpoints.items():
|
||||
inferred_clusters[name] = endpoint_details
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
if len(endpoint.split(":")) > 1:
|
||||
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
|
||||
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
|
||||
|
||||
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
|
||||
|
||||
if "prompt_targets" in config_yaml:
|
||||
for prompt_target in config_yaml["prompt_targets"]:
|
||||
name = prompt_target.get("endpoint", {}).get("name", None)
|
||||
if not name:
|
||||
continue
|
||||
if name not in inferred_clusters:
|
||||
inferred_clusters[name] = {
|
||||
"name": name,
|
||||
"port": 80, # default port
|
||||
}
|
||||
|
||||
endpoints = config_yaml.get("endpoints", {})
|
||||
|
||||
# override the inferred clusters with the ones defined in the config
|
||||
for name, endpoint_details in endpoints.items():
|
||||
if name in inferred_clusters:
|
||||
print("updating cluster", endpoint_details)
|
||||
inferred_clusters[name].update(endpoint_details)
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
if len(endpoint.split(":")) > 1:
|
||||
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
|
||||
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
|
||||
else:
|
||||
inferred_clusters[name] = endpoint_details
|
||||
|
||||
print("updated clusters", inferred_clusters)
|
||||
raise Exception(
|
||||
f"Unknown endpoint {name}, please add it in endpoints section in your arch_config.yaml file"
|
||||
)
|
||||
|
||||
arch_llm_providers = config_yaml["llm_providers"]
|
||||
arch_tracing = config_yaml.get("tracing", {})
|
||||
|
|
@ -90,6 +86,7 @@ def validate_and_render_schema():
|
|||
|
||||
rendered = template.render(data)
|
||||
print(ENVOY_CONFIG_FILE_RENDERED)
|
||||
print(rendered)
|
||||
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
|
||||
file.write(rendered)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
KATANEMO_DOCKERHUB_REPO = "katanemo/archgw"
|
||||
KATANEMO_LOCAL_MODEL_LIST = [
|
||||
"katanemo/Arch-Guard-cpu",
|
||||
"katanemo/Arch-Guard",
|
||||
"katanemo/bge-large-en-v1.5",
|
||||
]
|
||||
SERVICE_NAME_ARCHGW = "archgw"
|
||||
SERVICE_NAME_MODEL_SERVER = "model_server"
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ def start_archgw_docker(client, arch_config_file, env):
|
|||
},
|
||||
environment={
|
||||
"OTEL_TRACING_HTTP_ENDPOINT": "http://host.docker.internal:4318/v1/traces",
|
||||
"MODEL_SERVER_PORT": os.getenv("MODEL_SERVER_PORT", "51000"),
|
||||
**env,
|
||||
},
|
||||
extra_hosts={"host.docker.internal": "host-gateway"},
|
||||
|
|
@ -78,25 +79,6 @@ def stream_gateway_logs(follow):
|
|||
log.info(f"Failed to stream logs: {str(e)}")
|
||||
|
||||
|
||||
def stream_model_server_logs(follow):
|
||||
"""
|
||||
Get the model server logs, check if the user wants to follow/tail them.
|
||||
"""
|
||||
log_file_expanded = os.path.expanduser(MODEL_SERVER_LOG_FILE)
|
||||
|
||||
stream_command = ["tail"]
|
||||
if follow:
|
||||
stream_command.append("-f")
|
||||
|
||||
stream_command.append(log_file_expanded)
|
||||
subprocess.run(
|
||||
stream_command,
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
def stream_access_logs(follow):
|
||||
"""
|
||||
Get the archgw access logs
|
||||
|
|
@ -117,7 +99,7 @@ def stream_access_logs(follow):
|
|||
)
|
||||
|
||||
|
||||
def start_arch(arch_config_file, env, log_timeout=120):
|
||||
def start_arch(arch_config_file, env, log_timeout=120, foreground=False):
|
||||
"""
|
||||
Start Docker Compose in detached mode and stream logs until services are healthy.
|
||||
|
||||
|
|
@ -130,6 +112,16 @@ def start_arch(arch_config_file, env, log_timeout=120):
|
|||
try:
|
||||
client = docker.from_env()
|
||||
|
||||
try:
|
||||
container = client.containers.get("archgw")
|
||||
log.info("archgw container found in docker, stopping and removing it")
|
||||
# ensure that previous docker container is stopped and removed
|
||||
container.stop()
|
||||
container.remove()
|
||||
log.info("Stopped and removed archgw container")
|
||||
except docker.errors.NotFound as e:
|
||||
pass
|
||||
|
||||
container = start_archgw_docker(client, arch_config_file, env)
|
||||
|
||||
start_time = time.time()
|
||||
|
|
@ -153,6 +145,13 @@ def start_arch(arch_config_file, env, log_timeout=120):
|
|||
log.info(f"Container health status: {container_status}")
|
||||
time.sleep(1)
|
||||
|
||||
if foreground:
|
||||
for line in container.logs(stream=True):
|
||||
print(line.decode("utf-8").strip("\n"))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
log.info("Keyboard interrupt received, stopping arch gateway service.")
|
||||
stop_arch()
|
||||
except docker.errors.APIError as e:
|
||||
log.info(f"Failed to start Arch: {str(e)}")
|
||||
|
||||
|
|
@ -186,17 +185,23 @@ def download_models_from_hf():
|
|||
snapshot_download(repo_id=model)
|
||||
|
||||
|
||||
def start_arch_modelserver():
|
||||
def start_arch_modelserver(foreground):
|
||||
"""
|
||||
Start the model server. This assumes that the archgw_modelserver package is installed locally
|
||||
|
||||
"""
|
||||
try:
|
||||
log.info("archgw_modelserver restart")
|
||||
subprocess.run(
|
||||
["archgw_modelserver", "restart"], check=True, start_new_session=True
|
||||
)
|
||||
log.info("Successfully ran model_server")
|
||||
if foreground:
|
||||
subprocess.run(
|
||||
["archgw_modelserver", "start", "--foreground"],
|
||||
check=True,
|
||||
)
|
||||
else:
|
||||
subprocess.run(
|
||||
["archgw_modelserver", "start"],
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
sys.exit(1)
|
||||
|
|
@ -212,7 +217,6 @@ def stop_arch_modelserver():
|
|||
["archgw_modelserver", "stop"],
|
||||
check=True,
|
||||
)
|
||||
log.info("Successfully stopped the archgw model_server")
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
sys.exit(1)
|
||||
|
|
|
|||
|
|
@ -16,10 +16,9 @@ from cli.core import (
|
|||
stop_arch_modelserver,
|
||||
start_arch,
|
||||
stop_arch,
|
||||
stream_gateway_logs,
|
||||
stream_model_server_logs,
|
||||
stream_access_logs,
|
||||
download_models_from_hf,
|
||||
stream_access_logs,
|
||||
stream_gateway_logs,
|
||||
)
|
||||
from cli.consts import (
|
||||
KATANEMO_DOCKERHUB_REPO,
|
||||
|
|
@ -138,16 +137,27 @@ def build(service):
|
|||
default=SERVICE_ALL,
|
||||
help="Service to start. Options are model_server, archgw.",
|
||||
)
|
||||
def up(file, path, service):
|
||||
@click.option(
|
||||
"--foreground",
|
||||
default=False,
|
||||
help="Run Arch in the foreground. Default is False",
|
||||
is_flag=True,
|
||||
)
|
||||
def up(file, path, service, foreground):
|
||||
"""Starts Arch."""
|
||||
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
|
||||
log.info(f"Error: Invalid service {service}. Exiting")
|
||||
sys.exit(1)
|
||||
|
||||
if service == SERVICE_ALL and foreground:
|
||||
# foreground can only be specified when starting individual services
|
||||
log.info("foreground flag is only supported for individual services. Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
if service == SERVICE_NAME_MODEL_SERVER:
|
||||
log.info("Download archgw models from HuggingFace...")
|
||||
download_models_from_hf()
|
||||
start_arch_modelserver()
|
||||
start_arch_modelserver(foreground)
|
||||
return
|
||||
|
||||
if file:
|
||||
|
|
@ -214,12 +224,11 @@ def up(file, path, service):
|
|||
env.update(env_stage)
|
||||
|
||||
if service == SERVICE_NAME_ARCHGW:
|
||||
start_arch(arch_config_file, env)
|
||||
start_arch(arch_config_file, env, foreground=foreground)
|
||||
else:
|
||||
# this will used the cached versions of the models, so its safe to use everytime.
|
||||
download_models_from_hf()
|
||||
start_arch_modelserver()
|
||||
start_arch(arch_config_file, env)
|
||||
start_arch_modelserver(foreground)
|
||||
start_arch(arch_config_file, env, foreground=foreground)
|
||||
|
||||
|
||||
@click.command()
|
||||
|
|
@ -267,65 +276,37 @@ def generate_prompt_targets(file):
|
|||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--service",
|
||||
default=SERVICE_ALL,
|
||||
help="Service to monitor. By default it will monitor both core gateway and model_server logs.",
|
||||
)
|
||||
@click.option(
|
||||
"--debug",
|
||||
help="For detailed debug logs to trace calls from archgw <> model_server <> api_server, etc",
|
||||
is_flag=True,
|
||||
)
|
||||
@click.option("--follow", help="Follow the logs", is_flag=True)
|
||||
def logs(service, debug, follow):
|
||||
def logs(debug, follow):
|
||||
"""Stream logs from access logs services."""
|
||||
|
||||
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
|
||||
print(f"Error: Invalid service {service}. Exiting")
|
||||
sys.exit(1)
|
||||
|
||||
if debug:
|
||||
try:
|
||||
archgw_process = None
|
||||
if service == SERVICE_NAME_ARCHGW or service == SERVICE_ALL:
|
||||
archgw_process = multiprocessing.Process(
|
||||
target=stream_gateway_logs, args=(follow,)
|
||||
)
|
||||
archgw_process.start()
|
||||
|
||||
model_server_process = None
|
||||
if service == SERVICE_NAME_MODEL_SERVER or service == SERVICE_ALL:
|
||||
model_server_process = multiprocessing.Process(
|
||||
target=stream_model_server_logs, args=(follow,)
|
||||
)
|
||||
model_server_process.start()
|
||||
|
||||
if archgw_process:
|
||||
archgw_process.join()
|
||||
if model_server_process:
|
||||
model_server_process.join()
|
||||
except KeyboardInterrupt:
|
||||
log.info("KeyboardInterrupt detected. Exiting.")
|
||||
if archgw_process and archgw_process.is_alive():
|
||||
archgw_process.terminate()
|
||||
|
||||
if model_server_process and model_server_process.is_alive():
|
||||
model_server_process.terminate()
|
||||
else:
|
||||
try:
|
||||
archgw_access_logs_process = None
|
||||
archgw_access_logs_process = multiprocessing.Process(
|
||||
target=stream_access_logs, args=(follow,)
|
||||
archgw_process = None
|
||||
try:
|
||||
if debug:
|
||||
archgw_process = multiprocessing.Process(
|
||||
target=stream_gateway_logs, args=(follow,)
|
||||
)
|
||||
archgw_access_logs_process.start()
|
||||
archgw_process.start()
|
||||
|
||||
if archgw_access_logs_process:
|
||||
archgw_access_logs_process.join()
|
||||
except KeyboardInterrupt:
|
||||
log.info("KeyboardInterrupt detected. Exiting.")
|
||||
if archgw_access_logs_process.is_alive():
|
||||
archgw_access_logs_process.terminate()
|
||||
archgw_access_logs_process = multiprocessing.Process(
|
||||
target=stream_access_logs, args=(follow,)
|
||||
)
|
||||
archgw_access_logs_process.start()
|
||||
archgw_access_logs_process.join()
|
||||
|
||||
if archgw_process:
|
||||
archgw_process.join()
|
||||
except KeyboardInterrupt:
|
||||
log.info("KeyboardInterrupt detected. Exiting.")
|
||||
if archgw_access_logs_process.is_alive():
|
||||
archgw_access_logs_process.terminate()
|
||||
if archgw_process and archgw_process.is_alive():
|
||||
archgw_process.terminate()
|
||||
|
||||
|
||||
main.add_command(up)
|
||||
|
|
|
|||
3691
arch/tools/poetry.lock
generated
3691
arch/tools/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "archgw"
|
||||
version = "0.1.6"
|
||||
version = "0.1.7"
|
||||
description = "Python-based CLI tool to manage Arch Gateway."
|
||||
authors = ["Katanemo Labs, Inc."]
|
||||
packages = [
|
||||
|
|
@ -9,8 +9,8 @@ packages = [
|
|||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.12"
|
||||
archgw_modelserver = "0.1.6"
|
||||
python = "^3.12"
|
||||
archgw_modelserver = "0.1.7"
|
||||
pyyaml = "^6.0.2"
|
||||
pydantic = "^2.10.1"
|
||||
click = "^8.1.7"
|
||||
|
|
|
|||
|
|
@ -17,8 +17,16 @@
|
|||
"path": "model_server"
|
||||
},
|
||||
{
|
||||
"name": "e2e_tests",
|
||||
"path": "e2e_tests"
|
||||
"name": "tests_e2e",
|
||||
"path": "tests/e2e"
|
||||
},
|
||||
{
|
||||
"name": "tests_archgw",
|
||||
"path": "tests/archgw"
|
||||
},
|
||||
{
|
||||
"name": "tests_modelserver",
|
||||
"path": "tests/modelserver"
|
||||
},
|
||||
{
|
||||
"name": "chatbot_ui",
|
||||
|
|
@ -41,6 +49,7 @@
|
|||
"eamodio.gitlens",
|
||||
"ms-python.black-formatter",
|
||||
"tamasfe.even-better-toml",
|
||||
"esbenp.prettier-vscode",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ pub struct ChatCompletionsRequest {
|
|||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum ToolType {
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
|
|
@ -80,6 +80,8 @@ pub struct FunctionParameter {
|
|||
pub enum_values: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<String>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameter {
|
||||
|
|
@ -96,6 +98,9 @@ impl Serialize for FunctionParameter {
|
|||
if let Some(default) = &self.default {
|
||||
map.serialize_entry("default", default)?;
|
||||
}
|
||||
if let Some(format) = &self.format {
|
||||
map.serialize_entry("format", format)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
|
@ -165,8 +170,8 @@ pub struct Message {
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub finish_reason: String,
|
||||
pub index: usize,
|
||||
pub finish_reason: Option<String>,
|
||||
pub index: Option<usize>,
|
||||
pub message: Message,
|
||||
}
|
||||
|
||||
|
|
@ -197,6 +202,18 @@ pub struct ToolCallState {
|
|||
pub enum ArchState {
|
||||
ToolCall(Vec<ToolCallState>),
|
||||
}
|
||||
#[derive(Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ModelServerResponse {
|
||||
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||
ModelServerErrorResponse(ModelServerErrorResponse),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelServerErrorResponse {
|
||||
pub result: String,
|
||||
pub intent_latency: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
|
|
@ -217,8 +234,8 @@ impl ChatCompletionsResponse {
|
|||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
index: 0,
|
||||
finish_reason: "done".to_string(),
|
||||
index: Some(0),
|
||||
finish_reason: Some("done".to_string()),
|
||||
}],
|
||||
usage: None,
|
||||
model: ARCH_FC_MODEL_NAME.to_string(),
|
||||
|
|
@ -408,6 +425,7 @@ mod test {
|
|||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
format: None,
|
||||
},
|
||||
);
|
||||
|
||||
|
|
@ -462,6 +480,7 @@ mod test {
|
|||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
format: None,
|
||||
},
|
||||
)]);
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,10 @@ use serde::{Deserialize, Serialize};
|
|||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
use crate::api::open_ai::{
|
||||
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
|
|
@ -192,6 +196,7 @@ pub struct Parameter {
|
|||
pub enum_values: Option<Vec<String>>,
|
||||
pub default: Option<String>,
|
||||
pub in_path: Option<bool>,
|
||||
pub format: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
|
||||
|
|
@ -231,11 +236,47 @@ pub struct PromptTarget {
|
|||
pub auto_llm_dispatch_on_response: Option<bool>,
|
||||
}
|
||||
|
||||
// convert PromptTarget to ChatCompletionTool
|
||||
impl From<&PromptTarget> for ChatCompletionTool {
|
||||
fn from(val: &PromptTarget) -> Self {
|
||||
let properties: HashMap<String, FunctionParameter> = match val.parameters {
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = FunctionParameter {
|
||||
parameter_type: ParameterType::from(
|
||||
entity.parameter_type.clone().unwrap_or("str".to_string()),
|
||||
),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
enum_values: entity.enum_values.clone(),
|
||||
default: entity.default.clone(),
|
||||
format: entity.format.clone(),
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
|
||||
ChatCompletionTool {
|
||||
tool_type: crate::api::open_ai::ToolType::Function,
|
||||
function: FunctionDefinition {
|
||||
name: val.name.clone(),
|
||||
description: val.description.clone(),
|
||||
parameters: FunctionParameters { properties },
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
|
||||
use crate::configuration::GuardType;
|
||||
use crate::{api::open_ai::ToolType, configuration::GuardType};
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_configuration() {
|
||||
|
|
@ -307,4 +348,76 @@ mod test {
|
|||
let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt);
|
||||
assert_eq!(*mode, super::GatewayMode::Prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let ref_config = fs::read_to_string(
|
||||
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
|
||||
)
|
||||
.expect("reference config file not found");
|
||||
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
|
||||
let prompt_targets = &config.prompt_targets;
|
||||
let prompt_target = prompt_targets
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.find(|p| p.name == "reboot_network_device")
|
||||
.unwrap();
|
||||
let chat_completion_tool: super::ChatCompletionTool = prompt_target.into();
|
||||
assert_eq!(chat_completion_tool.tool_type, ToolType::Function);
|
||||
assert_eq!(chat_completion_tool.function.name, "reboot_network_device");
|
||||
assert_eq!(
|
||||
chat_completion_tool.function.description,
|
||||
"Reboot a specific network device"
|
||||
);
|
||||
assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2);
|
||||
assert_eq!(
|
||||
chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.contains_key("device_id"),
|
||||
true
|
||||
);
|
||||
assert_eq!(
|
||||
chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.get("device_id")
|
||||
.unwrap()
|
||||
.parameter_type,
|
||||
crate::api::open_ai::ParameterType::String
|
||||
);
|
||||
assert_eq!(
|
||||
chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.get("device_id")
|
||||
.unwrap()
|
||||
.description,
|
||||
"Identifier of the network device to reboot.".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.get("device_id")
|
||||
.unwrap()
|
||||
.required,
|
||||
Some(true)
|
||||
);
|
||||
assert_eq!(
|
||||
chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.get("confirmation")
|
||||
.unwrap()
|
||||
.parameter_type,
|
||||
crate::api::open_ai::ParameterType::Bool
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,3 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "katanemo/bge-large-en-v1.5";
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
|
||||
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
|
|
@ -9,11 +5,6 @@ pub const TOOL_ROLE: &str = "tool";
|
|||
pub const ASSISTANT_ROLE: &str = "assistant";
|
||||
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot";
|
||||
pub const ARCH_FC_INTERNAL_HOST: &str = "arch_fc";
|
||||
pub const HALLUCINATION_INTERNAL_HOST: &str = "hallucination";
|
||||
pub const EMBEDDINGS_INTERNAL_HOST: &str = "embeddings";
|
||||
pub const GUARD_INTERNAL_HOST: &str = "guard";
|
||||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
pub const MESSAGES_KEY: &str = "messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
|
|
@ -25,7 +16,6 @@ pub const REQUEST_ID_HEADER: &str = "x-request-id";
|
|||
pub const TRACE_PARENT_HEADER: &str = "traceparent";
|
||||
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
|
||||
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
|
||||
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
|
||||
pub const ARCH_MODEL_PREFIX: &str = "Arch";
|
||||
pub const HALLUCINATION_TEMPLATE: &str =
|
||||
"It seems I'm missing some information. Could you provide the following details ";
|
||||
|
|
|
|||
|
|
@ -1,59 +0,0 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use crate::embeddings;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingRequest {
|
||||
#[serde(rename = "input")]
|
||||
pub input: Box<embeddings::CreateEmbeddingRequestInput>,
|
||||
/// ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.
|
||||
#[serde(rename = "model")]
|
||||
pub model: String,
|
||||
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
|
||||
#[serde(rename = "encoding_format", skip_serializing_if = "Option::is_none")]
|
||||
pub encoding_format: Option<EncodingFormat>,
|
||||
/// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.
|
||||
#[serde(rename = "dimensions", skip_serializing_if = "Option::is_none")]
|
||||
pub dimensions: Option<i32>,
|
||||
/// A unique identifier representing your end-user, which can help to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).
|
||||
#[serde(rename = "user", skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
impl CreateEmbeddingRequest {
|
||||
pub fn new(
|
||||
input: embeddings::CreateEmbeddingRequestInput,
|
||||
model: String,
|
||||
) -> CreateEmbeddingRequest {
|
||||
CreateEmbeddingRequest {
|
||||
input: Box::new(input),
|
||||
model,
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub enum EncodingFormat {
|
||||
#[serde(rename = "float")]
|
||||
Float,
|
||||
#[serde(rename = "base64")]
|
||||
Base64,
|
||||
}
|
||||
|
||||
impl Default for EncodingFormat {
|
||||
fn default() -> EncodingFormat {
|
||||
Self::Float
|
||||
}
|
||||
}
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// CreateEmbeddingRequestInput : Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
|
||||
/// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CreateEmbeddingRequestInput {
|
||||
/// The string that will be turned into an embedding.
|
||||
String(String),
|
||||
/// The array of integers that will be turned into an embedding.
|
||||
Array(Vec<i32>),
|
||||
}
|
||||
|
||||
impl Default for CreateEmbeddingRequestInput {
|
||||
fn default() -> Self {
|
||||
Self::String(Default::default())
|
||||
}
|
||||
}
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use crate::embeddings;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingResponse {
|
||||
/// The list of embeddings generated by the model.
|
||||
#[serde(rename = "data")]
|
||||
pub data: Vec<embeddings::Embedding>,
|
||||
/// The name of the model used to generate the embedding.
|
||||
#[serde(rename = "model")]
|
||||
pub model: String,
|
||||
/// The object type, which is always \"list\".
|
||||
#[serde(rename = "object")]
|
||||
pub object: Object,
|
||||
#[serde(rename = "usage")]
|
||||
pub usage: Box<embeddings::CreateEmbeddingResponseUsage>,
|
||||
}
|
||||
|
||||
impl CreateEmbeddingResponse {
|
||||
pub fn new(
|
||||
data: Vec<embeddings::Embedding>,
|
||||
model: String,
|
||||
object: Object,
|
||||
usage: embeddings::CreateEmbeddingResponseUsage,
|
||||
) -> CreateEmbeddingResponse {
|
||||
CreateEmbeddingResponse {
|
||||
data,
|
||||
model,
|
||||
object,
|
||||
usage: Box::new(usage),
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The object type, which is always \"list\".
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub enum Object {
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
}
|
||||
|
||||
impl Default for Object {
|
||||
fn default() -> Object {
|
||||
Self::List
|
||||
}
|
||||
}
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// CreateEmbeddingResponseUsage : The usage information for the request.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingResponseUsage {
|
||||
/// The number of tokens used by the prompt.
|
||||
#[serde(rename = "prompt_tokens")]
|
||||
pub prompt_tokens: i32,
|
||||
/// The total number of tokens used by the request.
|
||||
#[serde(rename = "total_tokens")]
|
||||
pub total_tokens: i32,
|
||||
}
|
||||
|
||||
impl CreateEmbeddingResponseUsage {
|
||||
/// The usage information for the request.
|
||||
pub fn new(prompt_tokens: i32, total_tokens: i32) -> CreateEmbeddingResponseUsage {
|
||||
CreateEmbeddingResponseUsage {
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Embedding : Represents an embedding vector returned by embedding endpoint.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Embedding {
|
||||
/// The index of the embedding in the list of embeddings.
|
||||
#[serde(rename = "index")]
|
||||
pub index: i32,
|
||||
/// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).
|
||||
#[serde(rename = "embedding")]
|
||||
pub embedding: Vec<f64>,
|
||||
/// The object type, which is always \"embedding\"
|
||||
#[serde(rename = "object")]
|
||||
pub object: Object,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
/// Represents an embedding vector returned by embedding endpoint.
|
||||
pub fn new(index: i32, embedding: Vec<f64>, object: Object) -> Embedding {
|
||||
Embedding {
|
||||
index,
|
||||
embedding,
|
||||
object,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The object type, which is always \"embedding\"
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub enum Object {
|
||||
#[serde(rename = "embedding")]
|
||||
Embedding,
|
||||
}
|
||||
|
||||
impl Default for Object {
|
||||
fn default() -> Object {
|
||||
Self::Embedding
|
||||
}
|
||||
}
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
pub mod create_embedding_request;
|
||||
pub use self::create_embedding_request::CreateEmbeddingRequest;
|
||||
pub mod create_embedding_request_input;
|
||||
pub use self::create_embedding_request_input::CreateEmbeddingRequestInput;
|
||||
pub mod create_embedding_response;
|
||||
pub use self::create_embedding_response::CreateEmbeddingResponse;
|
||||
pub mod create_embedding_response_usage;
|
||||
pub use self::create_embedding_response_usage::CreateEmbeddingResponseUsage;
|
||||
pub mod embedding;
|
||||
pub use self::embedding::Embedding;
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
pub mod api;
|
||||
pub mod configuration;
|
||||
pub mod consts;
|
||||
pub mod embeddings;
|
||||
pub mod errors;
|
||||
pub mod http;
|
||||
pub mod llm_providers;
|
||||
pub mod path;
|
||||
pub mod pii;
|
||||
pub mod ratelimit;
|
||||
pub mod routing;
|
||||
pub mod stats;
|
||||
pub mod tokenizer;
|
||||
pub mod tracing;
|
||||
pub mod path;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> Result<String, String> {
|
||||
pub fn replace_params_in_path(
|
||||
path: &str,
|
||||
params: &HashMap<String, String>,
|
||||
) -> Result<String, String> {
|
||||
let mut result = String::new();
|
||||
let mut in_param = false;
|
||||
let mut current_param = String::new();
|
||||
|
|
@ -17,12 +20,10 @@ pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> R
|
|||
return Err(format!("Missing value for parameter `{}`", param_name));
|
||||
}
|
||||
current_param.clear();
|
||||
} else if in_param {
|
||||
current_param.push(c);
|
||||
} else {
|
||||
if in_param {
|
||||
current_param.push(c);
|
||||
} else {
|
||||
result.push(c);
|
||||
}
|
||||
result.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use common::errors::ServerError;
|
||||
use common::stats::IncrementingMetric;
|
||||
use http::StatusCode;
|
||||
use log::{debug, warn};
|
||||
use proxy_wasm::traits::Context;
|
||||
|
||||
use crate::stream_context::{ResponseHandlerType, StreamContext};
|
||||
|
|
@ -19,76 +23,34 @@ impl Context for StreamContext {
|
|||
.expect("invalid token_id");
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
/*
|
||||
state transition
|
||||
let body = self
|
||||
.get_http_call_response_body(0, body_size)
|
||||
.unwrap_or(vec![]);
|
||||
|
||||
graph LR
|
||||
|
||||
on_http_request_body --> prompt received
|
||||
prompt received --> get embeddings & arch guard
|
||||
arch guard --> get embeddings
|
||||
get embeddings --> zeroshot intent
|
||||
|
||||
┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
│ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │
|
||||
│ │ │ │ │ │ │ │
|
||||
└──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘
|
||||
│ ▲
|
||||
│ │
|
||||
│ │
|
||||
│ ┌────────┴───────┐
|
||||
│ │ │
|
||||
└───────────►│ arch guard │
|
||||
│ │
|
||||
└────────────────┘
|
||||
|
||||
|
||||
continue from zeroshot intent
|
||||
|
||||
graph LR
|
||||
|
||||
zeroshot intent --> arch_fc
|
||||
zeroshot intent --> default prompt target
|
||||
arch_fc --> developer api call & hallucination check
|
||||
hallucination check --> parameter gathering & developer api call
|
||||
developer api call --> resume request to llm
|
||||
|
||||
|
||||
┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
│ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │
|
||||
│ │ │ │ │ │ │ │
|
||||
└────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘
|
||||
│ │ ▲
|
||||
│ └─────────────┐ │
|
||||
│ │ │
|
||||
│ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
└───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │
|
||||
│ │ │ │ │ │
|
||||
└───────────────────────┘ └─────────────────────┘ └───────────────────────┘
|
||||
|
||||
|
||||
using https://mermaid-ascii.art/
|
||||
*/
|
||||
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
#[cfg_attr(any(), rustfmt::skip)]
|
||||
match callout_context.response_handler_type {
|
||||
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
||||
ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context),
|
||||
ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context),
|
||||
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
|
||||
ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context),
|
||||
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
|
||||
ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
|
||||
}
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(String::from("No response body in inline HTTP request")),
|
||||
None,
|
||||
let http_status = self
|
||||
.get_http_call_response_header(":status")
|
||||
.unwrap_or(StatusCode::OK.as_str().to_string());
|
||||
debug!("http call response code: {}", http_status);
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
let server_error = ServerError::Upstream {
|
||||
host: callout_context.upstream_cluster.unwrap(),
|
||||
path: callout_context.upstream_cluster_path.unwrap(),
|
||||
status: http_status.clone(),
|
||||
body: String::from_utf8(body).unwrap(),
|
||||
};
|
||||
warn!("filter received non 2xx code: {:?}", server_error);
|
||||
return self.send_server_error(
|
||||
server_error,
|
||||
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
|
||||
);
|
||||
}
|
||||
|
||||
debug!("http call response handler type: {:?}", callout_context.response_handler_type);
|
||||
#[cfg_attr(any(), rustfmt::skip)]
|
||||
match callout_context.response_handler_type {
|
||||
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
|
||||
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
|
||||
ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub enum EmbeddingType {
|
||||
Name,
|
||||
Description,
|
||||
}
|
||||
|
|
@ -1,35 +1,17 @@
|
|||
use crate::embeddings::EmbeddingType;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::stream_context::StreamContext;
|
||||
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
|
||||
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
|
||||
use common::consts::DEFAULT_EMBEDDING_MODEL;
|
||||
use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST};
|
||||
use common::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use common::http::CallArgs;
|
||||
use common::http::Client;
|
||||
use common::stats::Gauge;
|
||||
use common::stats::IncrementingMetric;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, trace, warn};
|
||||
use log::debug;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
|
||||
pub type EmbeddingsStore = HashMap<String, EmbeddingTypeMap>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterCallContext {
|
||||
pub prompt_target_name: String,
|
||||
pub embedding_type: EmbeddingType,
|
||||
}
|
||||
pub struct FilterCallContext {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterContext {
|
||||
|
|
@ -40,9 +22,6 @@ pub struct FilterContext {
|
|||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
temp_embeddings_store: EmbeddingsStore,
|
||||
active_embedding_calls_count: u32,
|
||||
tracing: Rc<Option<Tracing>>,
|
||||
}
|
||||
|
||||
|
|
@ -55,131 +34,9 @@ impl FilterContext {
|
|||
prompt_targets: Rc::new(HashMap::new()),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(PromptGuards::default()),
|
||||
embeddings_store: Some(Rc::new(HashMap::new())),
|
||||
temp_embeddings_store: HashMap::new(),
|
||||
active_embedding_calls_count: 0,
|
||||
tracing: Rc::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
let prompt_target_description: Vec<(String, String)> = self
|
||||
.prompt_targets
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.description.clone()))
|
||||
.collect();
|
||||
|
||||
prompt_target_description
|
||||
.iter()
|
||||
.for_each(|(name, description)| {
|
||||
self.schedule_embeddings_call(name, description, EmbeddingType::Description);
|
||||
});
|
||||
}
|
||||
|
||||
fn schedule_embeddings_call(
|
||||
&mut self,
|
||||
prompt_target_name: &str,
|
||||
input: &str,
|
||||
embedding_type: EmbeddingType,
|
||||
) {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
let json_data = serde_json::to_string(&embeddings_input).unwrap();
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
"/embeddings",
|
||||
vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", EMBEDDINGS_INTERNAL_HOST),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(60),
|
||||
);
|
||||
|
||||
let call_context = crate::filter_context::FilterCallContext {
|
||||
prompt_target_name: String::from(prompt_target_name),
|
||||
embedding_type,
|
||||
};
|
||||
|
||||
self.active_embedding_calls_count += 1;
|
||||
if let Err(error) = self.http_call(call_args, call_context) {
|
||||
panic!("{error}")
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_response_handler(
|
||||
&mut self,
|
||||
embedding_type: EmbeddingType,
|
||||
prompt_target_name: String,
|
||||
body: Vec<u8>,
|
||||
) {
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(&prompt_target_name)
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Received embeddings response for unknown prompt target name={}",
|
||||
prompt_target_name
|
||||
)
|
||||
});
|
||||
|
||||
if !body.is_empty() {
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error deserializing embedding response. body: {:?}: {:?}",
|
||||
String::from_utf8(body).unwrap(),
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let embeddings = embedding_response.data.remove(0).embedding;
|
||||
debug!(
|
||||
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
|
||||
prompt_target.name,
|
||||
prompt_target.description,
|
||||
embedding_type
|
||||
);
|
||||
|
||||
let entry = self.temp_embeddings_store.entry(prompt_target_name);
|
||||
match entry {
|
||||
Entry::Occupied(_) => {
|
||||
entry.and_modify(|e| {
|
||||
if let Entry::Vacant(e) = e.entry(embedding_type) {
|
||||
e.insert(embeddings);
|
||||
} else {
|
||||
panic!(
|
||||
"Duplicate {:?} for prompt target with name=\"{}\"",
|
||||
&embedding_type, prompt_target.name
|
||||
)
|
||||
}
|
||||
});
|
||||
}
|
||||
Entry::Vacant(_) => {
|
||||
entry.or_insert(HashMap::from([(embedding_type, embeddings)]));
|
||||
}
|
||||
}
|
||||
|
||||
if self.prompt_targets.len() == self.temp_embeddings_store.len() {
|
||||
self.embeddings_store =
|
||||
Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for FilterContext {
|
||||
|
|
@ -194,46 +51,7 @@ impl Client for FilterContext {
|
|||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
trace!(
|
||||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
token_id
|
||||
);
|
||||
let callout_data = self
|
||||
.callouts
|
||||
.borrow_mut()
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id");
|
||||
|
||||
self.active_embedding_calls_count -= 1;
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
let body_bytes = self.get_http_call_response_body(0, body_size).unwrap();
|
||||
|
||||
if let Some(status_code) = self.get_http_call_response_header(":status") {
|
||||
if status_code == StatusCode::OK.as_str() {
|
||||
self.embedding_response_handler(
|
||||
callout_data.embedding_type,
|
||||
callout_data.prompt_target_name,
|
||||
body_bytes,
|
||||
);
|
||||
} else {
|
||||
warn!(
|
||||
"Received non-200 status code: {} for callout with token_id: {}: body_str: {}",
|
||||
status_code,
|
||||
token_id,
|
||||
String::from_utf8(body_bytes).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Context for FilterContext {}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
|
|
@ -271,15 +89,12 @@ impl RootContext for FilterContext {
|
|||
context_id
|
||||
);
|
||||
|
||||
let embedding_store = self.embeddings_store.as_ref().map(Rc::clone);
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(&self.system_prompt),
|
||||
Rc::clone(&self.prompt_targets),
|
||||
Rc::clone(&self.prompt_guards),
|
||||
Rc::clone(&self.overrides),
|
||||
embedding_store,
|
||||
Rc::clone(&self.tracing),
|
||||
)))
|
||||
}
|
||||
|
|
@ -289,25 +104,6 @@ impl RootContext for FilterContext {
|
|||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
if self.embeddings_store.is_some()
|
||||
&& self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len()
|
||||
{
|
||||
info!("embeddings store initialized");
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
} else {
|
||||
if self.active_embedding_calls_count == 0 {
|
||||
info!("retrieving embeddings from embedding server");
|
||||
self.process_prompt_targets();
|
||||
} else {
|
||||
info!("waiting for embeddings store to be initialized");
|
||||
}
|
||||
|
||||
self.set_tick_period(Duration::from_secs(5));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext};
|
||||
use common::{
|
||||
api::{
|
||||
open_ai::{self, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest},
|
||||
prompt_guard::{PromptGuardRequest, PromptGuardTask},
|
||||
api::open_ai::{
|
||||
self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
},
|
||||
consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, GUARD_INTERNAL_HOST,
|
||||
HEALTHZ_PATH, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
|
||||
MODEL_SERVER_NAME, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
|
||||
},
|
||||
errors::ServerError,
|
||||
http::{CallArgs, Client},
|
||||
|
|
@ -35,11 +34,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let request_path = self.get_http_request_header(":path").unwrap_or_default();
|
||||
if request_path == HEALTHZ_PATH {
|
||||
if self.is_embedding_store_initialized() {
|
||||
self.send_http_response(200, vec![], None);
|
||||
} else {
|
||||
self.send_http_response(503, vec![], None);
|
||||
}
|
||||
self.send_http_response(200, vec![], None);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
|
|
@ -138,43 +133,25 @@ impl HttpContext for StreamContext {
|
|||
|
||||
self.user_prompt = Some(last_user_prompt.clone());
|
||||
|
||||
let user_message_str = self.user_prompt.as_ref().unwrap().content.clone();
|
||||
// convert prompt targets to ChatCompletionTool
|
||||
let tool_calls: Vec<ChatCompletionTool> = self
|
||||
.prompt_targets
|
||||
.iter()
|
||||
.map(|(_, pt)| pt.into())
|
||||
.collect();
|
||||
|
||||
let prompt_guard_jailbreak_task = self
|
||||
.prompt_guards
|
||||
.input_guards
|
||||
.contains_key(&common::configuration::GuardType::Jailbreak);
|
||||
let arch_fc_chat_completion_request = ChatCompletionsRequest {
|
||||
messages: deserialized_body.messages.clone(),
|
||||
metadata: deserialized_body.metadata.clone(),
|
||||
stream: deserialized_body.stream,
|
||||
model: "--".to_string(),
|
||||
stream_options: deserialized_body.stream_options.clone(),
|
||||
tools: Some(tool_calls),
|
||||
};
|
||||
|
||||
self.chat_completions_request = Some(deserialized_body);
|
||||
|
||||
if !prompt_guard_jailbreak_task {
|
||||
debug!("Missing input guard. Making inline call to retrieve embeddings");
|
||||
let callout_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: user_message_str.clone(),
|
||||
prompt_target_name: None,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
let get_prompt_guards_request = PromptGuardRequest {
|
||||
input: self
|
||||
.user_prompt
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone(),
|
||||
task: PromptGuardTask::Jailbreak,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
|
||||
let json_data = match serde_json::to_string(&arch_fc_chat_completion_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
self.send_server_error(ServerError::Serialization(error), None);
|
||||
|
|
@ -182,14 +159,14 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
debug!("archgw => archfc: {}", json_data);
|
||||
|
||||
let mut headers = vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, GUARD_INTERNAL_HOST),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),
|
||||
(":method", "POST"),
|
||||
(":path", "/guard"),
|
||||
(":authority", GUARD_INTERNAL_HOST),
|
||||
(":path", "/function_calling"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
(":authority", MODEL_SERVER_NAME),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
|
|
@ -202,23 +179,25 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
"/guard",
|
||||
"/function_calling",
|
||||
headers,
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
let call_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
response_handler_type: ResponseHandlerType::ArchFC,
|
||||
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
|
||||
prompt_target_name: None,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()),
|
||||
upstream_cluster_path: Some("/function_calling".to_string()),
|
||||
};
|
||||
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
debug!("http_call failed: {:?}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
|
||||
|
|
@ -337,9 +316,11 @@ impl HttpContext for StreamContext {
|
|||
let mut data = match serde_json::from_str(&body_utf8) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
warn!("could not deserialize response: {}", e);
|
||||
self.send_server_error(ServerError::Deserialization(e), None);
|
||||
return Action::Pause;
|
||||
warn!(
|
||||
"could not deserialize response, sending data as it is: {}",
|
||||
e
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ use proxy_wasm::traits::*;
|
|||
use proxy_wasm::types::*;
|
||||
|
||||
mod context;
|
||||
mod embeddings;
|
||||
mod filter_context;
|
||||
mod http_context;
|
||||
mod metrics;
|
||||
|
|
|
|||
|
|
@ -1,36 +1,20 @@
|
|||
use crate::embeddings::EmbeddingType;
|
||||
use crate::filter_context::EmbeddingsStore;
|
||||
use crate::metrics::Metrics;
|
||||
use acap::cos;
|
||||
use common::api::hallucination::{
|
||||
extract_messages_for_hallucination, HallucinationClassificationRequest,
|
||||
HallucinationClassificationResponse,
|
||||
};
|
||||
use common::api::open_ai::{
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool,
|
||||
ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
|
||||
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, Message, ModelServerResponse, ToolCall,
|
||||
};
|
||||
use common::api::prompt_guard::PromptGuardResponse;
|
||||
use common::api::zero_shot::{ZeroShotClassificationRequest, ZeroShotClassificationResponse};
|
||||
use common::configuration::{Overrides, PromptGuards, PromptTarget, Tracing};
|
||||
use common::configuration::{Overrides, PromptTarget, Tracing};
|
||||
use common::consts::{
|
||||
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
|
||||
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER,
|
||||
ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST,
|
||||
HALLUCINATION_TEMPLATE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE,
|
||||
TRACE_PARENT_HEADER, USER_ROLE, ZEROSHOT_INTERNAL_HOST,
|
||||
};
|
||||
use common::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE,
|
||||
TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
|
||||
};
|
||||
use common::errors::ServerError;
|
||||
use common::http::{CallArgs, Client};
|
||||
use common::stats::Gauge;
|
||||
use derivative::Derivative;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, trace, warn};
|
||||
use log::{debug, warn};
|
||||
use proxy_wasm::traits::*;
|
||||
use serde_yaml::Value;
|
||||
use std::cell::RefCell;
|
||||
|
|
@ -41,12 +25,8 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ResponseHandlerType {
|
||||
Embeddings,
|
||||
ArchFC,
|
||||
FunctionCall,
|
||||
ZeroShotIntent,
|
||||
Hallucination,
|
||||
ArchGuard,
|
||||
DefaultTarget,
|
||||
}
|
||||
|
||||
|
|
@ -66,8 +46,7 @@ pub struct StreamCallContext {
|
|||
pub struct StreamContext {
|
||||
system_prompt: Rc<Option<String>>,
|
||||
pub prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
pub embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
_overrides: Rc<Option<Overrides>>,
|
||||
pub metrics: Rc<Metrics>,
|
||||
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
pub context_id: u32,
|
||||
|
|
@ -79,12 +58,11 @@ pub struct StreamContext {
|
|||
pub streaming_response: bool,
|
||||
pub is_chat_completions_request: bool,
|
||||
pub chat_completions_request: Option<ChatCompletionsRequest>,
|
||||
pub prompt_guards: Rc<PromptGuards>,
|
||||
pub request_id: Option<String>,
|
||||
pub start_upstream_llm_request_time: u128,
|
||||
pub time_to_first_token: Option<u128>,
|
||||
pub traceparent: Option<String>,
|
||||
pub tracing: Rc<Option<Tracing>>,
|
||||
pub _tracing: Rc<Option<Tracing>>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -94,9 +72,7 @@ impl StreamContext {
|
|||
metrics: Rc<Metrics>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
tracing: Rc<Option<Tracing>>,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
|
|
@ -104,7 +80,6 @@ impl StreamContext {
|
|||
metrics,
|
||||
system_prompt,
|
||||
prompt_targets,
|
||||
embeddings_store,
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
chat_completions_request: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -114,32 +89,15 @@ impl StreamContext {
|
|||
streaming_response: false,
|
||||
user_prompt: None,
|
||||
is_chat_completions_request: false,
|
||||
prompt_guards,
|
||||
overrides,
|
||||
_overrides: overrides,
|
||||
request_id: None,
|
||||
traceparent: None,
|
||||
tracing,
|
||||
_tracing: tracing,
|
||||
start_upstream_llm_request_time: 0,
|
||||
time_to_first_token: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn embeddings_store(&self) -> &EmbeddingsStore {
|
||||
self.embeddings_store.as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn is_embedding_store_initialized(&self) -> bool {
|
||||
if self.embeddings_store.as_ref().is_none() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
|
||||
self.send_http_response(
|
||||
override_status_code
|
||||
|
|
@ -151,190 +109,8 @@ impl StreamContext {
|
|||
);
|
||||
}
|
||||
|
||||
pub fn get_embeddings(&mut self, callout_context: StreamCallContext) {
|
||||
let user_message = callout_context.user_message.unwrap();
|
||||
let get_embeddings_input = CreateEmbeddingRequest {
|
||||
// Need to clone into input because user_message is used below.
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let embeddings_request_str: String = match serde_json::to_string(&get_embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
warn!("error serializing get embeddings request: {}", error);
|
||||
return self.send_server_error(ServerError::Deserialization(error), None);
|
||||
}
|
||||
};
|
||||
|
||||
let mut headers = vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", EMBEDDINGS_INTERNAL_HOST),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
if self.trace_arch_internal() && self.traceparent.is_some() {
|
||||
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
"/embeddings",
|
||||
headers,
|
||||
Some(embeddings_request_str.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
let call_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::Embeddings,
|
||||
user_message: Some(user_message),
|
||||
prompt_target_name: None,
|
||||
request_body: callout_context.request_body,
|
||||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
|
||||
debug!(
|
||||
"archgw => get embeddings request: {}",
|
||||
embeddings_request_str
|
||||
);
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
warn!("error dispatching get embeddings request: {}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(e) => {
|
||||
warn!("error deserializing embedding response: {}", e);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
||||
let prompt_embeddings_vector = &embedding_response.data[0].embedding;
|
||||
|
||||
trace!(
|
||||
"embedding model: {}, vector length: {:?}",
|
||||
embedding_response.model,
|
||||
prompt_embeddings_vector.len()
|
||||
);
|
||||
|
||||
let prompt_target_names = self
|
||||
.prompt_targets
|
||||
.iter()
|
||||
// exclude default target
|
||||
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||
.map(|(name, _)| name.clone())
|
||||
.collect();
|
||||
|
||||
let similarity_scores: Vec<(String, f64)> = self
|
||||
.prompt_targets
|
||||
.iter()
|
||||
// exclude default prompt target
|
||||
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||
.map(|(prompt_name, _)| {
|
||||
let pte = match self.embeddings_store().get(prompt_name) {
|
||||
Some(embeddings) => embeddings,
|
||||
None => {
|
||||
warn!(
|
||||
"embeddings not found for prompt target name: {}",
|
||||
prompt_name
|
||||
);
|
||||
return (prompt_name.clone(), 0.0);
|
||||
}
|
||||
};
|
||||
|
||||
let description_embeddings = match pte.get(&EmbeddingType::Description) {
|
||||
Some(embeddings) => embeddings,
|
||||
None => {
|
||||
warn!(
|
||||
"description embeddings not found for prompt target name: {}",
|
||||
prompt_name
|
||||
);
|
||||
return (prompt_name.clone(), 0.0);
|
||||
}
|
||||
};
|
||||
let similarity_score_description =
|
||||
cos::cosine_similarity(&prompt_embeddings_vector, &description_embeddings);
|
||||
(prompt_name.clone(), similarity_score_description)
|
||||
})
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
"similarity scores based on description embeddings match: {:?}",
|
||||
similarity_scores
|
||||
);
|
||||
|
||||
callout_context.similarity_scores = Some(similarity_scores);
|
||||
|
||||
let zero_shot_classification_request = ZeroShotClassificationRequest {
|
||||
// Need to clone into input because user_message is used below.
|
||||
input: callout_context.user_message.as_ref().unwrap().clone(),
|
||||
model: String::from(DEFAULT_INTENT_MODEL),
|
||||
labels: prompt_target_names,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
debug!(
|
||||
"error serializing zero shot classification request: {}",
|
||||
error
|
||||
);
|
||||
return self.send_server_error(ServerError::Serialization(error), None);
|
||||
}
|
||||
};
|
||||
|
||||
let mut headers = vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, ZEROSHOT_INTERNAL_HOST),
|
||||
(":method", "POST"),
|
||||
(":path", "/zeroshot"),
|
||||
(":authority", ZEROSHOT_INTERNAL_HOST),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
if self.trace_arch_internal() && self.traceparent.is_some() {
|
||||
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
"/zeroshot",
|
||||
headers,
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching zero shot classification request: {}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
}
|
||||
|
||||
fn trace_arch_internal(&self) -> bool {
|
||||
match self.tracing.as_ref() {
|
||||
fn _trace_arch_internal(&self) -> bool {
|
||||
match self._tracing.as_ref() {
|
||||
Some(tracing) => match tracing.trace_arch_internal.as_ref() {
|
||||
Some(trace_arch_internal) => *trace_arch_internal,
|
||||
None => false,
|
||||
|
|
@ -343,359 +119,6 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn hallucination_classification_resp_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
callout_context: StreamCallContext,
|
||||
) {
|
||||
let body_str = String::from_utf8(body).expect("could not convert body to string");
|
||||
debug!("archgw <= hallucination response: {}", body_str);
|
||||
let hallucination_response: HallucinationClassificationResponse =
|
||||
match serde_json::from_str(body_str.as_str()) {
|
||||
Ok(hallucination_response) => hallucination_response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"error deserializing hallucination response: {}, body: {}",
|
||||
e,
|
||||
body_str.as_str()
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
let mut keys_with_low_score: Vec<String> = Vec::new();
|
||||
for (key, value) in &hallucination_response.params_scores {
|
||||
if *value < DEFAULT_HALLUCINATED_THRESHOLD {
|
||||
debug!(
|
||||
"hallucination detected: score for {} : {} is less than threshold {}",
|
||||
key, value, DEFAULT_HALLUCINATED_THRESHOLD
|
||||
);
|
||||
keys_with_low_score.push(key.clone().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if !keys_with_low_score.is_empty() {
|
||||
let response =
|
||||
HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?";
|
||||
|
||||
let response_str = if self.streaming_response {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(response),
|
||||
None,
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
to_server_events(chunks)
|
||||
} else {
|
||||
let chat_completion_response = ChatCompletionsResponse::new(response);
|
||||
serde_json::to_string(&chat_completion_response).unwrap()
|
||||
};
|
||||
debug!("hallucination response: {:?}", response_str);
|
||||
// make sure on_http_response_body does not attach tool calls and tool response to the response
|
||||
self.tool_calls = None;
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![],
|
||||
Some(response_str.as_bytes()),
|
||||
);
|
||||
} else {
|
||||
// not a hallucination, resume the flow
|
||||
self.schedule_api_call_request(callout_context);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn zero_shot_intent_detection_resp_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
mut callout_context: StreamCallContext,
|
||||
) {
|
||||
let zeroshot_intent_response: ZeroShotClassificationResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(zeroshot_response) => zeroshot_response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"error deserializing zero shot classification response: {}",
|
||||
e
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
||||
trace!(
|
||||
"zeroshot intent response: {}",
|
||||
serde_json::to_string(&zeroshot_intent_response).unwrap()
|
||||
);
|
||||
|
||||
let desc_emb_similarity_map: HashMap<String, f64> = callout_context
|
||||
.similarity_scores
|
||||
.clone()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let pred_class_desc_emb_similarity = desc_emb_similarity_map
|
||||
.get(&zeroshot_intent_response.predicted_class)
|
||||
.unwrap();
|
||||
|
||||
let prompt_target_similarity_score = zeroshot_intent_response.predicted_class_score * 0.7
|
||||
+ pred_class_desc_emb_similarity * 0.3;
|
||||
|
||||
debug!(
|
||||
"similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}, prompt: {}",
|
||||
prompt_target_similarity_score,
|
||||
zeroshot_intent_response.predicted_class_score,
|
||||
pred_class_desc_emb_similarity,
|
||||
callout_context.user_message.as_ref().unwrap()
|
||||
);
|
||||
|
||||
let prompt_target_name = zeroshot_intent_response.predicted_class.clone();
|
||||
|
||||
// Check to see who responded to user message. This will help us identify if control should be passed to Arch FC or not.
|
||||
// If the last message was from Arch FC, then Arch FC is handling the conversation (possibly for parameter collection).
|
||||
let mut arch_assistant = false;
|
||||
let messages = &callout_context.request_body.messages;
|
||||
if messages.len() >= 2 {
|
||||
let latest_assistant_message = &messages[messages.len() - 2];
|
||||
if let Some(model) = latest_assistant_message.model.as_ref() {
|
||||
if model.contains(ARCH_MODEL_PREFIX) {
|
||||
arch_assistant = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("no assistant message found, probably first interaction");
|
||||
}
|
||||
|
||||
// get prompt target similarity thresold from overrides
|
||||
let prompt_target_intent_matching_threshold = match self.overrides.as_ref() {
|
||||
Some(overrides) => match overrides.prompt_target_intent_matching_threshold {
|
||||
Some(threshold) => threshold,
|
||||
None => DEFAULT_PROMPT_TARGET_THRESHOLD,
|
||||
},
|
||||
None => DEFAULT_PROMPT_TARGET_THRESHOLD,
|
||||
};
|
||||
|
||||
// check to ensure that the prompt target similarity score is above the threshold
|
||||
if prompt_target_similarity_score < prompt_target_intent_matching_threshold
|
||||
|| arch_assistant
|
||||
{
|
||||
debug!("intent score is low or arch assistant is handling the conversation");
|
||||
// if arch fc responded to the user message, then we don't need to check the similarity score
|
||||
// it may be that arch fc is handling the conversation for parameter collection
|
||||
if arch_assistant {
|
||||
info!("arch fc is engaged in parameter collection");
|
||||
} else if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found, forwarding request to default prompt target");
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
if self.trace_arch_internal() && self.traceparent.is_some() {
|
||||
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
// if no default prompt target is found and similarity score is low send response to upstream llm
|
||||
// removing tool calls and tool response
|
||||
|
||||
let messages = self.filter_out_arch_messages(&callout_context);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model,
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"archgw (low similarity score) => llm request: {}",
|
||||
llm_request_str
|
||||
);
|
||||
|
||||
self.set_http_request_body(
|
||||
0,
|
||||
self.request_body_size,
|
||||
&llm_request_str.into_bytes(),
|
||||
);
|
||||
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(&prompt_target_name)
|
||||
.expect("prompt target not found")
|
||||
.clone();
|
||||
|
||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||
for pt in self.prompt_targets.values() {
|
||||
if pt.default.unwrap_or_default() {
|
||||
continue;
|
||||
}
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = FunctionParameter {
|
||||
parameter_type: ParameterType::from(
|
||||
entity.parameter_type.clone().unwrap_or("str".to_string()),
|
||||
),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
enum_values: entity.enum_values.clone(),
|
||||
default: entity.default.clone(),
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let tools_parameters = FunctionParameters { properties };
|
||||
|
||||
chat_completion_tools.push({
|
||||
ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionDefinition {
|
||||
name: pt.name.clone(),
|
||||
description: pt.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
},
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// archfc handler needs state so it can expand tool calls
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::to_string(&self.arch_state).unwrap(),
|
||||
);
|
||||
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: self
|
||||
.chat_completions_request
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(chat_completion_tools),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||
Ok(msg_body) => msg_body,
|
||||
Err(e) => {
|
||||
warn!("error serializing arch_fc request body: {}", e);
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, ARCH_FC_INTERNAL_HOST),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", ARCH_FC_INTERNAL_HOST),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
if self.trace_arch_internal() && self.traceparent.is_some() {
|
||||
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
"/v1/chat/completions",
|
||||
headers,
|
||||
Some(msg_body.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::ArchFC;
|
||||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
|
||||
debug!("archgw => archfc request: {}", msg_body);
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
debug!("error dispatching arch_fc request: {}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn arch_fc_response_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
|
|
@ -704,14 +127,87 @@ impl StreamContext {
|
|||
let body_str = String::from_utf8(body).unwrap();
|
||||
debug!("archgw <= archfc response: {}", body_str);
|
||||
|
||||
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
||||
let model_server_response: ModelServerResponse = match serde_json::from_str(&body_str) {
|
||||
Ok(arch_fc_response) => arch_fc_response,
|
||||
Err(e) => {
|
||||
warn!("error deserializing archfc response: {}", e);
|
||||
warn!(
|
||||
"error deserializing archfc response: {}, body: {}",
|
||||
e, body_str
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
||||
let arch_fc_response = match model_server_response {
|
||||
ModelServerResponse::ChatCompletionsResponse(response) => response,
|
||||
ModelServerResponse::ModelServerErrorResponse(response) => {
|
||||
debug!("archgw <= archfc error response: {}", response.result);
|
||||
if response.result == "No intent matched" {
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found, forwarding request to default prompt target");
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
// if self.trace_arch_internal() && self.traceparent.is_some() {
|
||||
// headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
// }
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name =
|
||||
Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
return self.send_server_error(
|
||||
ServerError::LogicError(response.result),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
arch_fc_response.choices[0]
|
||||
.message
|
||||
.tool_calls
|
||||
|
|
@ -767,114 +263,7 @@ impl StreamContext {
|
|||
);
|
||||
}
|
||||
|
||||
// TODO CO: pass nli check
|
||||
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(&tools_call_name)
|
||||
.expect("prompt target not found for tool call")
|
||||
.clone();
|
||||
|
||||
debug!(
|
||||
"prompt_target_name: {}, tool_name(s): {:?}",
|
||||
prompt_target.name,
|
||||
self.tool_calls
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|tc| tc.function.name.clone())
|
||||
.collect::<Vec<String>>(),
|
||||
);
|
||||
|
||||
// If hallucination, pass chat template to check parameters
|
||||
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
|
||||
|
||||
let mut tool_params = self.tool_calls.as_ref().unwrap()[0]
|
||||
.function
|
||||
.arguments
|
||||
.clone();
|
||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||
debug!(
|
||||
"tool_params (without messages history): {}",
|
||||
tool_params_json_str
|
||||
);
|
||||
tool_params.insert(
|
||||
String::from(MESSAGES_KEY),
|
||||
serde_yaml::to_value(&callout_context.request_body.messages).unwrap(),
|
||||
);
|
||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||
|
||||
use serde_json::Value;
|
||||
let v: Value = serde_json::from_str(&tool_params_json_str).unwrap();
|
||||
let tool_params_dict: HashMap<String, String> = match v.as_object() {
|
||||
Some(obj) => obj
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
// Convert each value to a string, regardless of its type
|
||||
(key.clone(), value.to_string())
|
||||
})
|
||||
.collect(),
|
||||
None => HashMap::new(), // Return an empty HashMap if v is not an object
|
||||
};
|
||||
|
||||
let all_user_messages =
|
||||
extract_messages_for_hallucination(&callout_context.request_body.messages);
|
||||
let user_messages_str = all_user_messages.join(", ");
|
||||
debug!("user messages: {}", user_messages_str);
|
||||
|
||||
let hallucination_classification_request = HallucinationClassificationRequest {
|
||||
prompt: user_messages_str,
|
||||
model: String::from(DEFAULT_INTENT_MODEL),
|
||||
parameters: tool_params_dict,
|
||||
};
|
||||
|
||||
let hallucination_request_str: String =
|
||||
match serde_json::to_string(&hallucination_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
debug!(
|
||||
"error serializing hallucination classification request: {}",
|
||||
error
|
||||
);
|
||||
return self.send_server_error(ServerError::Serialization(error), None);
|
||||
}
|
||||
};
|
||||
|
||||
let mut headers = vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, HALLUCINATION_INTERNAL_HOST),
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", HALLUCINATION_INTERNAL_HOST),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
if self.trace_arch_internal() && self.traceparent.is_some() {
|
||||
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
"/hallucination",
|
||||
headers,
|
||||
Some(hallucination_request_str.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::Hallucination;
|
||||
|
||||
debug!(
|
||||
"archgw => hallucination request: {}",
|
||||
hallucination_request_str
|
||||
);
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
self.schedule_api_call_request(callout_context);
|
||||
}
|
||||
|
||||
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
|
||||
|
|
@ -969,8 +358,9 @@ impl StreamContext {
|
|||
pub fn api_call_response_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
let http_status = self
|
||||
.get_http_call_response_header(":status")
|
||||
.expect("http status code not found");
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
.unwrap_or(StatusCode::OK.as_str().to_string());
|
||||
debug!("api_call_response_handler: http_status: {}", http_status);
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
warn!(
|
||||
"api server responded with non 2xx status code: {}",
|
||||
http_status
|
||||
|
|
@ -1093,56 +483,24 @@ impl StreamContext {
|
|||
messages
|
||||
}
|
||||
|
||||
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
debug!(
|
||||
"archgw <= archguard response: {:?}",
|
||||
serde_json::to_string(&prompt_guard_resp)
|
||||
);
|
||||
|
||||
if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() {
|
||||
//TODO: handle other scenarios like forward to error target
|
||||
let msg = self
|
||||
.prompt_guards
|
||||
.jailbreak_on_exception_message()
|
||||
.unwrap_or("refrain from discussing jailbreaking.");
|
||||
info!("jailbreak detected: {}", msg);
|
||||
|
||||
let response_str = if self.streaming_response {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(msg.to_string()),
|
||||
None,
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
to_server_events(chunks)
|
||||
} else {
|
||||
let chat_completion_response = ChatCompletionsResponse::new(msg.to_string());
|
||||
serde_json::to_string(&chat_completion_response).unwrap()
|
||||
};
|
||||
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![],
|
||||
Some(response_str.as_bytes()),
|
||||
);
|
||||
|
||||
return self.send_server_error(
|
||||
ServerError::Jailbreak(String::from(msg)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
pub fn generate_toll_call_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
self.get_embeddings(callout_context);
|
||||
pub fn generate_api_response_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: TOOL_ROLE.to_string(),
|
||||
content: self.tool_call_response.clone(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_target_handler(&self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
||||
|
|
@ -1264,26 +622,6 @@ impl StreamContext {
|
|||
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
||||
pub fn generate_toll_call_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_api_response_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: TOOL_ROLE.to_string(),
|
||||
content: self.tool_call_response.clone(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for StreamContext {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,7 @@
|
|||
use common::api::hallucination::HallucinationClassificationResponse;
|
||||
use common::api::open_ai::{
|
||||
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
|
||||
};
|
||||
use common::api::prompt_guard::PromptGuardResponse;
|
||||
use common::api::zero_shot::ZeroShotClassificationResponse;
|
||||
use common::configuration::Configuration;
|
||||
use common::embeddings::{
|
||||
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
|
||||
Embedding,
|
||||
};
|
||||
use http::StatusCode;
|
||||
use proxy_wasm_test_framework::tester::{self, Tester};
|
||||
use proxy_wasm_test_framework::types::{
|
||||
|
|
@ -83,13 +76,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "guard"),
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/guard"),
|
||||
(":authority", "guard"),
|
||||
(":path", "/function_calling"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
(":authority", "model_server"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
|
|
@ -97,139 +88,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
)
|
||||
.returning(Some(1))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
|
||||
let prompt_guard_response = PromptGuardResponse {
|
||||
toxic_prob: None,
|
||||
toxic_verdict: None,
|
||||
jailbreak_prob: None,
|
||||
jailbreak_verdict: None,
|
||||
};
|
||||
let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
1,
|
||||
0,
|
||||
prompt_guard_response_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&prompt_guard_response_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "embeddings"),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddings"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let embedding_response = CreateEmbeddingResponse {
|
||||
data: vec![Embedding {
|
||||
index: 0,
|
||||
embedding: vec![],
|
||||
object: embedding::Object::default(),
|
||||
}],
|
||||
model: String::from("test"),
|
||||
object: create_embedding_response::Object::default(),
|
||||
usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)),
|
||||
};
|
||||
let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
2,
|
||||
0,
|
||||
embeddings_response_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embeddings_response_buffer))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "zeroshot"),
|
||||
(":method", "POST"),
|
||||
(":path", "/zeroshot"),
|
||||
(":authority", "zeroshot"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(3))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let zero_shot_response = ZeroShotClassificationResponse {
|
||||
predicted_class: "weather_forecast".to_string(),
|
||||
predicted_class_score: 0.1,
|
||||
scores: HashMap::new(),
|
||||
model: "test-model".to_string(),
|
||||
};
|
||||
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
3,
|
||||
0,
|
||||
zeroshot_intent_detection_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&zeroshot_intent_detection_buffer))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
("x-arch-upstream", "arch_fc"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", "arch_fc"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "120000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
||||
|
|
@ -248,69 +111,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
|||
.execute_and_expect(ReturnType::Bool(true))
|
||||
.unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_tick(filter_context)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "embeddings"),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddings"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(101))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_set_tick_period_millis(Some(5000))
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let embedding_response = CreateEmbeddingResponse {
|
||||
data: vec![Embedding {
|
||||
embedding: vec![],
|
||||
index: 0,
|
||||
object: embedding::Object::default(),
|
||||
}],
|
||||
model: String::from("test"),
|
||||
object: create_embedding_response::Object::default(),
|
||||
usage: Box::new(CreateEmbeddingResponseUsage {
|
||||
prompt_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}),
|
||||
};
|
||||
let embedding_response_str = serde_json::to_string(&embedding_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
filter_context,
|
||||
101,
|
||||
0,
|
||||
embedding_response_str.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_log(
|
||||
Some(LogLevel::Trace),
|
||||
Some(
|
||||
format!(
|
||||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
101
|
||||
)
|
||||
.as_str(),
|
||||
),
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embedding_response_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
filter_context
|
||||
}
|
||||
|
||||
|
|
@ -435,6 +235,7 @@ fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
|
|||
.returning(Some(chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(Some("arch_internal"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
|
|
@ -538,8 +339,8 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
finish_reason: Some("test".to_string()),
|
||||
index: Some(0),
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: None,
|
||||
|
|
@ -564,7 +365,7 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
|
|
@ -572,47 +373,7 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "hallucination"),
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", "hallucination"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// hallucination should return that parameters were not halliucinated
|
||||
// prompt: str
|
||||
// parameters: dict
|
||||
// model: str
|
||||
|
||||
let hallucatination_body = HallucinationClassificationResponse {
|
||||
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
||||
model: "nli-model".to_string(),
|
||||
};
|
||||
|
||||
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
|
|
@ -628,14 +389,14 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(6))
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
|
|
@ -643,6 +404,10 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
||||
.returning(Some("200"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
|
@ -652,8 +417,8 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
finish_reason: Some("test".to_string()),
|
||||
index: Some(0),
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("hello from fake llm gateway".to_string()),
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ start_demo() {
|
|||
echo "Starting Arch with arch_config.yaml..."
|
||||
archgw up arch_config.yaml
|
||||
|
||||
# Step 4: Start Network Agent
|
||||
# Step 4: Start developer services
|
||||
echo "Starting Network Agent using Docker Compose..."
|
||||
docker compose up -d # Run in detached mode
|
||||
}
|
||||
|
|
|
|||
13
demos/currency_exchange/test_data.yaml
Normal file
13
demos/currency_exchange/test_data.yaml
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
test_cases:
|
||||
- id: "get exchange rate"
|
||||
input:
|
||||
messages:
|
||||
- role: user
|
||||
content: what is exchange rate for gbp
|
||||
expected_tools:
|
||||
- type: function
|
||||
function:
|
||||
name: currency_exchange
|
||||
arguments:
|
||||
currency_symbol: GBP
|
||||
expected_output_contains: gbp
|
||||
|
|
@ -15,3 +15,16 @@ services:
|
|||
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../shared/chatbot_ui
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
|
|
|
|||
|
|
@ -11,14 +11,9 @@ from typing import List, Optional, Tuple
|
|||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from openai import OpenAI
|
||||
from common import create_gradio_app
|
||||
|
||||
app = FastAPI()
|
||||
workforce_data_df = None
|
||||
demo_description = """This demo showcases how the **Arch** can be used to build an
|
||||
HR agent to manage workforce-related inquiries, workforce planning, and communication via Slack.
|
||||
It intelligently routes incoming prompts to the correct targets, providing concise and useful responses
|
||||
tailored for HR and workforce decision-making. """
|
||||
|
||||
with open("workforce_data.json") as file:
|
||||
workforce_data = json.load(file)
|
||||
|
|
@ -95,15 +90,5 @@ def get_workforce(request: WorkforceRequest):
|
|||
return response
|
||||
|
||||
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
client = OpenAI(
|
||||
api_key="--",
|
||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||
)
|
||||
|
||||
gr.mount_gradio_app(
|
||||
app, create_gradio_app(demo_description, client), path="/agent/chat"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True)
|
||||
|
|
|
|||
14
demos/hr_agent/test_data.yaml
Normal file
14
demos/hr_agent/test_data.yaml
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
test_cases:
|
||||
- id: get workforce data
|
||||
input:
|
||||
messages:
|
||||
- role: user
|
||||
content: what is workforce data for asia for fte employees
|
||||
expected_tools:
|
||||
- type: function
|
||||
function:
|
||||
name: workforce
|
||||
arguments:
|
||||
staffing_type: fte
|
||||
region: asia
|
||||
expected_output_contains: asia
|
||||
|
|
@ -34,7 +34,7 @@ prompt_targets:
|
|||
default: true
|
||||
|
||||
- name: get_policy_coverage
|
||||
description: Retrieve the coverage details for a given policy type (car, boat, house, motorcycle).
|
||||
description: Retrieve the coverage details for an insurance policy.
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /policy/coverage
|
||||
|
|
@ -42,7 +42,7 @@ prompt_targets:
|
|||
parameters:
|
||||
- name: policy_type
|
||||
type: str
|
||||
description: The type of policy, option - car, boat, house, motorcycle.
|
||||
description: The type of policy
|
||||
default: car
|
||||
required: true
|
||||
|
||||
|
|
@ -51,11 +51,11 @@ prompt_targets:
|
|||
name: app_server
|
||||
path: /policy/initiate
|
||||
http_method: POST
|
||||
description: Start a policy coverage for car, boat, motorcycle or house.
|
||||
description: Start a policy coverage for an insurance policy
|
||||
parameters:
|
||||
- name: policy_type
|
||||
type: str
|
||||
description: The type of policy, option - car, boat, house, motorcycle.
|
||||
description: The type of policy
|
||||
default: car
|
||||
required: true
|
||||
- name: deductible
|
||||
|
|
@ -84,11 +84,11 @@ prompt_targets:
|
|||
name: app_server
|
||||
path: /policy/deductible
|
||||
http_method: POST
|
||||
description: Update the deductible amount for a specific policy coverage.
|
||||
description: Update the deductible amount for a specific insurance policy coverage.
|
||||
parameters:
|
||||
- name: policy_id
|
||||
type: str
|
||||
description: The id of the policy
|
||||
description: The id of the insurance policy
|
||||
required: true
|
||||
- name: deductible
|
||||
type: float
|
||||
|
|
|
|||
19
demos/multi_turn_rag_agent/Dockerfile
Normal file
19
demos/multi_turn_rag_agent/Dockerfile
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
FROM python:3.12 AS base
|
||||
|
||||
FROM base AS builder
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
COPY requirements.txt /src/
|
||||
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
|
||||
|
||||
COPY . /src
|
||||
|
||||
FROM python:3.12-slim AS output
|
||||
|
||||
COPY --from=builder /runtime /usr/local
|
||||
|
||||
COPY . /app
|
||||
WORKDIR /app
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"]
|
||||
22
demos/multi_turn_rag_agent/README.md
Normal file
22
demos/multi_turn_rag_agent/README.md
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Multi-Turn Agentic Demo (RAG)
|
||||
|
||||
This demo showcases how the **Arch** can be used to build accurate multi-turn RAG agent by just writing simple APIs.
|
||||
|
||||

|
||||
|
||||
### Energy Source Q/A
|
||||
Provides information about various energy sources and considerations.
|
||||
|
||||
- **Endpoint**: `/agent/energy_source`
|
||||
- **Parameters**:
|
||||
- `energy_source` (`str`, **required**): A source of energy (e.g., `renewable`, `fossil`).
|
||||
- `consideration` (`str`, *optional*): A specific type of consideration for an energy source (e.g., `cost`, `economic`, `technology`).
|
||||
|
||||
# Starting the demo
|
||||
1. Please make sure the [pre-requisites](https://github.com/katanemo/arch/?tab=readme-ov-file#prerequisites) are installed correctly
|
||||
2. Start Arch
|
||||
```sh
|
||||
sh run_demo.sh
|
||||
```
|
||||
3. Navigate to http://localhost:18080
|
||||
4. Ask "give me information about renewable energy sources"
|
||||
59
demos/multi_turn_rag_agent/arch_config.yaml
Normal file
59
demos/multi_turn_rag_agent/arch_config.yaml
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
version: v0.1
|
||||
|
||||
listener:
|
||||
address: 127.0.0.1
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
endpoints:
|
||||
rag_energy_source_agent:
|
||||
endpoint: host.docker.internal:18083
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-4o-mini
|
||||
default: true
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant and can offer information about energy sources.
|
||||
You will get a JSON object with energy_source and consideration fields. Focus on answering the querstion using those fields.
|
||||
Keep your responses to just three main points to make it easy for the reader to digest the information
|
||||
|
||||
prompt_targets:
|
||||
- name: get_info_for_energy_source
|
||||
description: get information about an energy source
|
||||
parameters:
|
||||
- name: energy_source
|
||||
type: str
|
||||
description: a source of energy
|
||||
required: true
|
||||
enum: [renewable, fossil]
|
||||
- name: consideration
|
||||
type: str
|
||||
description: a specific type of consideration for an energy source
|
||||
enum: [cost, economic, technology]
|
||||
endpoint:
|
||||
name: rag_energy_source_agent
|
||||
path: /agent/energy_source_info
|
||||
http_method: POST
|
||||
|
||||
- name: default_target
|
||||
default: true
|
||||
description: This is the default target for all unmatched prompts.
|
||||
endpoint:
|
||||
name: rag_energy_source_agent
|
||||
path: /default_target
|
||||
http_method: POST
|
||||
system_prompt: |
|
||||
You are a helpful assistant! Summarize the user's request and provide a helpful response.
|
||||
# if it is set to false arch will send response that it received from this prompt target to the user
|
||||
# if true arch will forward the response to the default LLM
|
||||
auto_llm_dispatch_on_response: false
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
trace_arch_internal: true
|
||||
24
demos/multi_turn_rag_agent/docker-compose.yaml
Normal file
24
demos/multi_turn_rag_agent/docker-compose.yaml
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
services:
|
||||
rag_energy_source_agent:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18083:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../shared/chatbot_ui
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
42
demos/multi_turn_rag_agent/main.py
Normal file
42
demos/multi_turn_rag_agent/main.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from openai import OpenAI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Define the request model
|
||||
class EnergySourceRequest(BaseModel):
|
||||
energy_source: str
|
||||
consideration: Optional[str] = None
|
||||
|
||||
|
||||
class EnergySourceResponse(BaseModel):
|
||||
energy_source: str
|
||||
consideration: Optional[str] = None
|
||||
|
||||
|
||||
# Post method for device summary
|
||||
@app.post("/agent/energy_source_info")
|
||||
def get_workforce(request: EnergySourceRequest):
|
||||
"""
|
||||
Endpoint to get details about energy source
|
||||
"""
|
||||
considertion = "You don't have any specific consideration. Feel free to talk in a more open ended fashion"
|
||||
|
||||
if request.consideration is not None:
|
||||
considertion = f"Add specific focus on the following consideration when you summarize the content for the energy source: {request.consideration}"
|
||||
|
||||
response = {
|
||||
"energy_source": request.energy_source,
|
||||
"consideration": considertion,
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True)
|
||||
BIN
demos/multi_turn_rag_agent/mutli-turn-example.png
Normal file
BIN
demos/multi_turn_rag_agent/mutli-turn-example.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 852 KiB |
12
demos/multi_turn_rag_agent/requirements.txt
Normal file
12
demos/multi_turn_rag_agent/requirements.txt
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
fastapi
|
||||
uvicorn
|
||||
typing
|
||||
pandas
|
||||
gradio==5.3.0
|
||||
async_timeout==4.0.3
|
||||
loguru==0.7.2
|
||||
asyncio==3.4.3
|
||||
httpx==0.27.0
|
||||
python-dotenv==1.0.1
|
||||
pydantic==2.8.2
|
||||
openai==1.51.0
|
||||
47
demos/multi_turn_rag_agent/run_demo.sh
Normal file
47
demos/multi_turn_rag_agent/run_demo.sh
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Function to start the demo
|
||||
start_demo() {
|
||||
# Step 1: Check if .env file exists
|
||||
if [ -f ".env" ]; then
|
||||
echo ".env file already exists. Skipping creation."
|
||||
else
|
||||
# Step 2: Create `.env` file and set OpenAI key
|
||||
if [ -z "$OPENAI_API_KEY" ]; then
|
||||
echo "Error: OPENAI_API_KEY environment variable is not set for the demo."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Creating .env file..."
|
||||
echo "OPENAI_API_KEY=$OPENAI_API_KEY" > .env
|
||||
echo ".env file created with OPENAI_API_KEY."
|
||||
fi
|
||||
|
||||
# Step 3: Start Arch
|
||||
echo "Starting Arch with arch_config.yaml..."
|
||||
archgw up arch_config.yaml
|
||||
|
||||
# Step 4: Start Network Agent
|
||||
echo "Starting HR Agent using Docker Compose..."
|
||||
docker compose up -d # Run in detached mode
|
||||
}
|
||||
|
||||
# Function to stop the demo
|
||||
stop_demo() {
|
||||
# Step 1: Stop Docker Compose services
|
||||
echo "Stopping HR Agent using Docker Compose..."
|
||||
docker compose down -v
|
||||
|
||||
# Step 2: Stop Arch
|
||||
echo "Stopping Arch..."
|
||||
archgw down
|
||||
}
|
||||
|
||||
# Main script logic
|
||||
if [ "$1" == "down" ]; then
|
||||
stop_demo
|
||||
else
|
||||
# Default action is to bring the demo up
|
||||
start_demo
|
||||
fi
|
||||
|
|
@ -17,36 +17,32 @@ system_prompt: |
|
|||
You are a network assistant that helps operators with a better understanding of network traffic flow and perform actions on networking operations. No advice on manufacturers or purchasing decisions.
|
||||
|
||||
prompt_targets:
|
||||
- name: device_summary
|
||||
description: Retrieve network statistics for specific devices within a time range
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_summary
|
||||
http_method: POST
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs) to retrieve statistics for.
|
||||
required: true # device_ids are required to get device statistics
|
||||
- name: days
|
||||
type: int
|
||||
description: The number of days for which to gather device statistics.
|
||||
default: "7"
|
||||
- name: reboot_devices
|
||||
description: Reboot a list of devices
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_reboot
|
||||
http_method: POST
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs).
|
||||
required: true
|
||||
- name: days
|
||||
type: int
|
||||
description: A list of device identifiers (IDs)
|
||||
default: "7"
|
||||
- name: device_summary
|
||||
description: Retrieve network statistics for specific devices within a time range
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_summary
|
||||
http_method: POST
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs) to retrieve statistics for.
|
||||
required: true # device_ids are required to get device statistics
|
||||
- name: days
|
||||
type: int
|
||||
description: The number of days for which to gather device statistics.
|
||||
default: "7"
|
||||
- name: reboot_devices
|
||||
description: Reboot a list of devices
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_reboot
|
||||
http_method: POST
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs).
|
||||
required: true
|
||||
|
||||
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
|
||||
endpoints:
|
||||
|
|
@ -54,6 +50,6 @@ endpoints:
|
|||
# value could be ip address or a hostname with port
|
||||
# this could also be a list of endpoints for load balancing
|
||||
# for example endpoint: [ ip1:port, ip2:port ]
|
||||
endpoint: host.docker.internal:18080
|
||||
endpoint: host.docker.internal:18083
|
||||
# max time to wait for a connection to be established
|
||||
connect_timeout: 0.005s
|
||||
|
|
|
|||
|
|
@ -2,14 +2,19 @@ services:
|
|||
api_server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18083:80"
|
||||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../shared/chatbot_ui
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
- ../shared/chatbot_ui/common.py:/app/common.py
|
||||
ports:
|
||||
- "18080:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
from openai import OpenAI
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
from common import create_gradio_app
|
||||
import gradio as gr
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
app = FastAPI()
|
||||
demo_description = """This demo illustrates how **Arch** can be used to perform function calling with network-related tasks.
|
||||
In this demo, you act as a **network assistant** that provides factual information, without offering advice on manufacturers or purchasing decisions."""
|
||||
DEMO_DESCRIPTION = """This demo illustrates how **Arch** can be used to perform function calling
|
||||
with network-related tasks. In this demo, you act as a **network assistant** that provides factual
|
||||
information, without offering advice on manufacturers or purchasing decisions."""
|
||||
|
||||
|
||||
# Define the request model
|
||||
|
|
@ -52,7 +51,8 @@ def reboot_network_device(request_data: DeviceRebootRequest):
|
|||
# Access data from the Pydantic model
|
||||
device_ids = request_data.device_ids
|
||||
|
||||
# Validate 'device_ids' (This is already validated by Pydantic, but additional logic can be added if needed)
|
||||
# Validate 'device_ids'
|
||||
# (This is already validated by Pydantic, but additional logic can be added if needed)
|
||||
if not device_ids:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="'device_ids' parameter is required"
|
||||
|
|
@ -87,7 +87,8 @@ def get_device_summary(request: DeviceSummaryRequest):
|
|||
stats = {
|
||||
"device_id": device_id,
|
||||
"time_range": f"Last {time_range} days",
|
||||
"data": f"Device {device_id} over the last {time_range} days experienced {minutes} minutes of downtime.",
|
||||
"data": f"""Device {device_id} over the last {time_range} days experienced {minutes}
|
||||
minutes of downtime.""",
|
||||
}
|
||||
minutes += 1
|
||||
statistics.append(DeviceStatistics(**stats))
|
||||
|
|
@ -100,10 +101,3 @@ client = OpenAI(
|
|||
api_key="--",
|
||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||
)
|
||||
|
||||
gr.mount_gradio_app(
|
||||
app, create_gradio_app(demo_description, client), path="/agent/chat"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import pandas as pd
|
||||
import random
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import re
|
||||
import logging
|
||||
from dateparser import parse
|
||||
import random
|
||||
import re
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pandas as pd
|
||||
from dateparser import parse
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
|
|
@ -12,7 +13,7 @@ logging.basicConfig(
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_sql():
|
||||
def loadsql():
|
||||
# Example Usage
|
||||
conn = sqlite3.connect(":memory:")
|
||||
|
||||
|
|
@ -72,7 +73,10 @@ def random_mac():
|
|||
|
||||
# Function to generate random IP addresses
|
||||
def random_ip():
|
||||
return f"{random.randint(1, 255)}.{random.randint(1, 255)}.{random.randint(1, 255)}.{random.randint(1, 255)}"
|
||||
return f"""{random.randint(1, 255)}
|
||||
.{random.randint(1, 255)}
|
||||
.{random.randint(1, 255)}
|
||||
.{random.randint(1, 255)}"""
|
||||
|
||||
|
||||
# Generate synthetic data for the device table
|
||||
|
|
@ -88,7 +92,8 @@ def generate_device_data(
|
|||
"layer": ["L2" if i % 2 == 0 else "L3" for i in range(n)],
|
||||
"region": [random.choice(["US", "EU", "ASIA"]) for _ in range(n)],
|
||||
"uptime": [
|
||||
f"{random.randint(0, 10)} days {random.randint(0, 23)}:{random.randint(0, 59)}:{random.randint(0, 59)}"
|
||||
f"""{random.randint(0, 10)} days {random.randint(0, 23)}
|
||||
:{random.randint(0, 59)}:{random.randint(0, 59)}"""
|
||||
for _ in range(n)
|
||||
],
|
||||
"device_mac_address": [random_mac() for _ in range(n)],
|
||||
|
|
@ -129,7 +134,6 @@ def generate_interface_stats_data(conn, device_df, n=1000):
|
|||
)
|
||||
df = pd.DataFrame(interface_stats_data)
|
||||
df.to_sql("interfacestats", conn, index=False)
|
||||
return
|
||||
|
||||
|
||||
# Generate synthetic data for the ts_flow table
|
||||
|
|
@ -175,14 +179,13 @@ def generate_flow_data(conn, device_df, n=1000):
|
|||
)
|
||||
df = pd.DataFrame(flow_data)
|
||||
df.to_sql("ts_flow", conn, index=False)
|
||||
return
|
||||
|
||||
|
||||
def load_params(req):
|
||||
# Step 1: Convert the from_time natural language string to a timestamp if provided
|
||||
if req.from_time:
|
||||
# Use `dateparser` to parse natural language timeframes
|
||||
logger.info(f"{'* ' * 50}\n\nCaptured from time: {req.from_time}\n\n")
|
||||
logger.info("%s\n\nCaptured from time: %s\n\n", "* " * 50, req.from_time)
|
||||
parsed_time = parse(req.from_time, settings={"RELATIVE_BASE": datetime.now()})
|
||||
if not parsed_time:
|
||||
conv_time = convert_to_ago_format(req.from_time)
|
||||
|
|
@ -192,15 +195,16 @@ def load_params(req):
|
|||
)
|
||||
else:
|
||||
return {
|
||||
"error": "Invalid from_time format. Please provide a valid time description such as 'past 7 days' or 'since last month'."
|
||||
"error": """Invalid from_time format. Please provide a valid time description
|
||||
such as 'past 7 days' or 'since last month'."""
|
||||
}
|
||||
logger.info(f"\n\nConverted from time: {parsed_time}\n\n{'* ' * 50}\n\n")
|
||||
logger.info("\n\nConverted from time: %s\n\n%s\n\n", parsed_time, "* " * 50)
|
||||
from_time = parsed_time
|
||||
logger.info(f"Using parsed from_time: {from_time}")
|
||||
logger.info("Using parsed from_time: %f", from_time)
|
||||
else:
|
||||
# If no from_time is provided, use a default value (e.g., the past 7 days)
|
||||
from_time = datetime.now() - timedelta(days=7)
|
||||
logger.info(f"Using default from_time: {from_time}")
|
||||
logger.info("Using default from_time: %f", from_time)
|
||||
|
||||
# Step 2: Build the dynamic SQL query based on the optional filters
|
||||
filters = []
|
||||
|
|
|
|||
|
|
@ -1,25 +1,8 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
|
||||
PROMPT_GATEWAY_ENDPOINT = os.getenv(
|
||||
"PROMPT_GATEWAY_ENDPOINT", "http://localhost:10000/v1/chat/completions"
|
||||
)
|
||||
LLM_GATEWAY_ENDPOINT = os.getenv(
|
||||
"LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1/chat/completions"
|
||||
)
|
||||
ARCH_STATE_HEADER = "x-arch-state"
|
||||
|
||||
PREFILL_LIST = [
|
||||
"May",
|
||||
"Could",
|
||||
"Sure",
|
||||
"Definitely",
|
||||
"Certainly",
|
||||
"Of course",
|
||||
"Can",
|
||||
]
|
||||
|
||||
|
||||
def get_data_chunks(stream, n=1):
|
||||
chunks = []
|
||||
484
demos/test_runner/poetry.lock
generated
Normal file
484
demos/test_runner/poetry.lock
generated
Normal file
|
|
@ -0,0 +1,484 @@
|
|||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2024.12.14"
|
||||
description = "Python package for providing Mozilla's CA Bundle."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56"},
|
||||
{file = "certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.0"
|
||||
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4f9fc98dad6c2eaa32fc3af1417d95b5e3d08aff968df0cd320066def971f9a6"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0de7b687289d3c1b3e8660d0741874abe7888100efe14bd0f9fd7141bcbda92b"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5ed2e36c3e9b4f21dd9422f6893dec0abf2cca553af509b10cd630f878d3eb99"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d3ff7fc90b98c637bda91c89d51264a3dcf210cade3a2c6f838c7268d7a4ca"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1110e22af8ca26b90bd6364fe4c763329b0ebf1ee213ba32b68c73de5752323d"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:86f4e8cca779080f66ff4f191a685ced73d2f72d50216f7112185dc02b90b9b7"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f683ddc7eedd742e2889d2bfb96d69573fde1d92fcb811979cdb7165bb9c7d3"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27623ba66c183eca01bf9ff833875b459cad267aeeb044477fedac35e19ba907"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f606a1881d2663630ea5b8ce2efe2111740df4b687bd78b34a8131baa007f79b"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0b309d1747110feb25d7ed6b01afdec269c647d382c857ef4663bbe6ad95a912"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:136815f06a3ae311fae551c3df1f998a1ebd01ddd424aa5603a4336997629e95"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:14215b71a762336254351b00ec720a8e85cada43b987da5a042e4ce3e82bd68e"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:79983512b108e4a164b9c8d34de3992f76d48cadc9554c9e60b43f308988aabe"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-win32.whl", hash = "sha256:c94057af19bc953643a33581844649a7fdab902624d2eb739738a30e2b3e60fc"},
|
||||
{file = "charset_normalizer-3.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:55f56e2ebd4e3bc50442fbc0888c9d8c94e4e06a933804e2af3e89e2f9c1c749"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0d99dd8ff461990f12d6e42c7347fd9ab2532fb70e9621ba520f9e8637161d7c"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c57516e58fd17d03ebe67e181a4e4e2ccab1168f8c2976c6a334d4f819fe5944"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dba5d19c4dfab08e58d5b36304b3f92f3bd5d42c1a3fa37b5ba5cdf6dfcbcee"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf4475b82be41b07cc5e5ff94810e6a01f276e37c2d55571e3fe175e467a1a1c"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce031db0408e487fd2775d745ce30a7cd2923667cf3b69d48d219f1d8f5ddeb6"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ff4e7cdfdb1ab5698e675ca622e72d58a6fa2a8aa58195de0c0061288e6e3ea"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3710a9751938947e6327ea9f3ea6332a09bf0ba0c09cae9cb1f250bd1f1549bc"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82357d85de703176b5587dbe6ade8ff67f9f69a41c0733cf2425378b49954de5"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:47334db71978b23ebcf3c0f9f5ee98b8d65992b65c9c4f2d34c2eaf5bcaf0594"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8ce7fd6767a1cc5a92a639b391891bf1c268b03ec7e021c7d6d902285259685c"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f1a2f519ae173b5b6a2c9d5fa3116ce16e48b3462c8b96dfdded11055e3d6365"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:63bc5c4ae26e4bc6be6469943b8253c0fd4e4186c43ad46e713ea61a0ba49129"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bcb4f8ea87d03bc51ad04add8ceaf9b0f085ac045ab4d74e73bbc2dc033f0236"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-win32.whl", hash = "sha256:9ae4ef0b3f6b41bad6366fb0ea4fc1d7ed051528e113a60fa2a65a9abb5b1d99"},
|
||||
{file = "charset_normalizer-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cee4373f4d3ad28f1ab6290684d8e2ebdb9e7a1b74fdc39e4c211995f77bec27"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0713f3adb9d03d49d365b70b84775d0a0d18e4ab08d12bc46baa6132ba78aaf6"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:de7376c29d95d6719048c194a9cf1a1b0393fbe8488a22008610b0361d834ecf"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a51b48f42d9358460b78725283f04bddaf44a9358197b889657deba38f329db"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b295729485b06c1a0683af02a9e42d2caa9db04a373dc38a6a58cdd1e8abddf1"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee803480535c44e7f5ad00788526da7d85525cfefaf8acf8ab9a310000be4b03"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d59d125ffbd6d552765510e3f31ed75ebac2c7470c7274195b9161a32350284"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cda06946eac330cbe6598f77bb54e690b4ca93f593dee1568ad22b04f347c15"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07afec21bbbbf8a5cc3651aa96b980afe2526e7f048fdfb7f1014d84acc8b6d8"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6b40e8d38afe634559e398cc32b1472f376a4099c75fe6299ae607e404c033b2"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b8dcd239c743aa2f9c22ce674a145e0a25cb1566c495928440a181ca1ccf6719"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:84450ba661fb96e9fd67629b93d2941c871ca86fc38d835d19d4225ff946a631"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:44aeb140295a2f0659e113b31cfe92c9061622cadbc9e2a2f7b8ef6b1e29ef4b"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1db4e7fefefd0f548d73e2e2e041f9df5c59e178b4c72fbac4cc6f535cfb1565"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-win32.whl", hash = "sha256:5726cf76c982532c1863fb64d8c6dd0e4c90b6ece9feb06c9f202417a31f7dd7"},
|
||||
{file = "charset_normalizer-3.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:b197e7094f232959f8f20541ead1d9862ac5ebea1d58e9849c1bf979255dfac9"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dd4eda173a9fcccb5f2e2bd2a9f423d180194b1bf17cf59e3269899235b2a114"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9e3c4c9e1ed40ea53acf11e2a386383c3304212c965773704e4603d589343ed"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92a7e36b000bf022ef3dbb9c46bfe2d52c047d5e3f3343f43204263c5addc250"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54b6a92d009cbe2fb11054ba694bc9e284dad30a26757b1e372a1fdddaf21920"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ffd9493de4c922f2a38c2bf62b831dcec90ac673ed1ca182fe11b4d8e9f2a64"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35c404d74c2926d0287fbd63ed5d27eb911eb9e4a3bb2c6d294f3cfd4a9e0c23"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4796efc4faf6b53a18e3d46343535caed491776a22af773f366534056c4e1fbc"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7fdd52961feb4c96507aa649550ec2a0d527c086d284749b2f582f2d40a2e0d"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:92db3c28b5b2a273346bebb24857fda45601aef6ae1c011c0a997106581e8a88"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ab973df98fc99ab39080bfb0eb3a925181454d7c3ac8a1e695fddfae696d9e90"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4b67fdab07fdd3c10bb21edab3cbfe8cf5696f453afce75d815d9d7223fbe88b"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aa41e526a5d4a9dfcfbab0716c7e8a1b215abd3f3df5a45cf18a12721d31cb5d"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ffc519621dce0c767e96b9c53f09c5d215578e10b02c285809f76509a3931482"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-win32.whl", hash = "sha256:f19c1585933c82098c2a520f8ec1227f20e339e33aca8fa6f956f6691b784e67"},
|
||||
{file = "charset_normalizer-3.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:707b82d19e65c9bd28b81dde95249b07bf9f5b90ebe1ef17d9b57473f8a64b7b"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:dbe03226baf438ac4fda9e2d0715022fd579cb641c4cf639fa40d53b2fe6f3e2"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd9a8bd8900e65504a305bf8ae6fa9fbc66de94178c420791d0293702fce2df7"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8831399554b92b72af5932cdbbd4ddc55c55f631bb13ff8fe4e6536a06c5c51"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a14969b8691f7998e74663b77b4c36c0337cb1df552da83d5c9004a93afdb574"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcaf7c1524c0542ee2fc82cc8ec337f7a9f7edee2532421ab200d2b920fc97cf"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425c5f215d0eecee9a56cdb703203dda90423247421bf0d67125add85d0c4455"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:d5b054862739d276e09928de37c79ddeec42a6e1bfc55863be96a36ba22926f6"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:f3e73a4255342d4eb26ef6df01e3962e73aa29baa3124a8e824c5d3364a65748"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:2f6c34da58ea9c1a9515621f4d9ac379871a8f21168ba1b5e09d74250de5ad62"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:f09cb5a7bbe1ecae6e87901a2eb23e0256bb524a79ccc53eb0b7629fbe7677c4"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:0099d79bdfcf5c1f0c2c72f91516702ebf8b0b8ddd8905f97a8aecf49712c621"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-win32.whl", hash = "sha256:9c98230f5042f4945f957d006edccc2af1e03ed5e37ce7c373f00a5a4daa6149"},
|
||||
{file = "charset_normalizer-3.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:62f60aebecfc7f4b82e3f639a7d1433a20ec32824db2199a11ad4f5e146ef5ee"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:af73657b7a68211996527dbfeffbb0864e043d270580c5aef06dc4b659a4b578"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cab5d0b79d987c67f3b9e9c53f54a61360422a5a0bc075f43cab5621d530c3b6"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9289fd5dddcf57bab41d044f1756550f9e7cf0c8e373b8cdf0ce8773dc4bd417"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b493a043635eb376e50eedf7818f2f322eabbaa974e948bd8bdd29eb7ef2a51"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fa2566ca27d67c86569e8c85297aaf413ffab85a8960500f12ea34ff98e4c41"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8e538f46104c815be19c975572d74afb53f29650ea2025bbfaef359d2de2f7f"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fd30dc99682dc2c603c2b315bded2799019cea829f8bf57dc6b61efde6611c8"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2006769bd1640bdf4d5641c69a3d63b71b81445473cac5ded39740a226fa88ab"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:dc15e99b2d8a656f8e666854404f1ba54765871104e50c8e9813af8a7db07f12"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ab2e5bef076f5a235c3774b4f4028a680432cded7cad37bba0fd90d64b187d19"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:4ec9dd88a5b71abfc74e9df5ebe7921c35cbb3b641181a531ca65cdb5e8e4dea"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:43193c5cda5d612f247172016c4bb71251c784d7a4d9314677186a838ad34858"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:aa693779a8b50cd97570e5a0f343538a8dbd3e496fa5dcb87e29406ad0299654"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-win32.whl", hash = "sha256:7706f5850360ac01d80c89bcef1640683cc12ed87f42579dab6c5d3ed6888613"},
|
||||
{file = "charset_normalizer-3.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:c3e446d253bd88f6377260d07c895816ebf33ffffd56c1c792b13bff9c3e1ade"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:980b4f289d1d90ca5efcf07958d3eb38ed9c0b7676bf2831a54d4f66f9c27dfa"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f28f891ccd15c514a0981f3b9db9aa23d62fe1a99997512b0491d2ed323d229a"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8aacce6e2e1edcb6ac625fb0f8c3a9570ccc7bfba1f63419b3769ccf6a00ed0"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7af3717683bea4c87acd8c0d3d5b44d56120b26fd3f8a692bdd2d5260c620a"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ff2ed8194587faf56555927b3aa10e6fb69d931e33953943bc4f837dfee2242"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e91f541a85298cf35433bf66f3fab2a4a2cff05c127eeca4af174f6d497f0d4b"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:309a7de0a0ff3040acaebb35ec45d18db4b28232f21998851cfa709eeff49d62"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:285e96d9d53422efc0d7a17c60e59f37fbf3dfa942073f666db4ac71e8d726d0"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5d447056e2ca60382d460a604b6302d8db69476fd2015c81e7c35417cfabe4cd"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:20587d20f557fe189b7947d8e7ec5afa110ccf72a3128d61a2a387c3313f46be"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:130272c698667a982a5d0e626851ceff662565379baf0ff2cc58067b81d4f11d"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ab22fbd9765e6954bc0bcff24c25ff71dcbfdb185fcdaca49e81bac68fe724d3"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7782afc9b6b42200f7362858f9e73b1f8316afb276d316336c0ec3bd73312742"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-win32.whl", hash = "sha256:2de62e8801ddfff069cd5c504ce3bc9672b23266597d4e4f50eda28846c322f2"},
|
||||
{file = "charset_normalizer-3.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:95c3c157765b031331dd4db3c775e58deaee050a3042fcad72cbc4189d7c8dca"},
|
||||
{file = "charset_normalizer-3.4.0-py3-none-any.whl", hash = "sha256:fe9f97feb71aa9896b81973a7bbada8c49501dc73e58a10fcef6663af95e5079"},
|
||||
{file = "charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
description = "Cross-platform colored terminal text."
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.6.9"
|
||||
description = "Code coverage measurement for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "coverage-7.6.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:85d9636f72e8991a1706b2b55b06c27545448baf9f6dbf51c4004609aacd7dcb"},
|
||||
{file = "coverage-7.6.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:608a7fd78c67bee8936378299a6cb9f5149bb80238c7a566fc3e6717a4e68710"},
|
||||
{file = "coverage-7.6.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96d636c77af18b5cb664ddf12dab9b15a0cfe9c0bde715da38698c8cea748bfa"},
|
||||
{file = "coverage-7.6.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d75cded8a3cff93da9edc31446872d2997e327921d8eed86641efafd350e1df1"},
|
||||
{file = "coverage-7.6.9-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7b15f589593110ae767ce997775d645b47e5cbbf54fd322f8ebea6277466cec"},
|
||||
{file = "coverage-7.6.9-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:44349150f6811b44b25574839b39ae35291f6496eb795b7366fef3bd3cf112d3"},
|
||||
{file = "coverage-7.6.9-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:d891c136b5b310d0e702e186d70cd16d1119ea8927347045124cb286b29297e5"},
|
||||
{file = "coverage-7.6.9-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:db1dab894cc139f67822a92910466531de5ea6034ddfd2b11c0d4c6257168073"},
|
||||
{file = "coverage-7.6.9-cp310-cp310-win32.whl", hash = "sha256:41ff7b0da5af71a51b53f501a3bac65fb0ec311ebed1632e58fc6107f03b9198"},
|
||||
{file = "coverage-7.6.9-cp310-cp310-win_amd64.whl", hash = "sha256:35371f8438028fdccfaf3570b31d98e8d9eda8bb1d6ab9473f5a390969e98717"},
|
||||
{file = "coverage-7.6.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:932fc826442132dde42ee52cf66d941f581c685a6313feebed358411238f60f9"},
|
||||
{file = "coverage-7.6.9-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:085161be5f3b30fd9b3e7b9a8c301f935c8313dcf928a07b116324abea2c1c2c"},
|
||||
{file = "coverage-7.6.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ccc660a77e1c2bf24ddbce969af9447a9474790160cfb23de6be4fa88e3951c7"},
|
||||
{file = "coverage-7.6.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c69e42c892c018cd3c8d90da61d845f50a8243062b19d228189b0224150018a9"},
|
||||
{file = "coverage-7.6.9-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0824a28ec542a0be22f60c6ac36d679e0e262e5353203bea81d44ee81fe9c6d4"},
|
||||
{file = "coverage-7.6.9-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4401ae5fc52ad8d26d2a5d8a7428b0f0c72431683f8e63e42e70606374c311a1"},
|
||||
{file = "coverage-7.6.9-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:98caba4476a6c8d59ec1eb00c7dd862ba9beca34085642d46ed503cc2d440d4b"},
|
||||
{file = "coverage-7.6.9-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ee5defd1733fd6ec08b168bd4f5387d5b322f45ca9e0e6c817ea6c4cd36313e3"},
|
||||
{file = "coverage-7.6.9-cp311-cp311-win32.whl", hash = "sha256:f2d1ec60d6d256bdf298cb86b78dd715980828f50c46701abc3b0a2b3f8a0dc0"},
|
||||
{file = "coverage-7.6.9-cp311-cp311-win_amd64.whl", hash = "sha256:0d59fd927b1f04de57a2ba0137166d31c1a6dd9e764ad4af552912d70428c92b"},
|
||||
{file = "coverage-7.6.9-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:99e266ae0b5d15f1ca8d278a668df6f51cc4b854513daab5cae695ed7b721cf8"},
|
||||
{file = "coverage-7.6.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9901d36492009a0a9b94b20e52ebfc8453bf49bb2b27bca2c9706f8b4f5a554a"},
|
||||
{file = "coverage-7.6.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abd3e72dd5b97e3af4246cdada7738ef0e608168de952b837b8dd7e90341f015"},
|
||||
{file = "coverage-7.6.9-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff74026a461eb0660366fb01c650c1d00f833a086b336bdad7ab00cc952072b3"},
|
||||
{file = "coverage-7.6.9-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65dad5a248823a4996724a88eb51d4b31587aa7aa428562dbe459c684e5787ae"},
|
||||
{file = "coverage-7.6.9-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:22be16571504c9ccea919fcedb459d5ab20d41172056206eb2994e2ff06118a4"},
|
||||
{file = "coverage-7.6.9-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f957943bc718b87144ecaee70762bc2bc3f1a7a53c7b861103546d3a403f0a6"},
|
||||
{file = "coverage-7.6.9-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0ae1387db4aecb1f485fb70a6c0148c6cdaebb6038f1d40089b1fc84a5db556f"},
|
||||
{file = "coverage-7.6.9-cp312-cp312-win32.whl", hash = "sha256:1a330812d9cc7ac2182586f6d41b4d0fadf9be9049f350e0efb275c8ee8eb692"},
|
||||
{file = "coverage-7.6.9-cp312-cp312-win_amd64.whl", hash = "sha256:b12c6b18269ca471eedd41c1b6a1065b2f7827508edb9a7ed5555e9a56dcfc97"},
|
||||
{file = "coverage-7.6.9-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:899b8cd4781c400454f2f64f7776a5d87bbd7b3e7f7bda0cb18f857bb1334664"},
|
||||
{file = "coverage-7.6.9-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:61f70dc68bd36810972e55bbbe83674ea073dd1dcc121040a08cdf3416c5349c"},
|
||||
{file = "coverage-7.6.9-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a289d23d4c46f1a82d5db4abeb40b9b5be91731ee19a379d15790e53031c014"},
|
||||
{file = "coverage-7.6.9-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e216d8044a356fc0337c7a2a0536d6de07888d7bcda76febcb8adc50bdbbd00"},
|
||||
{file = "coverage-7.6.9-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c026eb44f744acaa2bda7493dad903aa5bf5fc4f2554293a798d5606710055d"},
|
||||
{file = "coverage-7.6.9-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e77363e8425325384f9d49272c54045bbed2f478e9dd698dbc65dbc37860eb0a"},
|
||||
{file = "coverage-7.6.9-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:777abfab476cf83b5177b84d7486497e034eb9eaea0d746ce0c1268c71652077"},
|
||||
{file = "coverage-7.6.9-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:447af20e25fdbe16f26e84eb714ba21d98868705cb138252d28bc400381f6ffb"},
|
||||
{file = "coverage-7.6.9-cp313-cp313-win32.whl", hash = "sha256:d872ec5aeb086cbea771c573600d47944eea2dcba8be5f3ee649bfe3cb8dc9ba"},
|
||||
{file = "coverage-7.6.9-cp313-cp313-win_amd64.whl", hash = "sha256:fd1213c86e48dfdc5a0cc676551db467495a95a662d2396ecd58e719191446e1"},
|
||||
{file = "coverage-7.6.9-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:ba9e7484d286cd5a43744e5f47b0b3fb457865baf07bafc6bee91896364e1419"},
|
||||
{file = "coverage-7.6.9-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e5ea1cf0872ee455c03e5674b5bca5e3e68e159379c1af0903e89f5eba9ccc3a"},
|
||||
{file = "coverage-7.6.9-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d10e07aa2b91835d6abec555ec8b2733347956991901eea6ffac295f83a30e4"},
|
||||
{file = "coverage-7.6.9-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:13a9e2d3ee855db3dd6ea1ba5203316a1b1fd8eaeffc37c5b54987e61e4194ae"},
|
||||
{file = "coverage-7.6.9-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c38bf15a40ccf5619fa2fe8f26106c7e8e080d7760aeccb3722664c8656b030"},
|
||||
{file = "coverage-7.6.9-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:d5275455b3e4627c8e7154feaf7ee0743c2e7af82f6e3b561967b1cca755a0be"},
|
||||
{file = "coverage-7.6.9-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:8f8770dfc6e2c6a2d4569f411015c8d751c980d17a14b0530da2d7f27ffdd88e"},
|
||||
{file = "coverage-7.6.9-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8d2dfa71665a29b153a9681edb1c8d9c1ea50dfc2375fb4dac99ea7e21a0bcd9"},
|
||||
{file = "coverage-7.6.9-cp313-cp313t-win32.whl", hash = "sha256:5e6b86b5847a016d0fbd31ffe1001b63355ed309651851295315031ea7eb5a9b"},
|
||||
{file = "coverage-7.6.9-cp313-cp313t-win_amd64.whl", hash = "sha256:97ddc94d46088304772d21b060041c97fc16bdda13c6c7f9d8fcd8d5ae0d8611"},
|
||||
{file = "coverage-7.6.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:adb697c0bd35100dc690de83154627fbab1f4f3c0386df266dded865fc50a902"},
|
||||
{file = "coverage-7.6.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:be57b6d56e49c2739cdf776839a92330e933dd5e5d929966fbbd380c77f060be"},
|
||||
{file = "coverage-7.6.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1592791f8204ae9166de22ba7e6705fa4ebd02936c09436a1bb85aabca3e599"},
|
||||
{file = "coverage-7.6.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4e12ae8cc979cf83d258acb5e1f1cf2f3f83524d1564a49d20b8bec14b637f08"},
|
||||
{file = "coverage-7.6.9-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb5555cff66c4d3d6213a296b360f9e1a8e323e74e0426b6c10ed7f4d021e464"},
|
||||
{file = "coverage-7.6.9-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:b9389a429e0e5142e69d5bf4a435dd688c14478a19bb901735cdf75e57b13845"},
|
||||
{file = "coverage-7.6.9-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:592ac539812e9b46046620341498caf09ca21023c41c893e1eb9dbda00a70cbf"},
|
||||
{file = "coverage-7.6.9-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a27801adef24cc30871da98a105f77995e13a25a505a0161911f6aafbd66e678"},
|
||||
{file = "coverage-7.6.9-cp39-cp39-win32.whl", hash = "sha256:8e3c3e38930cfb729cb8137d7f055e5a473ddaf1217966aa6238c88bd9fd50e6"},
|
||||
{file = "coverage-7.6.9-cp39-cp39-win_amd64.whl", hash = "sha256:e28bf44afa2b187cc9f41749138a64435bf340adfcacb5b2290c070ce99839d4"},
|
||||
{file = "coverage-7.6.9-pp39.pp310-none-any.whl", hash = "sha256:f3ca78518bc6bc92828cd11867b121891d75cae4ea9e908d72030609b996db1b"},
|
||||
{file = "coverage-7.6.9.tar.gz", hash = "sha256:4a8d8977b0c6ef5aeadcb644da9e69ae0dcfe66ec7f368c89c72e058bd71164d"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
toml = ["tomli"]
|
||||
|
||||
[[package]]
|
||||
name = "deepdiff"
|
||||
version = "8.1.1"
|
||||
description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "deepdiff-8.1.1-py3-none-any.whl", hash = "sha256:b0231fa3afb0f7184e82535f2b4a36636442ed21e94a0cf3aaa7982157e7ebca"},
|
||||
{file = "deepdiff-8.1.1.tar.gz", hash = "sha256:dd7bc7d5c8b51b5b90f01b0e2fe23c801fd8b4c6a7ee7e31c5a3c3663fcc7ceb"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
orderly-set = ">=5.2.3,<6"
|
||||
|
||||
[package.extras]
|
||||
cli = ["click (==8.1.7)", "pyyaml (==6.0.2)"]
|
||||
optimize = ["orjson"]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
description = "Internationalized Domain Names in Applications (IDNA)"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"},
|
||||
{file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.0.0"
|
||||
description = "brain-dead simple config-ini parsing"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
|
||||
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "orderly-set"
|
||||
version = "5.2.3"
|
||||
description = "Orderly set"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "orderly_set-5.2.3-py3-none-any.whl", hash = "sha256:d357cedcf67f4ebff0d4cbd5b0997e98eeb65dd24fdf5c990a501ae9e82c7d34"},
|
||||
{file = "orderly_set-5.2.3.tar.gz", hash = "sha256:571ed97c5a5fca7ddeb6b2d26c19aca896b0ed91f334d9c109edd2f265fb3017"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "24.2"
|
||||
description = "Core utilities for Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
|
||||
{file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.5.0"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
|
||||
{file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "tox"]
|
||||
testing = ["pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.3.4"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"},
|
||||
{file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
||||
iniconfig = "*"
|
||||
packaging = "*"
|
||||
pluggy = ">=1.5,<2"
|
||||
|
||||
[package.extras]
|
||||
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "4.1.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"},
|
||||
{file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
coverage = {version = ">=5.2.1", extras = ["toml"]}
|
||||
pytest = ">=4.6"
|
||||
|
||||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-retry"
|
||||
version = "1.6.3"
|
||||
description = "Adds the ability to retry flaky tests in CI environments"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pytest_retry-1.6.3-py3-none-any.whl", hash = "sha256:e96f7df77ee70b0838d1085f9c3b8b5b7d74bf8947a0baf32e2b8c71b27683c8"},
|
||||
{file = "pytest_retry-1.6.3.tar.gz", hash = "sha256:36ccfa11c8c8f9ddad5e20375182146d040c20c4a791745139c5a99ddf1b557d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "flake8", "isort", "mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-sugar"
|
||||
version = "1.0.0"
|
||||
description = "pytest-sugar is a plugin for pytest that changes the default look and feel of pytest (e.g. progressbar, show tests that fail instantly)."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pytest-sugar-1.0.0.tar.gz", hash = "sha256:6422e83258f5b0c04ce7c632176c7732cab5fdb909cb39cca5c9139f81276c0a"},
|
||||
{file = "pytest_sugar-1.0.0-py3-none-any.whl", hash = "sha256:70ebcd8fc5795dc457ff8b69d266a4e2e8a74ae0c3edc749381c64b5246c8dfd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = ">=21.3"
|
||||
pytest = ">=6.2.0"
|
||||
termcolor = ">=2.1.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "flake8", "pre-commit"]
|
||||
|
||||
[[package]]
|
||||
name = "pyyaml"
|
||||
version = "6.0.2"
|
||||
description = "YAML parser and emitter for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"},
|
||||
{file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"},
|
||||
{file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"},
|
||||
{file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"},
|
||||
{file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"},
|
||||
{file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"},
|
||||
{file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"},
|
||||
{file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"},
|
||||
{file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"},
|
||||
{file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"},
|
||||
{file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"},
|
||||
{file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"},
|
||||
{file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"},
|
||||
{file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"},
|
||||
{file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"},
|
||||
{file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"},
|
||||
{file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"},
|
||||
{file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"},
|
||||
{file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"},
|
||||
{file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"},
|
||||
{file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"},
|
||||
{file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"},
|
||||
{file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"},
|
||||
{file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"},
|
||||
{file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"},
|
||||
{file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"},
|
||||
{file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"},
|
||||
{file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"},
|
||||
{file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"},
|
||||
{file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"},
|
||||
{file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"},
|
||||
{file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"},
|
||||
{file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"},
|
||||
{file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"},
|
||||
{file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"},
|
||||
{file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"},
|
||||
{file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"},
|
||||
{file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"},
|
||||
{file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"},
|
||||
{file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"},
|
||||
{file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"},
|
||||
{file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"},
|
||||
{file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"},
|
||||
{file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"},
|
||||
{file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"},
|
||||
{file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"},
|
||||
{file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"},
|
||||
{file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"},
|
||||
{file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"},
|
||||
{file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"},
|
||||
{file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"},
|
||||
{file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"},
|
||||
{file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.3"
|
||||
description = "Python HTTP for Humans."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
|
||||
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = ">=2017.4.17"
|
||||
charset-normalizer = ">=2,<4"
|
||||
idna = ">=2.5,<4"
|
||||
urllib3 = ">=1.21.1,<3"
|
||||
|
||||
[package.extras]
|
||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "2.5.0"
|
||||
description = "ANSI color formatting for output in terminal"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"},
|
||||
{file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.2.3"
|
||||
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"},
|
||||
{file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
|
||||
h2 = ["h2 (>=4,<5)"]
|
||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "be71fbf338d83dd6170e7d9df2df8d1035a07397fe5ffc522c96a92cfe3318bd"
|
||||
27
demos/test_runner/pyproject.toml
Normal file
27
demos/test_runner/pyproject.toml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
[tool.poetry]
|
||||
name = "demo tests"
|
||||
version = "0.0.1"
|
||||
description = "demo tests runner"
|
||||
authors = ["Katanemo Labs, Inc <info@katanemo.com>"]
|
||||
license = "Apache 2.0"
|
||||
readme = "README.md"
|
||||
package-mode = false
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.12"
|
||||
pytest = "^8.3.3"
|
||||
requests = "^2.29.0"
|
||||
pytest-sugar = "^1.0.0"
|
||||
deepdiff = "^8.0.1"
|
||||
pytest-retry = "^1.6.3"
|
||||
pyyaml = "*"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest-cov = "^4.1.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
python_files = ["test*.py"]
|
||||
addopts = ["-v", "-s"]
|
||||
retries = 2
|
||||
retry_delay = 0.5
|
||||
cumulative_timing = false
|
||||
18
demos/test_runner/run_demo_tests.sh
Normal file
18
demos/test_runner/run_demo_tests.sh
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
set -eu
|
||||
|
||||
for demo in currency_exchange hr_agent
|
||||
do
|
||||
echo "******************************************"
|
||||
echo "Running tests for $demo ..."
|
||||
echo "****************************************"
|
||||
cd ../$demo
|
||||
archgw up arch_config.yaml
|
||||
docker compose up -d
|
||||
cd ../test_runner
|
||||
TEST_DATA=../$demo/test_data.yaml poetry run pytest
|
||||
cd ../$demo
|
||||
docker compose down -v
|
||||
archgw down
|
||||
cd ../test_runner
|
||||
done
|
||||
60
demos/test_runner/test_demos.py
Normal file
60
demos/test_runner/test_demos.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
import os
|
||||
from common import get_arch_messages
|
||||
import pytest
|
||||
import requests
|
||||
from deepdiff import DeepDiff
|
||||
import logging
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
ARCHGW_ENDPOINT = os.getenv(
|
||||
"ARCHGW_ENDPOINT", "http://localhost:10000/v1/chat/completions"
|
||||
)
|
||||
|
||||
# Load test data from YAML file
|
||||
with open(os.getenv("TEST_DATA", "test_data.yaml"), "r") as file:
|
||||
test_data_yaml = yaml.safe_load(file)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_data",
|
||||
[
|
||||
pytest.param(test_case, id=test_case["id"])
|
||||
for test_case in test_data_yaml["test_cases"]
|
||||
],
|
||||
)
|
||||
def test_demos(test_data):
|
||||
input = test_data["input"]
|
||||
expected_tools = test_data["expected_tools"]
|
||||
expected_output_contains = test_data["expected_output_contains"]
|
||||
|
||||
response = requests.post(ARCHGW_ENDPOINT, json=input)
|
||||
assert response.status_code == 200
|
||||
# ensure that response is json
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("gpt-4o")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
|
||||
# ensure that model responded according to the expectation
|
||||
assert "role" in choices[0]["message"]
|
||||
assert choices[0]["message"]["role"] == "assistant"
|
||||
assert expected_output_contains.lower() in choices[0]["message"]["content"].lower()
|
||||
|
||||
# now verify arch_messages (tool call and api response) that are sent as response metadata
|
||||
arch_messages = get_arch_messages(response_json)
|
||||
assert len(arch_messages) == 2
|
||||
tool_calls_message = arch_messages[0]
|
||||
tool_calls = tool_calls_message.get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
|
||||
# remove dynamic id from tool_calls
|
||||
for tool_call in tool_calls:
|
||||
tool_call.pop("id", None)
|
||||
diff = DeepDiff(expected_tools, tool_calls, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
|
@ -42,21 +42,18 @@ prompt_guards:
|
|||
message: Looks like you're curious about my abilities, but I can only provide assistance for weather forecasting.
|
||||
|
||||
prompt_targets:
|
||||
- name: weather_forecast
|
||||
description: Check weather information for a given city.
|
||||
- name: get_current_weather
|
||||
description: Get current weather at a location.
|
||||
parameters:
|
||||
- name: city
|
||||
description: the name of the city
|
||||
- name: location
|
||||
description: The location to get the weather for
|
||||
required: true
|
||||
type: str
|
||||
type: string
|
||||
format: City, State
|
||||
- name: days
|
||||
description: the number of days
|
||||
type: int
|
||||
description: the number of days for the request
|
||||
required: true
|
||||
- name: units
|
||||
description: the temperature unit, e.g., Celsius and Fahrenheit
|
||||
type: str
|
||||
default: Fahrenheit
|
||||
type: int
|
||||
endpoint:
|
||||
name: weather_forecast_service
|
||||
path: /weather
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ async def healthz():
|
|||
|
||||
|
||||
class WeatherRequest(BaseModel):
|
||||
city: str
|
||||
location: str
|
||||
days: int = 7
|
||||
units: str = "Farenheit"
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ class WeatherRequest(BaseModel):
|
|||
@app.post("/weather")
|
||||
async def weather(req: WeatherRequest, res: Response):
|
||||
weather_forecast = {
|
||||
"city": req.city,
|
||||
"location": req.location,
|
||||
"temperature": [],
|
||||
"units": req.units,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
name = "api-server"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Adil Hafeez <adil@katanemo.com>"]
|
||||
authors = ["Adil Hafeez <info@katanemo.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
|
|
|
|||
BIN
docs/source/_static/img/input-token-metrics.png
Normal file
BIN
docs/source/_static/img/input-token-metrics.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 169 KiB |
BIN
docs/source/_static/img/llm-request-metrics.png
Normal file
BIN
docs/source/_static/img/llm-request-metrics.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 118 KiB |
BIN
docs/source/_static/img/output-token-metrics.png
Normal file
BIN
docs/source/_static/img/output-token-metrics.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 157 KiB |
|
|
@ -1,18 +1,18 @@
|
|||
.. _arch_agent_guide:
|
||||
|
||||
Agentic Workflow
|
||||
==============================
|
||||
Agentic Apps
|
||||
=============
|
||||
|
||||
Arch helps you easily personalize your applications by calling application-specific (API) functions
|
||||
via user prompts. This involves any predefined functions or APIs you want to expose to users to perform tasks,
|
||||
gather information, or manipulate data. This capability is generally referred to as :ref:`function calling <function_calling>`, where
|
||||
you have the flexibility to support “agentic” apps tailored to specific use cases - from updating insurance
|
||||
claims to creating ad campaigns - via prompts.
|
||||
Arch helps you build personalized agentic applications by calling application-specific (API) functions via user prompts.
|
||||
This involves any predefined functions or APIs you want to expose to users to perform tasks, gather information,
|
||||
or manipulate data. This capability is generally referred to as :ref:`function calling <function_calling>`, where
|
||||
you can support “agentic” apps tailored to specific use cases - from updating insurance claims to creating ad campaigns - via prompts.
|
||||
|
||||
Arch analyzes prompts, extracts critical information from prompts, engages in lightweight conversation with
|
||||
the user to gather any missing parameters and makes API calls so that you can focus on writing business logic.
|
||||
Arch does this via its purpose-built `Arch-Function <https://huggingface.co/collections/katanemo/arch-function-66f209a693ea8df14317ad68>`_ - the fastest (200ms p90 - 10x faser than GPT-4o)
|
||||
and cheapest (100x than GPT-4o) function calling LLM that matches performance with frontier models.
|
||||
Arch analyzes prompts, extracts critical information from prompts, engages in lightweight conversation with the user to
|
||||
gather any missing parameters and makes API calls so that you can focus on writing business logic. Arch does this via its
|
||||
purpose-built `Arch-Function <https://huggingface.co/collections/katanemo/arch-function-66f209a693ea8df14317ad68>`_ -
|
||||
the fastest (200ms p50 - 12x faser than GPT-4o) and cheapest (44x than GPT-4o) function calling LLM that matches or outperforms
|
||||
frontier LLMs.
|
||||
|
||||
.. image:: includes/agent/function-calling-flow.jpg
|
||||
:width: 100%
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from openai import OpenAI
|
||||
from common import create_gradio_app
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Define the request model
|
||||
class EnergySourceRequest(BaseModel):
|
||||
energy_source: str
|
||||
consideration: Optional[str] = None
|
||||
|
||||
|
||||
class EnergySourceResponse(BaseModel):
|
||||
energy_source: str
|
||||
consideration: Optional[str] = None
|
||||
|
||||
|
||||
# Post method for device summary
|
||||
@app.post("/agent/energy_source_info")
|
||||
def get_workforce(request: EnergySourceRequest):
|
||||
"""
|
||||
Endpoint to get details about energy source
|
||||
"""
|
||||
considertion = "You don't have any specific consideration. Feel free to talk in a more open ended fashion"
|
||||
|
||||
if request.consideration is not None:
|
||||
considertion = f"Add specific focus on the following consideration when you summarize the content for the energy source: {request.consideration}"
|
||||
|
||||
response = {
|
||||
"energy_source": request.energy_source,
|
||||
"consideration": considertion,
|
||||
}
|
||||
return response
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 852 KiB |
|
|
@ -0,0 +1,35 @@
|
|||
version: v0.1
|
||||
listener:
|
||||
address: 127.0.0.1
|
||||
port: 8080 #If you configure port 443, you'll need to update the listener with tls_certificates
|
||||
message_format: huggingface
|
||||
|
||||
# Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way
|
||||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-3.5-turbo
|
||||
default: true
|
||||
|
||||
# default system prompt used by all prompt targets
|
||||
system_prompt: |
|
||||
You are a helpful assistant and can offer information about energy sources. You will get a JSON object with energy_source and consideration fields. Focus on answering using those fields
|
||||
|
||||
prompt_targets:
|
||||
- name: get_info_for_energy_source
|
||||
description: get information about an energy source
|
||||
parameters:
|
||||
- name: energy_source
|
||||
type: str
|
||||
description: a source of energy
|
||||
required: true
|
||||
enum: [renewable, fossil]
|
||||
- name: consideration
|
||||
type: str
|
||||
description: a specific type of consideration for an energy source
|
||||
enum: [cost, economic, technology]
|
||||
endpoint:
|
||||
name: rag_energy_source_agent
|
||||
path: /agent/energy_source_info
|
||||
http_method: POST
|
||||
|
|
@ -1,162 +0,0 @@
|
|||
from flask import Flask, request, jsonify
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
from langchain import OpenAI
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Global dictionary to keep track of user memories
|
||||
user_memories = {}
|
||||
|
||||
|
||||
def get_user_conversation(user_id):
|
||||
"""
|
||||
Retrieve the user's conversation memory using LangChain.
|
||||
If the user does not exist, initialize their conversation memory.
|
||||
"""
|
||||
if user_id not in user_memories:
|
||||
user_memories[user_id] = ConversationBufferMemory(return_messages=True)
|
||||
return user_memories[user_id]
|
||||
|
||||
|
||||
def update_user_conversation(user_id, client_messages, intent_changed):
|
||||
"""
|
||||
Update the user's conversation memory with new messages using LangChain.
|
||||
Each message is augmented with a UUID, timestamp, and intent change marker.
|
||||
Only new messages are added to avoid duplication.
|
||||
"""
|
||||
memory = get_user_conversation(user_id)
|
||||
stored_messages = memory.chat_memory.messages
|
||||
|
||||
# Determine the number of stored messages
|
||||
num_stored_messages = len(stored_messages)
|
||||
new_messages = client_messages[num_stored_messages:]
|
||||
|
||||
# Process each new message
|
||||
for index, message in enumerate(new_messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
metadata = {
|
||||
"uuid": str(uuid.uuid4()),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"intent_changed": False, # Default value
|
||||
}
|
||||
|
||||
# Mark the intent change on the last message if detected
|
||||
if intent_changed and index == len(new_messages) - 1:
|
||||
metadata["intent_changed"] = True
|
||||
|
||||
# Create a new message with metadata
|
||||
if role == "user":
|
||||
memory.chat_memory.add_message(
|
||||
HumanMessage(content=content, additional_kwargs={"metadata": metadata})
|
||||
)
|
||||
elif role == "assistant":
|
||||
memory.chat_memory.add_message(
|
||||
AIMessage(content=content, additional_kwargs={"metadata": metadata})
|
||||
)
|
||||
else:
|
||||
# Handle other roles if necessary
|
||||
pass
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
def get_messages_since_last_intent(messages):
|
||||
"""
|
||||
Retrieve messages from the last intent change onwards using LangChain.
|
||||
"""
|
||||
messages_since_intent = []
|
||||
for message in reversed(messages):
|
||||
# Insert message at the beginning to maintain correct order
|
||||
messages_since_intent.insert(0, message)
|
||||
metadata = message.additional_kwargs.get("metadata", {})
|
||||
# Break if intent_changed is True
|
||||
if metadata.get("intent_changed", False) == True:
|
||||
break
|
||||
|
||||
return messages_since_intent
|
||||
|
||||
|
||||
def forward_to_llm(messages):
|
||||
"""
|
||||
Forward messages to an upstream LLM using LangChain.
|
||||
"""
|
||||
# Convert messages to a conversation string
|
||||
conversation = ""
|
||||
for message in messages:
|
||||
role = "User" if isinstance(message, HumanMessage) else "Assistant"
|
||||
content = message.content
|
||||
conversation += f"{role}: {content}\n"
|
||||
# Use LangChain's LLM to get a response. This call is proxied through Arch for end-to-end observability and traffic management
|
||||
llm = OpenAI()
|
||||
# Create a prompt that includes the conversation
|
||||
prompt = f"{conversation}Assistant:"
|
||||
response = llm(prompt)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/process_rag", methods=["POST"])
|
||||
def process_rag():
|
||||
# Extract JSON data from the request
|
||||
data = request.get_json()
|
||||
|
||||
user_id = data.get("user_id")
|
||||
if not user_id:
|
||||
return jsonify({"error": "User ID is required"}), 400
|
||||
|
||||
client_messages = data.get("messages")
|
||||
if not client_messages or not isinstance(client_messages, list):
|
||||
return jsonify({"error": "Messages array is required"}), 400
|
||||
|
||||
# Extract the intent change marker from Arch's headers if present for the current prompt
|
||||
intent_changed_header = request.headers.get("x-arch-intent-marker", "").lower()
|
||||
if intent_changed_header in ["", "false"]:
|
||||
intent_changed = False
|
||||
elif intent_changed_header == "true":
|
||||
intent_changed = True
|
||||
else:
|
||||
# Invalid value provided
|
||||
return (
|
||||
jsonify({"error": "Invalid value for x-arch-prompt-intent-change header"}),
|
||||
400,
|
||||
)
|
||||
|
||||
# Update user conversation based on intent change
|
||||
memory = update_user_conversation(user_id, client_messages, intent_changed)
|
||||
|
||||
# Retrieve messages since last intent change for LLM
|
||||
messages_for_llm = get_messages_since_last_intent(memory.chat_memory.messages)
|
||||
|
||||
# Forward messages to upstream LLM
|
||||
llm_response = forward_to_llm(messages_for_llm)
|
||||
|
||||
# Prepare the messages to return
|
||||
messages_to_return = []
|
||||
for message in memory.chat_memory.messages:
|
||||
role = "user" if isinstance(message, HumanMessage) else "assistant"
|
||||
content = message.content
|
||||
metadata = message.additional_kwargs.get("metadata", {})
|
||||
message_entry = {
|
||||
"uuid": metadata.get("uuid"),
|
||||
"timestamp": metadata.get("timestamp"),
|
||||
"role": role,
|
||||
"content": content,
|
||||
"intent_changed": metadata.get("intent_changed", False),
|
||||
}
|
||||
messages_to_return.append(message_entry)
|
||||
|
||||
# Prepare the response
|
||||
response = {
|
||||
"user_id": user_id,
|
||||
"messages": messages_to_return,
|
||||
"llm_response": llm_response,
|
||||
}
|
||||
|
||||
return jsonify(response), 200
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True)
|
||||
90
docs/source/build_with_arch/multi_turn.rst
Normal file
90
docs/source/build_with_arch/multi_turn.rst
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
.. _arch_multi_turn_guide:
|
||||
|
||||
Multi-Turn
|
||||
==========
|
||||
Developers often `struggle <https://www.reddit.com/r/LocalLLaMA/comments/18mqwg6/best_practice_for_rag_with_followup_chat/>`_ to efficiently handle
|
||||
``follow-up`` or ``clarification`` questions. Specifically, when users ask for changes or additions to previous responses, it requires developers to
|
||||
re-write prompts using LLMs with precise prompt engineering techniques. This process is slow, manual, error prone and adds latency and token cost for
|
||||
common scenarios that can be managed more efficiently.
|
||||
|
||||
Arch is highly capable of accurately detecting and processing prompts in multi-turn scenarios so that you can buil fast and accurate agents in minutes.
|
||||
Below are some cnversational examples that you can build via Arch. Each example is enriched with annotations (via ** [Arch] ** ) that illustrates how Arch
|
||||
processess conversational messages on your behalf.
|
||||
|
||||
.. Note::
|
||||
The following section assumes that you have some knowledge about the core concepts of Arch, such as :ref:`prompt_targets <arch_overview_prompt_handling>`.
|
||||
If you haven't familizaried yourself with Arch's concepts, we recommend you first read the :ref:`tech overview <tech_overview>` section firtst.
|
||||
Additionally, the conversation examples below assume the usage of the following :ref:`arch_config.yaml <multi_turn_subsection_prompt_target>` file.
|
||||
|
||||
Example 1: Adjusting Retrieval
|
||||
------------------------------
|
||||
.. code-block:: text
|
||||
|
||||
User: What are the benefits of renewable energy?
|
||||
**[Arch]**: Check if there is an available <prompt_target> that can handle this user query.
|
||||
**[Arch]**: Found "get_info_for_energy_source" prompt_target in arch_config.yaml. Forward prompt to the endpoint configured in "get_info_for_energy_source"
|
||||
...
|
||||
Assistant: Renewable energy reduces greenhouse gas emissions, lowers air pollution, and provides sustainable power sources like solar and wind.
|
||||
|
||||
User: Include cost considerations in the response.
|
||||
**[Arch]**: Follow-up detected. Forward prompt history to the "get_info_for_energy_source" prompt_target and post the following parameters consideration="cost"
|
||||
...
|
||||
Assistant: Renewable energy reduces greenhouse gas emissions, lowers air pollution, and provides sustainable power sources like solar and wind. While the initial setup costs can be high, long-term savings from reduced fuel expenses and government incentives make it cost-effective.
|
||||
|
||||
|
||||
Example 2: Switching Intent
|
||||
---------------------------
|
||||
.. code-block:: text
|
||||
|
||||
User: What are the symptoms of diabetes?
|
||||
**[Arch]**: Check if there is an available <prompt_target> that can handle this user query.
|
||||
**[Arch]**: Found "diseases_symptoms" prompt_target in arch_config.yaml. Forward disease=diabeteres to "diseases_symptoms" prompt target
|
||||
...
|
||||
Assistant: Common symptoms include frequent urination, excessive thirst, fatigue, and blurry vision.
|
||||
|
||||
User: How is it diagnosed?
|
||||
**[Arch]**: New intent detected.
|
||||
**[Arch]**: Found "disease_diagnoses" prompt_target in arch_config.yaml. Forward disease=diabeteres to "disease_diagnoses" prompt target
|
||||
...
|
||||
Assistant: Diabetes is diagnosed through blood tests like fasting blood sugar, A1C, or an oral glucose tolerance test.
|
||||
|
||||
|
||||
Build Multi-Turn RAG Apps
|
||||
--------------------------
|
||||
The following section describes how you can easilly add support for multi-turn scenarios via Arch. You process and manage multi-turn prompts
|
||||
just like you manage single-turn ones. Arch handles the conpleixity of detecting the correct intent based on the last user prompt and
|
||||
the covnersational history, extracts relevant parameters needed by downstream APIs, and dipatches calls to any upstream LLMs to summarize the
|
||||
response from your APIs.
|
||||
|
||||
|
||||
.. _multi_turn_subsection_prompt_target:
|
||||
|
||||
Step 1: Define Arch Config
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. literalinclude:: includes/multi_turn/prompt_targets_multi_turn.yaml
|
||||
:language: yaml
|
||||
:caption: Arch Config
|
||||
:linenos:
|
||||
|
||||
Step 2: Process Request in Flask
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Once the prompt targets are configured as above, handle parameters across multi-turn as if its a single-turn request
|
||||
|
||||
.. literalinclude:: includes/multi_turn/multi_turn_rag.py
|
||||
:language: python
|
||||
:caption: Parameter handling with Flask
|
||||
:linenos:
|
||||
|
||||
Demo App
|
||||
~~~~~~~~
|
||||
|
||||
For your convenience, we've built a `demo app <https://github.com/katanemo/archgw/main/demos/multi_turn_rag_agent>`_
|
||||
that you can test and modify locally for multi-turn RAG scenarios.
|
||||
|
||||
.. figure:: includes/multi_turn/mutli-turn-example.png
|
||||
:width: 100%
|
||||
:align: center
|
||||
|
||||
Example multi-turn user conversation showing adjusting retrieval
|
||||
|
|
@ -1,10 +1,18 @@
|
|||
.. _arch_rag_guide:
|
||||
|
||||
RAG Application
|
||||
===============
|
||||
RAG Apps
|
||||
========
|
||||
|
||||
The following section describes how Arch can help you build faster, smarter and more accurate
|
||||
Retrieval-Augmented Generation (RAG) applications.
|
||||
Retrieval-Augmented Generation (RAG) applications, including fast and accurate RAG in multi-turn
|
||||
converational scenarios.
|
||||
|
||||
What is Retrieval-Augmented Generation (RAG)?
|
||||
---------------------------------------------
|
||||
RAG applications combine retrieval-based methods with generative AI models to provide more accurate,
|
||||
contextually relevant, and reliable outputs. These applications leverage external data sources to augment
|
||||
the capabilities of Large Language Models (LLMs), enabling them to retrieve and integrate specific information
|
||||
rather than relying solely on the LLM's internal knowledge.
|
||||
|
||||
Parameter Extraction for RAG
|
||||
----------------------------
|
||||
|
|
@ -33,60 +41,12 @@ Once the prompt targets are configured as above, handling those parameters is
|
|||
:caption: Parameter handling with Flask
|
||||
:linenos:
|
||||
|
||||
[Coming Soon] `Drift Detection via Arch Intent-Markers <https://github.com/orgs/katanemo/projects/1/views/1?pane=issue&itemId=82697909>`_
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------
|
||||
Developers struggle to efficiently handle ``follow-up`` or ``clarification`` questions. Specifically, when users ask for
|
||||
changes or additions to previous responses their AI applications often generate entirely new responses instead of adjusting
|
||||
previous ones. Arch offers ``intent tracking`` as a feature so that developers can know when the user has shifted away from a
|
||||
previous intent so that they can dramatically improve retrieval accuracy, lower overall token cost and improve the speed of
|
||||
their responses back to users.
|
||||
Multi-Turn RAG (Follow-up Questions)
|
||||
-------------------------------------
|
||||
Developers often `struggle <https://www.reddit.com/r/LocalLLaMA/comments/18mqwg6/best_practice_for_rag_with_followup_chat/>`_ to efficiently handle
|
||||
``follow-up`` or ``clarification`` questions. Specifically, when users ask for changes or additions to previous responses, it requires developers to
|
||||
re-write prompts using LLMs with precise prompt engineering techniques. This process is slow, manual, error prone and adds signifcant latency to the
|
||||
user experience. Arch
|
||||
|
||||
Arch uses its built-in lightweight NLI and embedding models to know if the user has steered away from an active intent.
|
||||
Arch's intent-drift detection mechanism is based on its :ref:`prompt target <prompt_target>` primtive. Arch tries to match an incoming
|
||||
prompt to one of the prompt_targets configured in the gateway. Once it detects that the user has moved away from an active
|
||||
active intent, Arch adds the ``x-arch-intent-marker`` headers to the request before sending it your application servers.
|
||||
|
||||
.. literalinclude:: includes/rag/intent_detection.py
|
||||
:language: python
|
||||
:linenos:
|
||||
:lines: 101-157
|
||||
:emphasize-lines: 14-25
|
||||
:caption: Intent Detection Example
|
||||
|
||||
|
||||
.. Note::
|
||||
|
||||
Arch is (mostly) stateless so that it can scale in an embarrassingly parrallel fashion. So, while Arch offers
|
||||
intent-drift detetction, you still have to maintain converational state with intent drift as metadata. The
|
||||
following code snippets show how easily you can build and enrich conversational history with Langchain (in Python),
|
||||
so that you can use the most relevant prompts for your retrieval and for prompting upstream LLMs.
|
||||
|
||||
|
||||
Step 1: Define ConversationBufferMemory
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. literalinclude:: includes/rag/intent_detection.py
|
||||
:language: python
|
||||
:linenos:
|
||||
:lines: 1-21
|
||||
|
||||
Step 2: Update ConversationBufferMemory with Intents
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. literalinclude:: includes/rag/intent_detection.py
|
||||
:language: python
|
||||
:linenos:
|
||||
:lines: 24-64
|
||||
|
||||
Step 3: Get Messages based on latest drift
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. literalinclude:: includes/rag/intent_detection.py
|
||||
:language: python
|
||||
:linenos:
|
||||
:lines: 67-80
|
||||
|
||||
|
||||
You can used the last set of messages that match to an intent to prompt an LLM, use it with an vector-DB for
|
||||
improved retrieval, etc. With Arch and a few lines of code, you can improve the retrieval accuracy, lower overall
|
||||
token cost and dramatically improve the speed of their responses back to users.
|
||||
Arch is highly capable of accurately detecting and processing prompts in a multi-turn scenarios so that you can buil fast and accurate RAG apps in
|
||||
minutes. For additional details on how to build multi-turn RAG applications please refer to our :ref:`multi-turn <arch_multi_turn_guide>` docs.
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ For more details on how you can build agentic applications using Arch, see our f
|
|||
.. Note::
|
||||
`Arch-Function <https://huggingface.co/collections/katanemo/arch-function-66f209a693ea8df14317ad68>`_ is a collection of dedicated agentic models engineered in Arch to extract information from a (set of) prompts and executes necessary backend API calls.
|
||||
This allows for efficient handling of agentic tasks, such as scheduling data retrieval, by dynamically interacting with backend services.
|
||||
Arch-Function achieves state-of-the-art performance, comparable with frontier models like Claude Sonnet 3.5 ang GPT-4, while being 100x cheaper ($0.05M/token hosted) and 10x faster (p50 latencies of 200ms).
|
||||
Arch-Function achieves state-of-the-art performance, comparable with frontier models like Claude Sonnet 3.5 ang GPT-4, while being 44x cheaper ($0.10M/token hosted) and 10x faster (p50 latencies of 200ms).
|
||||
|
||||
Prompting LLMs
|
||||
--------------
|
||||
|
|
|
|||
|
|
@ -3,8 +3,16 @@
|
|||
Terminology
|
||||
============
|
||||
|
||||
A few definitions before we dive into the main architecture documentation. Arch borrows from Envoy's terminology
|
||||
to keep things consistent in logs, traces and in code.
|
||||
A few definitions before we dive into the main architecture documentation. Also note, Arch borrows from Envoy's terminology
|
||||
to keep things consistent in logs and traces, and introduces and clarifies concepts are is relates to LLM applications.
|
||||
|
||||
**Agent**: An application that uses LLMs to handle wide-ranging tasks from users via prompts. This could be as simple
|
||||
as retrieving or summarizing data from an API, or being able to trigger compleix actions like adjusting ad campaigns, or
|
||||
changing travel plans via prompts.
|
||||
|
||||
**Arch Config**: Arch operates based on a configuration that controls the behavior of a single instance of the Arch gateway.
|
||||
This where you enable capabilities like LLM routing, fast function calling (via prompt_targets), applying guardrails, and enabling critical
|
||||
features like metrics and tracing. For the full configuration reference of `arch_config.yaml` see :ref:`here <configuration_refernce>`.
|
||||
|
||||
**Downstream(Ingress)**: An downstream client (web application, etc.) connects to Arch, sends prompts, and receives responses.
|
||||
|
||||
|
|
@ -14,9 +22,10 @@ to keep things consistent in logs, traces and in code.
|
|||
:width: 100%
|
||||
:align: center
|
||||
|
||||
**Listener**: A :ref:`listener <arch_overview_listeners>` is a named network location (e.g., port, address, path etc.) that Arch listens on to process prompts
|
||||
before forwarding them to your application server endpoints. rch enables you to configure one listener for downstream connections
|
||||
(like port 80, 443) and creates a separate internal listener for calls that initiate from your application code to LLMs.
|
||||
**Listener**: A :ref:`listener <arch_overview_listeners>` is a named network location (e.g., port, address, path etc.) that Arch
|
||||
listens on to process prompts before forwarding them to your application server endpoints. rch enables you to configure one listener
|
||||
for downstream connections (like port 80, 443) and creates a separate internal listener for calls that initiate from your application
|
||||
code to LLMs.
|
||||
|
||||
.. Note::
|
||||
|
||||
|
|
@ -24,23 +33,18 @@ before forwarding them to your application server endpoints. rch enables you to
|
|||
that you can use (``127.0.0.1:12000``) to proxy egress calls originating from your application to LLMs (API-based or hosted).
|
||||
For more details, check out :ref:`LLM provider <llm_provider>`.
|
||||
|
||||
**Instance**: An instance of the Arch gateway. When you start Arch it creates at most two processes. One to handle Layer 7
|
||||
networking operations (auth, tls, observability, etc) and the second process to serve models that enable it to make smart
|
||||
decisions on how to accept, handle and forward prompts. The second process is optional, as the model serving sevice could be
|
||||
hosted on a different network (an API call). But these two processes are considered a single instance of Arch.
|
||||
**Prompt Target**: Arch offers a primitive called :ref:`prompt target <prompt_target>` to help separate business logic from
|
||||
undifferentiated work in building generative AI apps. Prompt targets are endpoints that receive prompts that are processed by Arch.
|
||||
For example, Arch enriches incoming prompts with metadata like knowing when a request is a follow-up or clarifying prompt so that you
|
||||
can build faster, more accurate retrieval (RAG) apps. To support agentic apps, like scheduling travel plans or sharing comments on a
|
||||
document - via prompts, Arch uses its function calling abilities to extract critical information fromthe incoming prompt (or a set of
|
||||
prompts) needed by a downstream backend API or function call before calling it directly.
|
||||
|
||||
**Prompt Target**: Arch offers a primitive called :ref:`prompt target <prompt_target>` to help separate business logic from undifferentiated
|
||||
work in building generative AI apps. Prompt targets are endpoints that receive prompts that are processed by Arch.
|
||||
For example, Arch enriches incoming prompts with metadata like knowing when a request is a follow-up or clarifying prompt
|
||||
so that you can build faster, more accurate retrieval (RAG) apps. To support agentic apps, like scheduling travel plans or
|
||||
sharing comments on a document - via prompts, Arch uses its function calling abilities to extract critical information from
|
||||
the incoming prompt (or a set of prompts) needed by a downstream backend API or function call before calling it directly.
|
||||
**Model Serving**: Arch is a set of `two` self-contained processes that are designed to run alongside your application servers
|
||||
(or on a separate hostconnected via a network).The :ref:`model serving <model_serving>` process helps Arch make intelligent decisions
|
||||
about the incoming prompts. The model server is designed to call the (fast) purpose-built LLMs in Arch.
|
||||
|
||||
**Error Target**: :ref:`Error targets <error_target>` are those endpoints that receive forwarded errors from Arch when issues arise,
|
||||
such as failing to properly call a function/API, detecting violations of guardrails, or encountering other processing errors.
|
||||
These errors are communicated to the application via headers ``X-Arch-[ERROR-TYPE]``, allowing it to handle the errors gracefully
|
||||
and take appropriate actions.
|
||||
|
||||
**Model Serving**: Arch is a set of `two` self-contained processes that are designed to run alongside your application servers
|
||||
(or on a separate hostconnected via a network).The :ref:`model serving <model_serving>` process helps Arch make intelligent decisions about the
|
||||
incoming prompts. The model server is designed to call the (fast) purpose-built LLMs in Arch.
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from sphinxawesome_theme.postprocess import Icons
|
|||
project = "Arch Docs"
|
||||
copyright = "2024, Katanemo Labs, Inc"
|
||||
author = "Katanemo Labs, Inc"
|
||||
release = " v0.1.5"
|
||||
release = " v0.1.7"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ Arch's CLI allows you to manage and interact with the Arch gateway efficiently.
|
|||
|
||||
$ python -m venv venv
|
||||
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
|
||||
$ pip install archgw==0.1.6
|
||||
$ pip install archgw==0.1.7
|
||||
|
||||
|
||||
Build AI Agent with Arch Gateway
|
||||
|
|
|
|||
|
|
@ -3,7 +3,75 @@
|
|||
Monitoring
|
||||
==========
|
||||
|
||||
Arch offers several monitoring metrics that help you understand three critical aspects of your application:
|
||||
latency, token usage, and error rates by an upstream LLM provider. Latency measures the speed at which your
|
||||
application is responding to users, which includes metrics like time to first token (TFT), time per output
|
||||
token (TOT) metrics, and the total latency as perceived by users.
|
||||
`OpenTelemetry <https://opentelemetry.io/>`_ is an open-source observability framework providing APIs
|
||||
and instrumentation for generating, collecting, processing, and exporting telemetry data, such as traces,
|
||||
metrics, and logs. Its flexible design supports a wide range of backends and seamlessly integrates with
|
||||
modern application tools.
|
||||
|
||||
Arch acts a *source* for several monitoring metrics related to **prompts** and **LLMs** natively integrated
|
||||
via `OpenTelemetry <https://opentelemetry.io/>`_ to help you understand three critical aspects of your application:
|
||||
latency, token usage, and error rates by an upstream LLM provider. Latency measures the speed at which your application
|
||||
is responding to users, which includes metrics like time to first token (TFT), time per output token (TOT) metrics, and
|
||||
the total latency as perceived by users. Below are some screenshots how Arch integrates natively with tools like
|
||||
`Grafana <https://grafana.com/grafana/dashboards/>`_ via `Promethus <https://prometheus.io/>`_
|
||||
|
||||
|
||||
Metrics Dashboard (via Grafana)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. image:: /_static/img/llm-request-metrics.png
|
||||
:width: 100%
|
||||
:align: center
|
||||
|
||||
.. image:: /_static/img/input-token-metrics.png
|
||||
:width: 100%
|
||||
:align: center
|
||||
|
||||
.. image:: /_static/img/output-token-metrics.png
|
||||
:width: 100%
|
||||
:align: center
|
||||
|
||||
Configure Monitoring
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
Arch gateway publishes stats endpoint at http://localhost:19901/stats. As noted above, Arch is a source for metrics. To view and manipulate dashbaords, you will
|
||||
need to configiure `Promethus <https://prometheus.io/>`_ (as a metrics store) and `Grafana <https://grafana.com/grafana/dashboards/>`_ for dashboards. Below
|
||||
are some sample configuration files for both, respectively.
|
||||
|
||||
.. code-block:: yaml
|
||||
:caption: Sample prometheus.yaml config file
|
||||
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
scrape_timeout: 10s
|
||||
evaluation_interval: 15s
|
||||
alerting:
|
||||
alertmanagers:
|
||||
- static_configs:
|
||||
- targets: []
|
||||
scheme: http
|
||||
timeout: 10s
|
||||
api_version: v2
|
||||
scrape_configs:
|
||||
- job_name: archgw
|
||||
honor_timestamps: true
|
||||
scrape_interval: 15s
|
||||
scrape_timeout: 10s
|
||||
metrics_path: /stats
|
||||
scheme: http
|
||||
static_configs:
|
||||
- targets:
|
||||
- host.docker.internal:19901
|
||||
params:
|
||||
format: ["prometheus"]
|
||||
|
||||
|
||||
.. code-block:: yaml
|
||||
:caption: Sample grafana datasource.yaml config file
|
||||
|
||||
apiVersion: 1
|
||||
datasources:
|
||||
- name: Prometheus
|
||||
type: prometheus
|
||||
url: http://prometheus:9090
|
||||
isDefault: true
|
||||
access: proxy
|
||||
editable: true
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ Welcome to Arch!
|
|||
|
||||
<a href="https://www.producthunt.com/posts/arch-3?embed=true&utm_source=badge-top-post-badge&utm_medium=badge&utm_souce=badge-arch-3" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/top-post-badge.svg?post_id=565761&theme=light&period=daily" alt="Arch - Build fast, hyper-personalized agents with intelligent infra | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
|
||||
`Arch <https://github.com/katanemo/arch>`_ is an intelligent infrastructure primitive for GenAI (built by the contributors of `Envoy <https://www.envoyproxy.io/>`_ ) that born out of the belief that:
|
||||
`Arch <https://github.com/katanemo/arch>`_ is an intelligent gateway for agents - an infrastructure primitive for GenAI (built by the contributors of `Envoy <https://www.envoyproxy.io/>`_ ). The project was born out of the belief that:
|
||||
|
||||
*Prompts are nuanced and opaque user requests, which require the same capabilities as traditional HTTP requests including secure handling, intelligent routing, robust observability, and integration with backend (API) systems for personalization - all outside business logic.*
|
||||
|
||||
|
|
@ -62,6 +62,7 @@ Welcome to Arch!
|
|||
|
||||
build_with_arch/agent
|
||||
build_with_arch/rag
|
||||
build_with_arch/multi_turn
|
||||
|
||||
.. tab-item:: Resources
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
Configuration Reference
|
||||
============================
|
||||
.. _configuration_refernce:
|
||||
|
||||
The following is a complete reference of the ``prompt-conifg.yml`` that controls the behavior of a single instance of
|
||||
the Arch gateway. We've kept things simple (less than 80 lines) and held off on exposing additional functionality (for
|
||||
e.g. suppporting push observability stats, managing prompt-endpoints as virtual cluster, exposing more load balancing
|
||||
options, etc). Our belief that the simple things, should be simple. So we offert good defaults for developers, so
|
||||
that they can spend more of their time in building features unique to their AI experience.
|
||||
Configuration Reference
|
||||
=======================
|
||||
|
||||
The following is a complete reference of the ``arch_conifg.yml`` that controls the behavior of a single instance of
|
||||
the Arch gateway. This where you enable capabilities like routing to upstream LLm providers, defining prompt_targets
|
||||
where prompts get routed to, apply guardrails, and enable critical agent observability features.
|
||||
|
||||
.. literalinclude:: includes/arch_config_full_reference.yaml
|
||||
:language: yaml
|
||||
|
|
|
|||
15
e2e_tests/.vscode/launch.json
vendored
15
e2e_tests/.vscode/launch.json
vendored
|
|
@ -1,15 +0,0 @@
|
|||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
||||
7
e2e_tests/.vscode/settings.json
vendored
7
e2e_tests/.vscode/settings.json
vendored
|
|
@ -1,7 +0,0 @@
|
|||
{
|
||||
"python.testing.pytestArgs": [
|
||||
"."
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
@model_server_endpoint = http://localhost:51000
|
||||
@archfc_endpoint = https://api.fc.archgw.com
|
||||
|
||||
### talk to model_server for completion
|
||||
POST {{model_server_endpoint}}/v1/chat/completions HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in seattle for next 10 days"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"id": "weather-112",
|
||||
"tool_type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "str", "days": "int"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
### talk to arch_fc directly for completion
|
||||
POST {{archfc_endpoint}}/v1/chat/completions HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "Arch-Function",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"id\": \"weather-112\", \"tool_type\": \"function\", \"function\": {\"name\": \"weather_forecast\", \"arguments\": {\"city\": \"str\", \"days\": \"int\"}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
|
||||
},
|
||||
{ "role": "user", "content": "how is the weather in seattle?" },
|
||||
{ "role": "assistant", "content": "Of course! " }
|
||||
],
|
||||
"continue_final_message": true,
|
||||
"add_generation_prompt": false
|
||||
}
|
||||
2
model_server/.vscode/launch.json
vendored
2
model_server/.vscode/launch.json
vendored
|
|
@ -9,7 +9,7 @@
|
|||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": ["app.main:app","--reload", "--port", "51000"]
|
||||
"args": ["src.main:app","--reload", "--port", "51000"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ WORKDIR /src
|
|||
# specify list of models that will go into the image as a comma separated list
|
||||
# following models have been tested to work with this image
|
||||
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
|
||||
ENV MODELS="katanemo/bge-large-en-v1.5-onnx"
|
||||
ENV MODELS=""
|
||||
|
||||
COPY ./app ./app
|
||||
COPY ./app/guard_model_config.yaml .
|
||||
|
|
@ -28,4 +28,4 @@ COPY ./app/openai_params.yaml .
|
|||
# RUN python install.py && \
|
||||
# find /root/.cache/torch/sentence_transformers/ -name onnx -exec rm -rf {} +
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
CMD ["uvicorn", "src.app.main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ RUN if command -v nvcc >/dev/null 2>&1; then \
|
|||
COPY . /src
|
||||
|
||||
# Specify list of models that will go into the image as a comma separated list
|
||||
ENV MODELS="katanemo/bge-large-en-v1.5-onnx"
|
||||
ENV MODELS=""
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
COPY /app /app
|
||||
|
|
|
|||
|
|
@ -1,178 +0,0 @@
|
|||
import importlib
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import psutil
|
||||
import tempfile
|
||||
import subprocess
|
||||
import logging
|
||||
|
||||
|
||||
def get_version():
|
||||
try:
|
||||
version = importlib.metadata.version("archgw_modelserver")
|
||||
return version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
return "version not found"
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
log = logging.getLogger("model_server.cli")
|
||||
log.setLevel(logging.INFO)
|
||||
log.info(f"model server version: {get_version()}")
|
||||
|
||||
|
||||
def run_server(port=51000):
|
||||
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
|
||||
if len(sys.argv) > 1:
|
||||
action = sys.argv[1]
|
||||
else:
|
||||
action = "start"
|
||||
|
||||
if action == "start":
|
||||
start_server(port)
|
||||
elif action == "stop":
|
||||
stop_server(port)
|
||||
elif action == "restart":
|
||||
restart_server(port)
|
||||
else:
|
||||
log.info(f"Unknown action: {action}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def start_server(port=51000):
|
||||
"""Start the Uvicorn server"""
|
||||
log.info(
|
||||
"starting model server - loading some awesomeness, this may take some time :)"
|
||||
)
|
||||
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"app.main:app",
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
f"{port}",
|
||||
],
|
||||
start_new_session=True,
|
||||
bufsize=1,
|
||||
universal_newlines=True,
|
||||
stdout=subprocess.PIPE, # Suppress standard output. There is a logger that model_server prints to
|
||||
stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to
|
||||
)
|
||||
|
||||
if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"):
|
||||
log.info(f"Model server started with PID {process.pid}")
|
||||
else:
|
||||
# Add model_server boot-up logs
|
||||
log.info("model server - didn't start in time, shutting down")
|
||||
process.terminate()
|
||||
|
||||
|
||||
def wait_for_health_check(url, timeout=300):
|
||||
"""Wait for the Uvicorn server to respond to health-check requests."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.ConnectionError:
|
||||
time.sleep(1)
|
||||
print("Timed out waiting for model server to respond.")
|
||||
return False
|
||||
|
||||
|
||||
def check_and_install_lsof():
|
||||
"""Check if lsof is installed, and if not, install it using apt-get."""
|
||||
try:
|
||||
# Check if lsof is installed by running "lsof -v"
|
||||
subprocess.run(
|
||||
["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
print("lsof is already installed.")
|
||||
except subprocess.CalledProcessError:
|
||||
print("lsof not found, installing...")
|
||||
try:
|
||||
# Update package list and install lsof
|
||||
subprocess.run(["sudo", "apt-get", "update"], check=True)
|
||||
subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True)
|
||||
print("lsof installed successfully.")
|
||||
except subprocess.CalledProcessError as install_error:
|
||||
print(f"Failed to install lsof: {install_error}")
|
||||
|
||||
|
||||
def kill_process(port=51000, wait=True, timeout=10):
|
||||
"""Stop the running Uvicorn server."""
|
||||
log.info("Stopping model server")
|
||||
try:
|
||||
# Run the function to check and install lsof if necessary
|
||||
# Step 1: Run lsof command to get the process using the port
|
||||
lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
|
||||
result = subprocess.run(
|
||||
lsof_command, shell=True, capture_output=True, text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"No process found listening on port {port}.")
|
||||
return
|
||||
|
||||
# Step 2: Parse the process IDs from the output
|
||||
process_ids = [line.split()[1] for line in result.stdout.splitlines()]
|
||||
|
||||
if not process_ids:
|
||||
print(f"No process found listening on port {port}.")
|
||||
return
|
||||
|
||||
# Step 3: Kill each process using its PID
|
||||
for pid in process_ids:
|
||||
print(f"Killing model server process with PID {pid}")
|
||||
subprocess.run(f"kill {pid}", shell=True)
|
||||
|
||||
if wait:
|
||||
# Step 4: Wait for the process to be killed by checking if it's still running
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
check_process = subprocess.run(
|
||||
f"ps -p {pid}", shell=True, capture_output=True, text=True
|
||||
)
|
||||
if check_process.returncode != 0:
|
||||
print(f"Process {pid} has been killed.")
|
||||
break
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > timeout:
|
||||
print(
|
||||
f"Process {pid} did not terminate within {timeout} seconds."
|
||||
)
|
||||
print(f"Attempting to force kill process {pid}...")
|
||||
subprocess.run(f"kill -9 {pid}", shell=True) # SIGKILL
|
||||
break
|
||||
|
||||
print(
|
||||
f"Waiting for process {pid} to be killed... ({elapsed_time:.2f} seconds)"
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def stop_server(port=51000, wait=True, timeout=10):
|
||||
check_and_install_lsof()
|
||||
kill_process(port, wait, timeout)
|
||||
|
||||
|
||||
def restart_server(port=51000):
|
||||
"""Restart the Uvicorn server."""
|
||||
stop_server(port)
|
||||
start_server(port)
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
import app.commons.globals as glb
|
||||
import app.commons.utilities as utils
|
||||
import app.loader as loader
|
||||
|
||||
from app.function_calling.model_handler import ArchFunctionHandler
|
||||
from app.prompt_guard.model_handler import ArchGuardHanlder
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
arch_function_hanlder = ArchFunctionHandler()
|
||||
PREFILL_LIST = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"]
|
||||
PREFILL_ENABLED = True
|
||||
TOOL_CALL_TOKEN = "<tool_call>"
|
||||
arch_function_endpoint = "https://api.fc.archgw.com/v1"
|
||||
arch_function_client = utils.get_client(arch_function_endpoint)
|
||||
arch_function_generation_params = {
|
||||
"temperature": 0.2,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
"max_tokens": 512,
|
||||
"stop_token_ids": [151645],
|
||||
# "top_logprobs": 10,
|
||||
}
|
||||
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
# Model definition
|
||||
embedding_model = loader.get_embedding_model()
|
||||
zero_shot_model = loader.get_zero_shot_model()
|
||||
|
||||
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
|
||||
# Patterns for function name and parameter parsing
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
import app.commons.utilities as utils
|
||||
|
||||
|
||||
DEVICE = utils.get_device()
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
import os
|
||||
import yaml
|
||||
import torch
|
||||
import string
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
logger_instance = None
|
||||
|
||||
|
||||
def get_device():
|
||||
available_device = {
|
||||
"cpu": True,
|
||||
"cuda": torch.cuda.is_available(),
|
||||
"mps": (
|
||||
torch.backends.mps.is_available()
|
||||
if hasattr(torch.backends, "mps")
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
||||
if available_device["cuda"]:
|
||||
device = "cuda"
|
||||
elif available_device["mps"]:
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_client(endpoint):
|
||||
client = OpenAI(base_url=endpoint, api_key="EMPTY")
|
||||
return client
|
||||
|
||||
|
||||
def get_model_server_logger():
|
||||
global logger_instance
|
||||
|
||||
if logger_instance is not None:
|
||||
# If the logger is already initialized, return the existing instance
|
||||
return logger_instance
|
||||
|
||||
# Define log file path outside current directory (e.g., ~/archgw_logs)
|
||||
log_dir = os.path.expanduser("~/archgw_logs")
|
||||
log_file = "modelserver.log"
|
||||
log_file_path = os.path.join(log_dir, log_file)
|
||||
|
||||
# Ensure the log directory exists, create it if necessary, handle permissions errors
|
||||
try:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True) # Create directory if it doesn't exist
|
||||
|
||||
# Check if the script has write permission in the log directory
|
||||
if not os.access(log_dir, os.W_OK):
|
||||
raise PermissionError(f"No write permission for the directory: {log_dir}")
|
||||
# Configure logging to file and console using basicConfig
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file
|
||||
],
|
||||
)
|
||||
except (PermissionError, OSError):
|
||||
# Dont' fallback to console logging if there are issues writing to the log file
|
||||
raise RuntimeError(f"No write permission for the directory: {log_dir}")
|
||||
|
||||
# Initialize the logger instance after configuring handlers
|
||||
logger_instance = logging.getLogger("model_server_logger")
|
||||
return logger_instance
|
||||
|
||||
|
||||
def remove_punctuations(s):
|
||||
s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation)))
|
||||
return " ".join(s.split()).lower()
|
||||
|
||||
|
||||
def get_label_map(labels):
|
||||
return {remove_punctuations(label): label for label in labels}
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
import json
|
||||
import random
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_TOOL_PROMPT = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_FORMAT_PROMPT = """
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
""".strip()
|
||||
|
||||
|
||||
class ArchFunctionHandler:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
||||
def convert_tools(tools):
|
||||
return "\n".join([json.dumps(tool) for tool in tools])
|
||||
|
||||
tool_text = convert_tools(tools)
|
||||
|
||||
system_prompt = (
|
||||
ARCH_FUNCTION_CALLING_TASK_PROMPT
|
||||
+ "\n\n"
|
||||
+ ARCH_FUNCTION_CALLING_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ ARCH_FUNCTION_CALLING_FORMAT_PROMPT
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _add_execution_results_prompting(
|
||||
self,
|
||||
messages: list[dict],
|
||||
execution_results: list,
|
||||
) -> dict:
|
||||
content = []
|
||||
for result in execution_results:
|
||||
content.append(f"<tool_response>\n{json.dumps(result)}\n</tool_response>")
|
||||
|
||||
content = "\n".join(content)
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
def extract_tool_calls(self, content: str):
|
||||
tool_calls = []
|
||||
|
||||
flag = False
|
||||
for line in content.split("\n"):
|
||||
if "<tool_call>" == line:
|
||||
flag = True
|
||||
elif "</tool_call>" == line:
|
||||
flag = False
|
||||
else:
|
||||
if flag:
|
||||
try:
|
||||
tool_content = json.loads(line)
|
||||
except Exception:
|
||||
fixed_content = self.fix_json_string(line)
|
||||
try:
|
||||
tool_content = json.loads(fixed_content)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
"arguments": tool_content["arguments"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
flag = False
|
||||
|
||||
return tool_calls
|
||||
|
||||
def fix_json_string(self, json_str: str):
|
||||
# Remove any leading or trailing whitespace or newline characters
|
||||
json_str = json_str.strip()
|
||||
|
||||
# Stack to keep track of brackets
|
||||
stack = []
|
||||
|
||||
# Clean string to collect valid characters
|
||||
fixed_str = ""
|
||||
|
||||
# Dictionary for matching brackets
|
||||
matching_bracket = {")": "(", "}": "{", "]": "["}
|
||||
|
||||
# Dictionary for the opposite of matching_bracket
|
||||
opening_bracket = {v: k for k, v in matching_bracket.items()}
|
||||
|
||||
for char in json_str:
|
||||
if char in "{[(":
|
||||
stack.append(char)
|
||||
fixed_str += char
|
||||
elif char in "}])":
|
||||
if stack and stack[-1] == matching_bracket[char]:
|
||||
stack.pop()
|
||||
fixed_str += char
|
||||
else:
|
||||
# Ignore the unmatched closing brackets
|
||||
continue
|
||||
else:
|
||||
fixed_str += char
|
||||
|
||||
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
|
||||
while stack:
|
||||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str.replace("'", '"')
|
||||
|
|
@ -1,157 +0,0 @@
|
|||
import json
|
||||
import hashlib
|
||||
import app.commons.constants as const
|
||||
import random
|
||||
from fastapi import Response
|
||||
from pydantic import BaseModel
|
||||
from app.commons.utilities import get_model_server_logger
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
logger = get_model_server_logger()
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Optional[str] = ""
|
||||
content: Optional[str] = ""
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = []
|
||||
tool_call_id: Optional[str] = ""
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
message: Message
|
||||
finish_reason: Optional[str] = "stop"
|
||||
index: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
choices: List[Choice]
|
||||
model: Optional[str] = "Arch-Function"
|
||||
created: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
object: Optional[str] = "chat_completion"
|
||||
|
||||
|
||||
def process_messages(history: list[Message]):
|
||||
updated_history = []
|
||||
for hist in history:
|
||||
if hist.tool_calls:
|
||||
if len(hist.tool_calls) > 1:
|
||||
error_msg = f"Only one tool call is supported, tools counts: {len(hist.tool_calls)}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
|
||||
}
|
||||
)
|
||||
elif hist.role == "tool":
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
|
||||
}
|
||||
)
|
||||
else:
|
||||
updated_history.append({"role": hist.role, "content": hist.content})
|
||||
return updated_history
|
||||
|
||||
|
||||
async def chat_completion(req: ChatMessage, res: Response):
|
||||
logger.info("starting request")
|
||||
|
||||
tools_encoded = const.arch_function_hanlder._format_system(req.tools)
|
||||
|
||||
messages = [{"role": "system", "content": tools_encoded}]
|
||||
|
||||
updated_history = process_messages(req.messages)
|
||||
for message in updated_history:
|
||||
messages.append({"role": message["role"], "content": message["content"]})
|
||||
|
||||
client_model_name = const.arch_function_client.models.list().data[0].id
|
||||
|
||||
logger.info(
|
||||
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
|
||||
)
|
||||
|
||||
# Retrieve the first token, handling the Stream object carefully
|
||||
|
||||
try:
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=const.PREFILL_ENABLED,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"model_server <= arch_function: error: {e}")
|
||||
raise
|
||||
|
||||
if const.PREFILL_ENABLED:
|
||||
first_token_content = ""
|
||||
for token in resp:
|
||||
first_token_content = token.choices[
|
||||
0
|
||||
].delta.content.strip() # Clean up the content
|
||||
if first_token_content: # Break if it's non-empty
|
||||
break
|
||||
|
||||
# Check if the first token requires tool call handling
|
||||
if first_token_content != const.TOOL_CALL_TOKEN:
|
||||
# Engage pre-filling response if no tool call is indicated
|
||||
resp.close()
|
||||
logger.info("Tool call is not found! Engage pre filling")
|
||||
prefill_content = random.choice(const.PREFILL_LIST)
|
||||
messages.append({"role": "assistant", "content": prefill_content})
|
||||
|
||||
# Send a new completion request with the updated messages
|
||||
# the model will continue the final message in the chat instead of starting a new one
|
||||
# disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response.
|
||||
extra_body = {
|
||||
**const.arch_function_generation_params,
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
}
|
||||
pre_fill_resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
full_response = pre_fill_resp.choices[0].message.content
|
||||
else:
|
||||
# Initialize full response and iterate over tokens to gather the full response
|
||||
full_response = first_token_content
|
||||
for token in resp:
|
||||
if hasattr(token.choices[0].delta, "content"):
|
||||
full_response += token.choices[0].delta.content
|
||||
else:
|
||||
logger.info("Stream is disabled, not engaging pre-filling")
|
||||
full_response = resp.choices[0].message.content
|
||||
|
||||
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
|
||||
|
||||
if tool_calls:
|
||||
message = Message(content="", tool_calls=tool_calls)
|
||||
else:
|
||||
message = Message(content=full_response, tool_calls=[])
|
||||
choice = Choice(message=message)
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[choice], model=client_model_name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
|
||||
)
|
||||
logger.info(
|
||||
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
import os
|
||||
import app.commons.globals as glb
|
||||
|
||||
from transformers import AutoTokenizer, AutoModel, pipeline
|
||||
from optimum.onnxruntime import (
|
||||
ORTModelForFeatureExtraction,
|
||||
ORTModelForSequenceClassification,
|
||||
)
|
||||
import app.commons.utilities as utils
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"),
|
||||
):
|
||||
logger.info("Loading Embedding Model...")
|
||||
|
||||
if glb.DEVICE != "cuda":
|
||||
model = ORTModelForFeatureExtraction.from_pretrained(
|
||||
model_name, file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_name, device_map=glb.DEVICE)
|
||||
|
||||
embedding_model = {
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model,
|
||||
}
|
||||
|
||||
return embedding_model
|
||||
|
||||
|
||||
def get_zero_shot_model(
|
||||
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
|
||||
):
|
||||
logger.info("Loading Zero-shot Model...")
|
||||
|
||||
if glb.DEVICE != "cuda":
|
||||
model = ORTModelForSequenceClassification.from_pretrained(
|
||||
model_name, file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
model = model_name
|
||||
|
||||
zero_shot_model = {
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name),
|
||||
"model": model,
|
||||
}
|
||||
|
||||
zero_shot_model["pipeline"] = pipeline(
|
||||
"zero-shot-classification",
|
||||
model=zero_shot_model["model"],
|
||||
tokenizer=zero_shot_model["tokenizer"],
|
||||
device=glb.DEVICE,
|
||||
)
|
||||
|
||||
return zero_shot_model
|
||||
|
||||
|
||||
def get_prompt_guard(model_name):
|
||||
logger.info("Loading Guard Model...")
|
||||
|
||||
if glb.DEVICE == "cpu":
|
||||
model_class = OVModelForSequenceClassification
|
||||
else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
prompt_guard = {
|
||||
"device": glb.DEVICE,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
return prompt_guard
|
||||
|
|
@ -1,261 +0,0 @@
|
|||
import os
|
||||
import time
|
||||
import torch
|
||||
import app.commons.utilities as utils
|
||||
import app.commons.globals as glb
|
||||
import app.prompt_guard.model_utils as guard_utils
|
||||
|
||||
from typing import List, Dict
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, Response, HTTPException, Request
|
||||
from app.function_calling.model_utils import ChatMessage
|
||||
|
||||
from app.commons.constants import embedding_model, zero_shot_model, arch_guard_handler
|
||||
from app.function_calling.model_utils import (
|
||||
chat_completion as arch_function_chat_completion,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
"service.name": "model-server",
|
||||
}
|
||||
)
|
||||
|
||||
# Initialize the tracer provider
|
||||
trace.set_tracer_provider(TracerProvider(resource=resource))
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
logger.info(f"Ready to serve traffic. available device: {glb.DEVICE}")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
FastAPIInstrumentor().instrument_app(app)
|
||||
|
||||
# DEFAULT_OTLP_HOST = "http://localhost:4317"
|
||||
DEFAULT_OTLP_HOST = "none"
|
||||
|
||||
# Configure the OTLP exporter (Jaeger, Zipkin, etc.)
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=os.getenv("OTLP_HOST", DEFAULT_OTLP_HOST) # noqa: F821
|
||||
)
|
||||
|
||||
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: str
|
||||
model: str
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
input: str
|
||||
task: str
|
||||
|
||||
|
||||
class ZeroShotRequest(BaseModel):
|
||||
input: str
|
||||
labels: List[str]
|
||||
model: str
|
||||
|
||||
|
||||
class HallucinationRequest(BaseModel):
|
||||
prompt: str
|
||||
parameters: Dict
|
||||
model: str
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [{"id": embedding_model["model_name"], "object": "model"}],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/embeddings")
|
||||
async def embedding(req: EmbeddingRequest, res: Response):
|
||||
logger.info(f"Embedding req: {req}")
|
||||
|
||||
if req.model != embedding_model["model_name"]:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
encoded_input = embedding_model["tokenizer"](
|
||||
req.input, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(glb.DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
embeddings = embedding_model["model"](**encoded_input)
|
||||
embeddings = embeddings[0][:, 0]
|
||||
embeddings = (
|
||||
torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
|
||||
)
|
||||
|
||||
logger.info(f"Embedding Call Complete Time: {time.perf_counter()-start_time}")
|
||||
|
||||
data = [
|
||||
{"object": "embedding", "embedding": embedding, "index": index + 1}
|
||||
for index, embedding in enumerate(embeddings.tolist())
|
||||
]
|
||||
|
||||
usage = {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
|
||||
return {"data": data, "model": req.model, "object": "list", "usage": usage}
|
||||
|
||||
|
||||
@app.post("/guard")
|
||||
async def guard(req: GuardRequest, res: Response, max_num_words=300):
|
||||
"""
|
||||
Take input as text and return the prediction of toxic and jailbreak
|
||||
"""
|
||||
|
||||
if req.task in ["both", "toxic", "jailbreak"]:
|
||||
arch_guard_handler.task = req.task
|
||||
else:
|
||||
raise NotImplementedError(f"{req.task} is not supported!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if len(req.input.split()) < max_num_words:
|
||||
guard_result = arch_guard_handler.guard_predict(req.input)
|
||||
else:
|
||||
# text is long, split into chunks
|
||||
chunks = guard_utils.split_text_into_chunks(req.input)
|
||||
|
||||
guard_result = {
|
||||
"jailbreak_prob": [],
|
||||
"time": 0,
|
||||
"jailbreak_verdict": False,
|
||||
"toxic_sentence": [],
|
||||
"jailbreak_sentence": [],
|
||||
}
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_result = arch_guard_handler.guard_predict(chunk)
|
||||
guard_result["time"] += chunk_result["time"]
|
||||
if chunk_result[f"{arch_guard_handler.task}_verdict"]:
|
||||
guard_result[f"{arch_guard_handler.task}_verdict"] = True
|
||||
guard_result[f"{arch_guard_handler.task}_sentence"].append(
|
||||
chunk_result[f"{arch_guard_handler.task}_sentence"]
|
||||
)
|
||||
guard_result[f"{arch_guard_handler.task}_prob"].append(
|
||||
chunk_result[f"{arch_guard_handler.task}_prob"].item()
|
||||
)
|
||||
|
||||
logger.info(f"Time taken for Guard: {time.perf_counter() - start_time}")
|
||||
|
||||
return guard_result
|
||||
|
||||
|
||||
@app.post("/zeroshot")
|
||||
async def zeroshot(req: ZeroShotRequest, res: Response):
|
||||
logger.info(f"zero-shot request: {req}")
|
||||
|
||||
if req.model != zero_shot_model["model_name"]:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
classifier = zero_shot_model["pipeline"]
|
||||
|
||||
label_map = utils.get_label_map(req.labels)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
predictions = classifier(
|
||||
req.input, candidate_labels=list(label_map.keys()), multi_label=True
|
||||
)
|
||||
|
||||
logger.info(f"zero-shot taking {time.perf_counter() - start_time} seconds")
|
||||
|
||||
predicted_class = label_map[predictions["labels"][0]]
|
||||
predicted_score = predictions["scores"][0]
|
||||
|
||||
scores = {
|
||||
label_map[label]: score
|
||||
for label, score in zip(predictions["labels"], predictions["scores"])
|
||||
}
|
||||
|
||||
predicted_class = label_map[predictions["labels"][0]]
|
||||
|
||||
return {
|
||||
"predicted_class": predicted_class,
|
||||
"predicted_class_score": predicted_score,
|
||||
"scores": scores,
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/hallucination")
|
||||
@patch("app.loader.glb.DEVICE", "cpu") # Mock the device to 'cpu'
|
||||
async def hallucination(req: HallucinationRequest, res: Response):
|
||||
"""
|
||||
Take input as text and return the prediction of hallucination for each parameter
|
||||
"""
|
||||
logger.info(f"hallucination request: {req}")
|
||||
if req.model != zero_shot_model["model_name"]:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
classifier = zero_shot_model["pipeline"]
|
||||
|
||||
if "messages" in req.parameters:
|
||||
req.parameters.pop("messages")
|
||||
|
||||
if not req.parameters or len(req.parameters) == 0:
|
||||
return {
|
||||
"params_scores": {},
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
|
||||
|
||||
predictions = classifier(
|
||||
req.prompt,
|
||||
candidate_labels=list(candidate_labels.keys()),
|
||||
hypothesis_template="{}",
|
||||
multi_label=True,
|
||||
)
|
||||
|
||||
params_scores = {
|
||||
candidate_labels[label]: score
|
||||
for label, score in zip(predictions["labels"], predictions["scores"])
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"hallucination time cost: {params_scores}, taking {time.perf_counter() - start_time} seconds"
|
||||
)
|
||||
|
||||
return {
|
||||
"params_scores": params_scores,
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(req: ChatMessage, res: Response, request: Request):
|
||||
try:
|
||||
result = await arch_function_chat_completion(req, res)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_completion: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": "Internal server error"}
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
import time
|
||||
import torch
|
||||
import app.prompt_guard.model_utils as model_utils
|
||||
|
||||
|
||||
class ArchGuardHanlder:
|
||||
def __init__(self, model_dict, threshold=0.5):
|
||||
self.task = "jailbreak"
|
||||
self.positive_class = 2
|
||||
|
||||
self.model = model_dict["model"]
|
||||
self.tokenizer = model_dict["tokenizer"]
|
||||
self.device = model_dict["device"]
|
||||
|
||||
self.threshold = threshold
|
||||
|
||||
def guard_predict(self, input_text, max_length=512):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
inputs = self.tokenizer(
|
||||
input_text, truncation=True, max_length=max_length, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
|
||||
prob = model_utils.softmax(logits)[self.positive_class]
|
||||
|
||||
if prob > self.threshold:
|
||||
verdict = True
|
||||
sentence = input_text
|
||||
else:
|
||||
verdict = False
|
||||
sentence = None
|
||||
|
||||
result_dict = {
|
||||
f"{self.task}_prob": prob.item(),
|
||||
f"{self.task}_verdict": verdict,
|
||||
f"{self.task}_sentence": sentence,
|
||||
"time": time.perf_counter() - start_time,
|
||||
}
|
||||
|
||||
return result_dict
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def split_text_into_chunks(text, max_words=300):
|
||||
"""
|
||||
Max number of tokens for tokenizer is 512
|
||||
Split the text into chunks of 300 words (as approximation for tokens)
|
||||
"""
|
||||
words = text.split() # Split text into words
|
||||
# Estimate token count based on word count (1 word ≈ 1 token)
|
||||
chunk_size = max_words # Use the word count as an approximation for tokens
|
||||
chunks = [
|
||||
" ".join(words[i : i + chunk_size]) for i in range(0, len(words), chunk_size)
|
||||
]
|
||||
return chunks
|
||||
|
||||
|
||||
def softmax(x):
|
||||
return np.exp(x) / np.exp(x).sum(axis=0)
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
import pytest
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app # Assuming your FastAPI app is in main.py
|
||||
from unittest.mock import patch
|
||||
import app.commons.globals as glb
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
logger.info(f"Model will be loaded on device: {glb.DEVICE}")
|
||||
|
||||
|
||||
# Unit tests for the health check endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_healthz():
|
||||
response = client.get("/healthz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
# Unit test for the models endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_models():
|
||||
response = client.get("/models")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["object"] == "list"
|
||||
assert len(response.json()["data"]) > 0
|
||||
|
||||
|
||||
# Unit test for embeddings endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_embedding():
|
||||
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"}
|
||||
response = client.post("/embeddings", json=request_data)
|
||||
if request_data["model"] == "katanemo/bge-large-en-v1.5":
|
||||
assert response.status_code == 200
|
||||
assert response.json()["object"] == "list"
|
||||
assert "data" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the guard endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_guard():
|
||||
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
|
||||
response = client.post("/guard", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "jailbreak_verdict" in response.json()
|
||||
|
||||
|
||||
# Unit test for the zero-shot endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_zeroshot():
|
||||
request_data = {
|
||||
"input": "Test input",
|
||||
"labels": ["label1", "label2"],
|
||||
"model": "katanemo/bart-large-mnli",
|
||||
}
|
||||
response = client.post("/zeroshot", json=request_data)
|
||||
if request_data["model"] == "katanemo/bart-large-mnli":
|
||||
assert response.status_code == 200
|
||||
assert "predicted_class" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the hallucination endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_hallucination():
|
||||
request_data = {
|
||||
"prompt": "Test hallucination",
|
||||
"parameters": {"param1": "value1"},
|
||||
"model": "katanemo/bart-large-mnli",
|
||||
}
|
||||
response = client.post("/hallucination", json=request_data)
|
||||
if request_data["model"] == "katanemo/bart-large-mnli":
|
||||
assert response.status_code == 200
|
||||
assert "params_scores" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the chat completion endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_chat_completion():
|
||||
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
||||
request_data = {
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"model": "Arch-Function-1.5B",
|
||||
"tools": [], # Assuming tools is part of the req as per the function
|
||||
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed
|
||||
}
|
||||
response = await client.post("/v1/chat/completions", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "choices" in response.json()
|
||||
|
|
@ -1,794 +0,0 @@
|
|||
[{
|
||||
"case": "tool_call_halluciation",
|
||||
"tokens" : ["<tool_call>"],
|
||||
"expect": 1,
|
||||
"logprobs": [[-0.3333307206630707,
|
||||
-1.5310522317886353,
|
||||
-3.5098977088928223,
|
||||
-3.9004578590393066,
|
||||
-5.775152683258057,
|
||||
-5.814209461212158,
|
||||
-5.9574151039123535,
|
||||
-6.0094895362854,
|
||||
-6.0094895362854,
|
||||
-6.673445224761963]]
|
||||
},
|
||||
{
|
||||
"case" : "parameter_value_hallucination",
|
||||
"expect" : 0,
|
||||
"tokens" : ["<tool_call>",
|
||||
"\n",
|
||||
"{'",
|
||||
"name",
|
||||
"':",
|
||||
" '",
|
||||
"get",
|
||||
"_current",
|
||||
"_weather",
|
||||
"',",
|
||||
" '",
|
||||
"arguments",
|
||||
"':",
|
||||
" {'",
|
||||
"location",
|
||||
"':",
|
||||
" '",
|
||||
"Sea",
|
||||
",",
|
||||
" Australia",
|
||||
"',",
|
||||
" '",
|
||||
"unit",
|
||||
"':",
|
||||
" '",
|
||||
"c",
|
||||
"elsius",
|
||||
"',",
|
||||
" '",
|
||||
"days",
|
||||
"':",
|
||||
" '",
|
||||
"1",
|
||||
"'}}\n",
|
||||
"</tool_call>"],
|
||||
"logprobs": [[-0.008103232830762863,
|
||||
-5.085402488708496,
|
||||
-6.777836799621582,
|
||||
-7.558959007263184,
|
||||
-9.850253105163574,
|
||||
-10.266852378845215,
|
||||
-10.540244102478027,
|
||||
-10.722506523132324,
|
||||
-10.800618171691895,
|
||||
-10.917786598205566],
|
||||
[0.0,
|
||||
-23.25142478942871,
|
||||
-25.139137268066406,
|
||||
-26.2847843170166,
|
||||
-28.992677688598633,
|
||||
-29.070789337158203,
|
||||
-29.55248260498047,
|
||||
-29.91700553894043,
|
||||
-30.20341682434082,
|
||||
-30.307567596435547],
|
||||
[0.0,
|
||||
-21.66313934326172,
|
||||
-23.06916046142578,
|
||||
-23.32953453063965,
|
||||
-25.65988540649414,
|
||||
-25.985353469848633,
|
||||
-26.519121170043945,
|
||||
-27.07892417907715,
|
||||
-27.977216720581055,
|
||||
-28.458908081054688],
|
||||
[0.0,
|
||||
-28.094383239746094,
|
||||
-28.56305694580078,
|
||||
-29.109844207763672,
|
||||
-29.44832992553711,
|
||||
-31.79170036315918,
|
||||
-32.0,
|
||||
-32.05207443237305,
|
||||
-32.31244659423828,
|
||||
-32.364524841308594],
|
||||
[0.0,
|
||||
-30.489830017089844,
|
||||
-31.140766143798828,
|
||||
-31.81774139404297,
|
||||
-34.525634765625,
|
||||
-35.8275032043457,
|
||||
-36.504478454589844,
|
||||
-39.05614471435547,
|
||||
-40.123680114746094,
|
||||
-40.696502685546875],
|
||||
[0.0,
|
||||
-25.646865844726562,
|
||||
-26.66232681274414,
|
||||
-27.781936645507812,
|
||||
-28.979660034179688,
|
||||
-31.140764236450195,
|
||||
-31.92188835144043,
|
||||
-31.973962783813477,
|
||||
-33.04149627685547,
|
||||
-33.58828353881836],
|
||||
[0.0,
|
||||
-23.511798858642578,
|
||||
-24.136695861816406,
|
||||
-25.230268478393555,
|
||||
-25.777053833007812,
|
||||
-25.80309295654297,
|
||||
-26.45402717590332,
|
||||
-26.636289596557617,
|
||||
-26.740440368652344,
|
||||
-26.896663665771484],
|
||||
[0.0,
|
||||
-22.366153717041016,
|
||||
-24.683483123779297,
|
||||
-26.610252380371094,
|
||||
-26.610252380371094,
|
||||
-27.313264846801758,
|
||||
-27.67778778076172,
|
||||
-28.510986328125,
|
||||
-28.615135192871094,
|
||||
-29.13588523864746],
|
||||
[0.0,
|
||||
-22.52237319946289,
|
||||
-24.292919158935547,
|
||||
-24.344993591308594,
|
||||
-24.39706802368164,
|
||||
-24.73555564880371,
|
||||
-29.943042755126953,
|
||||
-29.969079971313477,
|
||||
-30.021154403686523,
|
||||
-30.0341739654541],
|
||||
[0.0,
|
||||
-30.17738151550293,
|
||||
-30.411718368530273,
|
||||
-30.88039207458496,
|
||||
-30.984540939331055,
|
||||
-31.270952224731445,
|
||||
-31.895851135253906,
|
||||
-32.46867370605469,
|
||||
-32.624900817871094,
|
||||
-33.484134674072266],
|
||||
[0.0,
|
||||
-28.146459579467773,
|
||||
-29.396255493164062,
|
||||
-30.099267959594727,
|
||||
-31.127744674682617,
|
||||
-31.179821014404297,
|
||||
-32.807159423828125,
|
||||
-33.7445068359375,
|
||||
-33.770545959472656,
|
||||
-34.069976806640625],
|
||||
[0.0,
|
||||
-26.323841094970703,
|
||||
-26.558177947998047,
|
||||
-30.515867233276367,
|
||||
-30.932466506958008,
|
||||
-31.37510108947754,
|
||||
-31.531326293945312,
|
||||
-31.70056915283203,
|
||||
-32.065093994140625,
|
||||
-32.364524841308594],
|
||||
[0.0,
|
||||
-26.922698974609375,
|
||||
-30.28152847290039,
|
||||
-31.505287170410156,
|
||||
-33.30187225341797,
|
||||
-33.73148727416992,
|
||||
-34.27827453613281,
|
||||
-34.33034896850586,
|
||||
-34.460533142089844,
|
||||
-34.720909118652344],
|
||||
[0.0,
|
||||
-21.532955169677734,
|
||||
-26.94873809814453,
|
||||
-29.109848022460938,
|
||||
-30.80228042602539,
|
||||
-31.55736541748047,
|
||||
-33.484134674072266,
|
||||
-34.681854248046875,
|
||||
-35.384864807128906,
|
||||
-35.853538513183594],
|
||||
[0.0,
|
||||
-19.502033233642578,
|
||||
-20.46541976928711,
|
||||
-24.110658645629883,
|
||||
-24.501218795776367,
|
||||
-25.256305694580078,
|
||||
-25.82912826538086,
|
||||
-25.881202697753906,
|
||||
-26.063465118408203,
|
||||
-26.063465118408203],
|
||||
[0.0,
|
||||
-24.37103271484375,
|
||||
-25.256305694580078,
|
||||
-25.933277130126953,
|
||||
-26.714401245117188,
|
||||
-28.2506103515625,
|
||||
-31.010576248168945,
|
||||
-32.07810974121094,
|
||||
-34.62977981567383,
|
||||
-35.241661071777344],
|
||||
[-1.1920922133867862e-06,
|
||||
-14.398697853088379,
|
||||
-14.424736976623535,
|
||||
-17.158666610717773,
|
||||
-17.41904067993164,
|
||||
-18.200162887573242,
|
||||
-18.434499740600586,
|
||||
-18.66883659362793,
|
||||
-19.71033477783203,
|
||||
-19.71033477783203],
|
||||
[-0.0001445904199499637,
|
||||
-8.98305892944336,
|
||||
-11.35246467590332,
|
||||
-13.1490478515625,
|
||||
-13.669795989990234,
|
||||
-14.073375701904297,
|
||||
-14.516012191772461,
|
||||
-14.555068969726562,
|
||||
-15.622602462768555,
|
||||
-15.635622024536133],
|
||||
[-0.44747352600097656,
|
||||
-1.0202960968017578,
|
||||
-8.467000961303711,
|
||||
-10.914518356323242,
|
||||
-11.25300407409668,
|
||||
-11.435266494750977,
|
||||
-12.346576690673828,
|
||||
-13.075624465942383,
|
||||
-13.12769889831543,
|
||||
-13.231849670410156],
|
||||
[-3.123767137527466,
|
||||
-1.1188862323760986,
|
||||
-1.639634370803833,
|
||||
-2.0562336444854736,
|
||||
-2.8633930683135986,
|
||||
-2.9675419330596924,
|
||||
-3.4882919788360596,
|
||||
-3.69659161567688,
|
||||
-4.217339515686035,
|
||||
-4.243376731872559],
|
||||
[-7.199982064776123e-05,
|
||||
-9.76410961151123,
|
||||
-11.144091606140137,
|
||||
-16.507802963256836,
|
||||
-17.132701873779297,
|
||||
-17.44515037536621,
|
||||
-17.9138240814209,
|
||||
-18.33042335510254,
|
||||
-18.9162654876709,
|
||||
-19.39795684814453],
|
||||
[0.0,
|
||||
-22.991050720214844,
|
||||
-23.824249267578125,
|
||||
-24.969894409179688,
|
||||
-25.46460723876953,
|
||||
-25.829130172729492,
|
||||
-26.480066299438477,
|
||||
-26.909683227539062,
|
||||
-27.33930206298828,
|
||||
-27.391376495361328],
|
||||
[-0.21928852796554565,
|
||||
-1.625309705734253,
|
||||
-9.775025367736816,
|
||||
-12.977627754211426,
|
||||
-16.388530731201172,
|
||||
-17.091541290283203,
|
||||
-19.044347763061523,
|
||||
-19.38283348083496,
|
||||
-19.460947036743164,
|
||||
-19.59113311767578],
|
||||
[0.0,
|
||||
-24.006507873535156,
|
||||
-27.443450927734375,
|
||||
-27.729862213134766,
|
||||
-28.12042236328125,
|
||||
-28.276647567749023,
|
||||
-28.927583694458008,
|
||||
-30.099267959594727,
|
||||
-31.479251861572266,
|
||||
-32.07810974121094],
|
||||
[0.0,
|
||||
-18.17412567138672,
|
||||
-18.772987365722656,
|
||||
-21.689178466796875,
|
||||
-21.92351531982422,
|
||||
-23.7200984954834,
|
||||
-23.79821014404297,
|
||||
-23.79821014404297,
|
||||
-24.032546997070312,
|
||||
-25.308382034301758],
|
||||
[-0.12947827577590942,
|
||||
-2.1083219051361084,
|
||||
-12.419143676757812,
|
||||
-15.23118782043457,
|
||||
-15.595710754394531,
|
||||
-15.830047607421875,
|
||||
-17.001731872558594,
|
||||
-17.60059356689453,
|
||||
-18.121341705322266,
|
||||
-18.251529693603516],
|
||||
[0.0,
|
||||
-19.449962615966797,
|
||||
-24.371034622192383,
|
||||
-24.917821884155273,
|
||||
-25.529701232910156,
|
||||
-25.85516929626465,
|
||||
-26.037429809570312,
|
||||
-26.115543365478516,
|
||||
-26.623271942138672,
|
||||
-26.649309158325195],
|
||||
[-0.03332124650478363,
|
||||
-3.4181859493255615,
|
||||
-15.759925842285156,
|
||||
-15.812002182006836,
|
||||
-16.593124389648438,
|
||||
-17.894996643066406,
|
||||
-18.09027671813965,
|
||||
-18.79328727722168,
|
||||
-19.144792556762695,
|
||||
-20.147233963012695],
|
||||
[0.0,
|
||||
-21.142393112182617,
|
||||
-22.157852172851562,
|
||||
-23.511798858642578,
|
||||
-24.657445907592773,
|
||||
-25.021968841552734,
|
||||
-25.5427188873291,
|
||||
-25.59479331970215,
|
||||
-25.75101661682129,
|
||||
-25.95931625366211],
|
||||
[0.0,
|
||||
-23.04312515258789,
|
||||
-24.94385528564453,
|
||||
-26.323841094970703,
|
||||
-27.54759979248047,
|
||||
-28.563060760498047,
|
||||
-29.786819458007812,
|
||||
-30.620018005371094,
|
||||
-30.69812774658203,
|
||||
-31.08869171142578],
|
||||
[0.0,
|
||||
-26.167617797851562,
|
||||
-28.771360397338867,
|
||||
-29.55248260498047,
|
||||
-30.906429290771484,
|
||||
-31.114728927612305,
|
||||
-31.414159774780273,
|
||||
-31.622459411621094,
|
||||
-31.713590621948242,
|
||||
-31.726608276367188],
|
||||
[-0.05012698099017143,
|
||||
-3.018392562866211,
|
||||
-11.740934371948242,
|
||||
-13.146955490112305,
|
||||
-13.797887802124023,
|
||||
-14.943536758422852,
|
||||
-16.037107467651367,
|
||||
-16.375595092773438,
|
||||
-16.714080810546875,
|
||||
-17.36501693725586],
|
||||
[-0.9704352021217346,
|
||||
-0.7360983490943909,
|
||||
-2.1941938400268555,
|
||||
-4.225115776062012,
|
||||
-5.0062360763549805,
|
||||
-5.2666120529174805,
|
||||
-5.839434623718262,
|
||||
-7.2714948654174805,
|
||||
-8.33902645111084,
|
||||
-8.495253562927246],
|
||||
[-0.014467108063399792,
|
||||
-4.258565902709961,
|
||||
-8.789079666137695,
|
||||
-10.429437637329102,
|
||||
-10.793962478637695,
|
||||
-11.835458755493164,
|
||||
-11.939607620239258,
|
||||
-13.31959342956543,
|
||||
-13.866378784179688,
|
||||
-15.038063049316406],
|
||||
[0.0,
|
||||
-20.08787727355957,
|
||||
-21.350692749023438,
|
||||
-21.415786743164062,
|
||||
-21.50691795349121,
|
||||
-21.50691795349121,
|
||||
-22.7176570892334,
|
||||
-24.13669776916504,
|
||||
-24.188772201538086,
|
||||
-24.34499740600586]]
|
||||
},
|
||||
{
|
||||
"case": "fail_case",
|
||||
"expect" : 0,
|
||||
"tokens" : ["<tool_call>",
|
||||
"\n",
|
||||
"{'",
|
||||
"name",
|
||||
"':",
|
||||
" '",
|
||||
"get",
|
||||
"_current",
|
||||
"_weather",
|
||||
"',",
|
||||
" '",
|
||||
"arguments",
|
||||
"':",
|
||||
" {'",
|
||||
"location",
|
||||
"':",
|
||||
" '",
|
||||
"Seattle",
|
||||
",",
|
||||
" WA",
|
||||
"',",
|
||||
" '",
|
||||
"unit",
|
||||
"':",
|
||||
" '",
|
||||
"c",
|
||||
"elsius",
|
||||
"',",
|
||||
" '",
|
||||
"days",
|
||||
"':",
|
||||
" '",
|
||||
"7",
|
||||
"'}}\n",
|
||||
"</tool_call>"],
|
||||
"logprobs":[[-0.00013815402053296566,
|
||||
-9.113236427307129,
|
||||
-10.571331977844238,
|
||||
-14.099404335021973,
|
||||
-14.28166675567627,
|
||||
-15.583537101745605,
|
||||
-15.81787395477295,
|
||||
-16.143341064453125,
|
||||
-16.143341064453125,
|
||||
-16.260509490966797],
|
||||
[0.0,
|
||||
-26.896663665771484,
|
||||
-27.32628059387207,
|
||||
-27.41741180419922,
|
||||
-32.07810974121094,
|
||||
-32.07810974121094,
|
||||
-32.28641128540039,
|
||||
-32.29943084716797,
|
||||
-32.44263458251953,
|
||||
-32.520748138427734],
|
||||
[0.0,
|
||||
-22.444263458251953,
|
||||
-24.527257919311523,
|
||||
-27.15703773498535,
|
||||
-28.016273498535156,
|
||||
-28.2506103515625,
|
||||
-28.693246841430664,
|
||||
-29.070789337158203,
|
||||
-29.565500259399414,
|
||||
-29.812854766845703],
|
||||
[0.0,
|
||||
-27.860050201416016,
|
||||
-28.641170501708984,
|
||||
-29.448333740234375,
|
||||
-30.932466506958008,
|
||||
-31.63547706604004,
|
||||
-32.33848571777344,
|
||||
-32.85923767089844,
|
||||
-33.17168426513672,
|
||||
-33.45809555053711],
|
||||
[0.0,
|
||||
-31.81774139404297,
|
||||
-31.895854949951172,
|
||||
-32.05207824707031,
|
||||
-35.43694305419922,
|
||||
-36.3482551574707,
|
||||
-38.61351013183594,
|
||||
-39.26444625854492,
|
||||
-40.61839294433594,
|
||||
-41.71196365356445],
|
||||
[0.0,
|
||||
-27.33930206298828,
|
||||
-27.834014892578125,
|
||||
-28.849472045898438,
|
||||
-30.567943572998047,
|
||||
-32.98942565917969,
|
||||
-33.067535400390625,
|
||||
-33.067535400390625,
|
||||
-35.67127990722656,
|
||||
-35.69731903076172],
|
||||
[0.0,
|
||||
-25.33441925048828,
|
||||
-26.063465118408203,
|
||||
-26.219690322875977,
|
||||
-26.2457275390625,
|
||||
-26.53213882446289,
|
||||
-27.365337371826172,
|
||||
-28.354759216308594,
|
||||
-28.667207717895508,
|
||||
-28.74532127380371],
|
||||
[0.0,
|
||||
-24.423107147216797,
|
||||
-24.579330444335938,
|
||||
-26.81855010986328,
|
||||
-28.12042236328125,
|
||||
-28.32872200012207,
|
||||
-28.61513328552246,
|
||||
-29.16191864013672,
|
||||
-29.187957763671875,
|
||||
-29.240032196044922],
|
||||
[0.0,
|
||||
-22.027664184570312,
|
||||
-23.850284576416016,
|
||||
-23.980472564697266,
|
||||
-24.292922973632812,
|
||||
-24.787633895874023,
|
||||
-29.279088973999023,
|
||||
-29.55248260498047,
|
||||
-29.903987884521484,
|
||||
-30.190399169921875],
|
||||
[0.0,
|
||||
-31.609439849853516,
|
||||
-31.817739486694336,
|
||||
-32.54678726196289,
|
||||
-32.676971435546875,
|
||||
-32.781124114990234,
|
||||
-32.98942565917969,
|
||||
-33.106590270996094,
|
||||
-33.57526397705078,
|
||||
-34.369407653808594],
|
||||
[0.0,
|
||||
-29.34418296813965,
|
||||
-29.63059425354004,
|
||||
-30.021156311035156,
|
||||
-30.984540939331055,
|
||||
-33.21073913574219,
|
||||
-34.30431365966797,
|
||||
-34.56468963623047,
|
||||
-34.70789337158203,
|
||||
-34.79902648925781],
|
||||
[0.0,
|
||||
-25.438566207885742,
|
||||
-25.69894027709961,
|
||||
-30.190397262573242,
|
||||
-30.802276611328125,
|
||||
-31.58340072631836,
|
||||
-31.609437942504883,
|
||||
-31.64849281311035,
|
||||
-31.973960876464844,
|
||||
-32.29943084716797],
|
||||
[0.0,
|
||||
-27.157039642333984,
|
||||
-32.104148864746094,
|
||||
-32.33848571777344,
|
||||
-34.04393768310547,
|
||||
-34.12205505371094,
|
||||
-34.40846252441406,
|
||||
-34.42148208618164,
|
||||
-34.772987365722656,
|
||||
-34.87713623046875],
|
||||
[0.0,
|
||||
-24.813671112060547,
|
||||
-26.974777221679688,
|
||||
-31.010578155517578,
|
||||
-31.08869171142578,
|
||||
-32.1822624206543,
|
||||
-35.33279037475586,
|
||||
-35.489013671875,
|
||||
-36.999183654785156,
|
||||
-37.88446044921875],
|
||||
[0.0,
|
||||
-20.46541976928711,
|
||||
-20.647682189941406,
|
||||
-23.069164276123047,
|
||||
-24.136699676513672,
|
||||
-25.438570022583008,
|
||||
-25.646869659423828,
|
||||
-26.193655014038086,
|
||||
-26.297805786132812,
|
||||
-26.506103515625],
|
||||
[0.0,
|
||||
-27.18307113647461,
|
||||
-28.30268096923828,
|
||||
-28.56305694580078,
|
||||
-29.526439666748047,
|
||||
-32.416595458984375,
|
||||
-35.202598571777344,
|
||||
-36.426361083984375,
|
||||
-39.31651306152344,
|
||||
-39.38160705566406],
|
||||
[0.0,
|
||||
-18.7469482421875,
|
||||
-20.100894927978516,
|
||||
-21.402767181396484,
|
||||
-21.428804397583008,
|
||||
-22.20992660522461,
|
||||
-22.34011459350586,
|
||||
-22.730674743652344,
|
||||
-23.069162368774414,
|
||||
-23.980472564697266],
|
||||
[-3.576278118089249e-07,
|
||||
-15.2579345703125,
|
||||
-16.481693267822266,
|
||||
-17.991863250732422,
|
||||
-19.215621948242188,
|
||||
-20.25712013244629,
|
||||
-21.350692749023438,
|
||||
-22.314077377319336,
|
||||
-22.496337890625,
|
||||
-22.938974380493164],
|
||||
[-0.08506780862808228,
|
||||
-2.506549835205078,
|
||||
-14.848289489746094,
|
||||
-15.473188400268555,
|
||||
-16.33242416381836,
|
||||
-16.358461380004883,
|
||||
-16.566761016845703,
|
||||
-17.03543472290039,
|
||||
-17.686370849609375,
|
||||
-17.816556930541992],
|
||||
[-0.0194891095161438,
|
||||
-4.445854187011719,
|
||||
-5.591499328613281,
|
||||
-5.956024169921875,
|
||||
-6.685070037841797,
|
||||
-13.142353057861328,
|
||||
-13.558952331542969,
|
||||
-15.173273086547852,
|
||||
-15.303461074829102,
|
||||
-15.85024642944336],
|
||||
[-0.0005990855861455202,
|
||||
-7.4212646484375,
|
||||
-15.675132751464844,
|
||||
-15.72720718383789,
|
||||
-16.76870346069336,
|
||||
-16.76870346069336,
|
||||
-17.706050872802734,
|
||||
-18.669435501098633,
|
||||
-19.398483276367188,
|
||||
-19.658857345581055],
|
||||
[0.0,
|
||||
-24.110658645629883,
|
||||
-25.829130172729492,
|
||||
-26.011390686035156,
|
||||
-26.011390686035156,
|
||||
-26.532140731811523,
|
||||
-26.58421516418457,
|
||||
-27.651750564575195,
|
||||
-27.75589942932129,
|
||||
-28.055330276489258],
|
||||
[-1.1408883333206177,
|
||||
-0.38580334186553955,
|
||||
-7.494022369384766,
|
||||
-12.519245147705078,
|
||||
-14.576202392578125,
|
||||
-16.034297943115234,
|
||||
-16.945608139038086,
|
||||
-17.908992767333984,
|
||||
-18.664077758789062,
|
||||
-19.34105110168457],
|
||||
[0.0,
|
||||
-26.688365936279297,
|
||||
-29.83889389038086,
|
||||
-30.177383422851562,
|
||||
-30.64605712890625,
|
||||
-31.244916915893555,
|
||||
-31.270954132080078,
|
||||
-32.83319854736328,
|
||||
-34.655818939208984,
|
||||
-34.89015579223633],
|
||||
[0.0,
|
||||
-18.929210662841797,
|
||||
-19.16354751586914,
|
||||
-23.589908599853516,
|
||||
-24.683481216430664,
|
||||
-24.995929718017578,
|
||||
-25.516677856445312,
|
||||
-25.542715072631836,
|
||||
-25.77705192565918,
|
||||
-26.063465118408203],
|
||||
[-0.2519786059856415,
|
||||
-1.5017764568328857,
|
||||
-12.437495231628418,
|
||||
-15.457839012145996,
|
||||
-15.744250297546387,
|
||||
-16.837820053100586,
|
||||
-17.41064453125,
|
||||
-17.56686782836914,
|
||||
-17.61894416809082,
|
||||
-18.035541534423828],
|
||||
[0.0,
|
||||
-20.517494201660156,
|
||||
-24.683483123779297,
|
||||
-25.67290496826172,
|
||||
-26.58421516418457,
|
||||
-27.651750564575195,
|
||||
-27.781936645507812,
|
||||
-27.912124633789062,
|
||||
-28.09438705444336,
|
||||
-28.445892333984375],
|
||||
[-3.40932747349143e-05,
|
||||
-10.284820556640625,
|
||||
-18.252273559570312,
|
||||
-20.17904281616211,
|
||||
-21.663175582885742,
|
||||
-22.027700424194336,
|
||||
-22.288074493408203,
|
||||
-22.704673767089844,
|
||||
-23.12127113342285,
|
||||
-23.277496337890625],
|
||||
[0.0,
|
||||
-22.60049057006836,
|
||||
-25.46460723876953,
|
||||
-25.829130172729492,
|
||||
-26.063467025756836,
|
||||
-27.287227630615234,
|
||||
-27.391376495361328,
|
||||
-27.4694881439209,
|
||||
-27.67778778076172,
|
||||
-28.055330276489258],
|
||||
[0.0,
|
||||
-23.902362823486328,
|
||||
-28.823436737060547,
|
||||
-29.240036010742188,
|
||||
-29.31814956665039,
|
||||
-29.917007446289062,
|
||||
-30.021160125732422,
|
||||
-31.21887969970703,
|
||||
-32.416603088378906,
|
||||
-32.416603088378906],
|
||||
[0.0,
|
||||
-28.641170501708984,
|
||||
-31.947925567626953,
|
||||
-32.59886169433594,
|
||||
-33.848655700683594,
|
||||
-34.109031677246094,
|
||||
-34.73393249511719,
|
||||
-35.02033996582031,
|
||||
-35.02033996582031,
|
||||
-36.074859619140625],
|
||||
[-0.013183215633034706,
|
||||
-4.335395336151123,
|
||||
-19.619365692138672,
|
||||
-20.035964965820312,
|
||||
-20.244266510009766,
|
||||
-21.311800003051758,
|
||||
-21.441987991333008,
|
||||
-22.561595916748047,
|
||||
-23.108383178710938,
|
||||
-23.264606475830078],
|
||||
[-8.344646857949556e-07,
|
||||
-14.190400123596191,
|
||||
-15.9088716506958,
|
||||
-18.17412567138672,
|
||||
-18.46053695678711,
|
||||
-18.46053695678711,
|
||||
-18.512611389160156,
|
||||
-18.90317153930664,
|
||||
-19.059398651123047,
|
||||
-19.085433959960938],
|
||||
[0.0,
|
||||
-17.70545196533203,
|
||||
-18.903175354003906,
|
||||
-20.829944610595703,
|
||||
-22.574451446533203,
|
||||
-22.860862731933594,
|
||||
-23.069162368774414,
|
||||
-23.32953643798828,
|
||||
-23.694061279296875,
|
||||
-24.188772201538086],
|
||||
[0.0,
|
||||
-20.022781372070312,
|
||||
-21.038240432739258,
|
||||
-21.220502853393555,
|
||||
-22.496337890625,
|
||||
-22.769729614257812,
|
||||
-23.589908599853516,
|
||||
-23.65500259399414,
|
||||
-23.94141387939453,
|
||||
-24.266881942749023]]
|
||||
}
|
||||
]
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import subprocess
|
||||
import time
|
||||
from app.cli import kill_process
|
||||
|
||||
|
||||
class TestStopServer(unittest.TestCase):
|
||||
@patch("subprocess.run")
|
||||
def test_stop_server_no_process(self, mock_run):
|
||||
# Mock subprocess.run to simulate no process listening on the port
|
||||
mock_run.return_value.returncode = 1
|
||||
with patch("builtins.print") as mock_print:
|
||||
kill_process(port=51000)
|
||||
mock_print.assert_called_with("No process found listening on port 51000.")
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_stop_server_process_killed(self, mock_run):
|
||||
# Simulate lsof returning a process id
|
||||
mock_run.side_effect = [
|
||||
MagicMock(returncode=0, stdout="uvicorn 1234 user LISTEN\n"),
|
||||
MagicMock(returncode=0), # for killing the process
|
||||
MagicMock(returncode=1), # for checking the process after it is killed
|
||||
]
|
||||
with patch("builtins.print") as mock_print:
|
||||
kill_process(port=51000, wait=True, timeout=5)
|
||||
mock_print.assert_any_call("Killing model server process with PID 1234")
|
||||
mock_print.assert_any_call("Process 1234 has been killed.")
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_stop_server_multiple_pids(self, mock_run):
|
||||
# Simulate lsof returning multiple process ids (e.g., 1234 and 5678)
|
||||
mock_run.side_effect = [
|
||||
MagicMock(
|
||||
returncode=0,
|
||||
stdout="uvicorn 1234 user LISTEN\nuvicorn 5678 user LISTEN\n",
|
||||
), # lsof output
|
||||
MagicMock(returncode=0), # first kill command for PID 1234
|
||||
MagicMock(returncode=1), # PID 1234 is successfully terminated
|
||||
MagicMock(returncode=0), # second kill command for PID 5678
|
||||
MagicMock(returncode=1), # PID 5678 is successfully terminated
|
||||
]
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
kill_process(port=51000, wait=True, timeout=5)
|
||||
|
||||
# Assert that the function tried to kill both PIDs
|
||||
mock_print.assert_any_call("Killing model server process with PID 1234")
|
||||
mock_print.assert_any_call("Process 1234 has been killed.")
|
||||
mock_print.assert_any_call("Killing model server process with PID 5678")
|
||||
mock_print.assert_any_call("Process 5678 has been killed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import app.commons.constants as const
|
||||
from fastapi import Response
|
||||
from app.function_calling.model_utils import (
|
||||
process_messages,
|
||||
chat_completion,
|
||||
Message,
|
||||
ChatMessage,
|
||||
Choice,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
|
||||
|
||||
def sample_messages():
|
||||
# Ensure fields are explicitly set with valid data or empty values
|
||||
return [
|
||||
Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""),
|
||||
Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[{"function": {"name": "sample_tool"}}],
|
||||
tool_call_id="sample_id",
|
||||
),
|
||||
Message(
|
||||
role="tool", content="Response from tool", tool_calls=[], tool_call_id=""
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def sample_request(sample_messages):
|
||||
return ChatMessage(
|
||||
messages=sample_messages,
|
||||
tools=[{"name": "sample_tool", "description": "A sample tool"}],
|
||||
)
|
||||
|
||||
|
||||
@patch("app.commons.constants.arch_function_hanlder")
|
||||
def test_process_messages(mock_hanlder):
|
||||
messages = sample_messages()
|
||||
processed = process_messages(messages)
|
||||
|
||||
assert len(processed) == 3
|
||||
assert processed[0] == {"role": "user", "content": "Hello!"}
|
||||
assert processed[1] == {
|
||||
"role": "assistant",
|
||||
"content": '<tool_call>\n{"name": "sample_tool"}\n</tool_call>',
|
||||
}
|
||||
assert processed[2] == {
|
||||
"role": "user",
|
||||
"content": "<tool_response>\nResponse from tool\n</tool_response>",
|
||||
}
|
||||
|
||||
|
||||
@patch("app.commons.constants.arch_function_client")
|
||||
@patch("app.commons.constants.arch_function_hanlder")
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion(mock_hanlder, mock_client):
|
||||
# Mock the model list return for client
|
||||
mock_client.models.list.return_value = MagicMock(
|
||||
data=[MagicMock(id="sample_model")]
|
||||
)
|
||||
request = sample_request(sample_messages())
|
||||
# Simulate stream response as list of tokens
|
||||
mock_response = AsyncMock()
|
||||
mock_response.__aiter__.return_value = [
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
|
||||
]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Mock the tool formatter
|
||||
mock_hanlder._format_system.return_value = "<formatted_tools>"
|
||||
|
||||
response = Response()
|
||||
chat_response = await chat_completion(request, response)
|
||||
|
||||
assert isinstance(chat_response, ChatCompletionResponse)
|
||||
assert chat_response.choices[0].message.content is not None
|
||||
|
||||
first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
|
||||
assert first_call_args["stream"] == True
|
||||
assert "model" in first_call_args
|
||||
assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
|
||||
|
||||
# Check that the arguments for the second call to 'create' include the pre-fill completion
|
||||
second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
|
||||
assert second_call_args["stream"] == False
|
||||
assert "model" in second_call_args
|
||||
assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
import json
|
||||
from app.function_calling.hallucination_handler import HallucinationStateHandler
|
||||
import pytest
|
||||
import os
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(__file__)
|
||||
|
||||
# Construct the full path to the JSON file
|
||||
json_file_path = os.path.join(current_dir, "test_cases.json")
|
||||
|
||||
with open(json_file_path) as f:
|
||||
test_cases = json.load(f)
|
||||
|
||||
get_weather_api = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get current weather at a location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "str",
|
||||
"description": "The location to get the weather for",
|
||||
"format": "City, State",
|
||||
},
|
||||
"unit": {
|
||||
"type": "str",
|
||||
"description": "The unit to return the weather in.",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
},
|
||||
"days": {
|
||||
"type": "str",
|
||||
"description": "the number of days for the request.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
function_description = get_weather_api["function"]
|
||||
if type(function_description) != list:
|
||||
function_description = [get_weather_api["function"]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", test_cases)
|
||||
def test_hallucination(case):
|
||||
state = HallucinationStateHandler(
|
||||
response_iterator=None, function=function_description
|
||||
)
|
||||
for token, logprob in zip(case["tokens"], case["logprobs"]):
|
||||
if token != "</tool_call>":
|
||||
state.append_and_check_token_hallucination(token, logprob)
|
||||
if state.hallucination:
|
||||
break
|
||||
assert state.hallucination == case["expect"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_hallucinate_sample", [True, False])
|
||||
def test_hallucination_prompt(is_hallucinate_sample):
|
||||
TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
|
||||
TOOL_PROMPT = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
|
||||
FORMAT_PROMPT = """
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
""".strip()
|
||||
|
||||
def convert_tools(tools):
|
||||
return "\n".join([json.dumps(tool) for tool in tools])
|
||||
|
||||
def format_prompt(tools):
|
||||
tool_text = convert_tools(tools)
|
||||
|
||||
return (
|
||||
TASK_PROMPT
|
||||
+ "\n\n"
|
||||
+ TOOL_PROMPT.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ FORMAT_PROMPT
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
openai_format_tools = [get_weather_api]
|
||||
|
||||
system_prompt = format_prompt(openai_format_tools)
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
|
||||
|
||||
# List models API
|
||||
model = client.models.list().data[0].id
|
||||
assert model == "Arch-Function"
|
||||
if not is_hallucinate_sample:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "user", "content": "can you help me check weather?"},
|
||||
{"role": "user", "content": "How is the weather in Seattle in 7 days?"},
|
||||
# {"role": "assistant", "content": "Of course!"},
|
||||
# {"role": "user", "content": "Seattle please"}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "user", "content": "can you help me check weather?"},
|
||||
{"role": "user", "content": "How is the weather in Seattle in days?"},
|
||||
# {"role": "assistant", "content": "Of course!"},
|
||||
# {"role": "user", "content": "Seattle please"}
|
||||
]
|
||||
|
||||
extra_body = {
|
||||
"temperature": 0.6,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
# "continue_final_message": True,
|
||||
# "add_generation_prompt": False,
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="Arch-Function", messages=messages, extra_body=extra_body, stream=True
|
||||
)
|
||||
|
||||
hallu = HallucinationStateHandler(
|
||||
response_iterator=resp, function=function_description
|
||||
)
|
||||
|
||||
for token in hallu:
|
||||
assert len(hallu.tokens) >= 0
|
||||
assert hallu.hallucination == is_hallucinate_sample
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue