mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Merge branch 'main' into adil/signoz_tracing
This commit is contained in:
commit
c18dc04a7d
51 changed files with 1593 additions and 695 deletions
6
.github/workflows/e2e_tests.yml
vendored
6
.github/workflows/e2e_tests.yml
vendored
|
|
@ -8,7 +8,8 @@ on:
|
|||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-latest-m
|
||||
# runs-on: gh-large-150gb-ssd
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
|
@ -29,4 +30,5 @@ jobs:
|
|||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
||||
run: |
|
||||
cd e2e_tests && bash run_e2e_tests.sh
|
||||
python -mvenv venv
|
||||
source venv/bin/activate && cd e2e_tests && bash run_e2e_tests.sh
|
||||
|
|
|
|||
81
README.md
81
README.md
|
|
@ -1,6 +1,9 @@
|
|||
<p>
|
||||
<img src="docs/source/_static/img/arch-logo.png" alt="Arch Gateway Logo" title="Arch Gateway Logo">
|
||||
</p>
|
||||

|
||||
|
||||
[](https://github.com/katanemo/arch/actions/workflows/pre-commit.yml)
|
||||
[](https://github.com/katanemo/arch/actions/workflows/rust_tests.yml)
|
||||
[](https://github.com/katanemo/arch/actions/workflows/e2e_tests.yml)
|
||||
[](https://github.com/katanemo/arch/actions/workflows/static.yml)
|
||||
|
||||
## Build fast, robust, and personalized AI agents.
|
||||
|
||||
|
|
@ -68,28 +71,43 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-3.5-turbo
|
||||
default: true
|
||||
|
||||
# default system prompt used by all prompt targets
|
||||
system_prompt: |
|
||||
You are a network assistant that just offers facts; not advice on manufacturers or purchasing decisions.
|
||||
You are a network assistant that helps operators with a better understanding of network traffic flow and perform actions on networking operations. No advice on manufacturers or purchasing decisions.
|
||||
|
||||
prompt_targets:
|
||||
- name: reboot_devices
|
||||
description: Reboot specific devices or device groups
|
||||
|
||||
path: /agent/device_reboot
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs) to reboot.
|
||||
required: false
|
||||
- name: device_group
|
||||
type: str
|
||||
description: The name of the device group to reboot
|
||||
required: false
|
||||
- name: device_summary
|
||||
description: Retrieve network statistics for specific devices within a time range
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_summary
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs) to retrieve statistics for.
|
||||
required: true # device_ids are required to get device statistics
|
||||
- name: days
|
||||
type: int
|
||||
description: The number of days for which to gather device statistics.
|
||||
default: "7"
|
||||
- name: reboot_devices
|
||||
description: Reboot a list of devices
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_reboot
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs).
|
||||
required: true
|
||||
- name: days
|
||||
type: int
|
||||
description: A list of device identifiers (IDs)
|
||||
default: "7"
|
||||
|
||||
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
|
||||
endpoints:
|
||||
|
|
@ -97,7 +115,7 @@ endpoints:
|
|||
# value could be ip address or a hostname with port
|
||||
# this could also be a list of endpoints for load balancing
|
||||
# for example endpoint: [ ip1:port, ip2:port ]
|
||||
endpoint: 127.0.0.1:80
|
||||
endpoint: host.docker.internal:18083
|
||||
# max time to wait for a connection to be established
|
||||
connect_timeout: 0.005s
|
||||
```
|
||||
|
|
@ -106,20 +124,23 @@ endpoints:
|
|||
Make outbound calls via Arch
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
# Set the OpenAI API base URL to the Arch gateway endpoint
|
||||
openai.api_base = "http://127.0.0.1:12000/v1"
|
||||
# No need to set a specific openai.api_key since it's configured in Arch's gateway
|
||||
openai.api_key = "null"
|
||||
from openai import OpenAI
|
||||
|
||||
# Use the OpenAI client as usual
|
||||
response = openai.Completion.create(
|
||||
model="text-davinci-003",
|
||||
prompt="What is the capital of France?"
|
||||
client = OpenAI(
|
||||
# No need to set a specific openai.api_key since it's configured in Arch's gateway
|
||||
api_key = '--',
|
||||
# Set the OpenAI API base URL to the Arch gateway endpoint
|
||||
base_url = "http://127.0.0.1:12000/v1"
|
||||
)
|
||||
|
||||
print("OpenAI Response:", response.choices[0].text.strip())
|
||||
response = client.chat.completions.create(
|
||||
# we select model from arch_config file
|
||||
model="--",
|
||||
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
||||
)
|
||||
|
||||
print("OpenAI Response:", response.choices[0].message.content)
|
||||
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ FROM envoyproxy/envoy:v1.31-latest as envoy
|
|||
#Build config generator, so that we have a single build image for both Rust and Python
|
||||
FROM python:3-slim as arch
|
||||
|
||||
RUN apt-get update && apt-get install -y gettext-base && apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
RUN apt-get update && apt-get install -y gettext-base curl && apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=builder /arch/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
|
||||
COPY --from=builder /arch/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm
|
||||
|
|
|
|||
|
|
@ -11,6 +11,11 @@ services:
|
|||
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
|
||||
- ~/archgw_logs:/var/log/
|
||||
env_file:
|
||||
- stage.env
|
||||
- env.list
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:10000/healthz"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
# Define paths
|
||||
source_schema="../arch_config_schema.yaml"
|
||||
source_compose="../docker-compose.yaml"
|
||||
source_stage_env="../stage.env"
|
||||
destination_dir="config"
|
||||
|
||||
# Ensure the destination directory exists only if it doesn't already
|
||||
|
|
@ -15,7 +14,7 @@ fi
|
|||
# Copy the files
|
||||
cp "$source_schema" "$destination_dir/arch_config_schema.yaml"
|
||||
cp "$source_compose" "$destination_dir/docker-compose.yaml"
|
||||
cp "$source_stage_env" "$destination_dir/stage.env"
|
||||
touch "$destination_dir/env.list"
|
||||
|
||||
# Print success message
|
||||
echo "Files copied successfully!"
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def validate_and_render_schema():
|
|||
try:
|
||||
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(str(e))
|
||||
exit(1) # validate_prompt_config failed. Exit
|
||||
|
||||
with open(ARCH_CONFIG_FILE, "r") as file:
|
||||
|
|
@ -73,7 +73,6 @@ def validate_and_render_schema():
|
|||
|
||||
print("updated clusters", inferred_clusters)
|
||||
|
||||
config_yaml = add_secret_key_to_llm_providers(config_yaml)
|
||||
arch_llm_providers = config_yaml["llm_providers"]
|
||||
arch_tracing = config_yaml.get("tracing", {})
|
||||
arch_config_string = yaml.dump(config_yaml)
|
||||
|
|
|
|||
|
|
@ -58,6 +58,8 @@ def main(ctx, version):
|
|||
click.echo(f"archgw cli version: {get_version()}")
|
||||
ctx.exit()
|
||||
|
||||
log.info(f"Starting archgw cli version: {get_version()}")
|
||||
|
||||
if ctx.invoked_subcommand is None:
|
||||
click.echo("""Arch (The Intelligent Prompt Gateway) CLI""")
|
||||
click.echo(logo)
|
||||
|
|
@ -68,7 +70,7 @@ def main(ctx, version):
|
|||
@click.option(
|
||||
"--service",
|
||||
default=SERVICE_ALL,
|
||||
help="Optioanl parameter to specify which service to build. Options are model_server, archgw",
|
||||
help="Optional parameter to specify which service to build. Options are model_server, archgw",
|
||||
)
|
||||
def build(service):
|
||||
"""Build Arch from source. Must be in root of cloned repo."""
|
||||
|
|
@ -168,7 +170,7 @@ def up(file, path, service):
|
|||
arch_config_schema_file=arch_schema_config,
|
||||
)
|
||||
except Exception as e:
|
||||
log.info(f"Exiting archgw up: {e}")
|
||||
log.info(f"Exiting archgw up: validation failed")
|
||||
sys.exit(1)
|
||||
|
||||
log.info("Starging arch model server and arch gateway")
|
||||
|
|
@ -178,6 +180,12 @@ def up(file, path, service):
|
|||
env = os.environ.copy()
|
||||
# check if access_keys are preesnt in the config file
|
||||
access_keys = get_llm_provider_access_keys(arch_config_file=arch_config_file)
|
||||
|
||||
# remove duplicates
|
||||
access_keys = set(access_keys)
|
||||
# remove the $ from the access_keys
|
||||
access_keys = [item[1:] if item.startswith("$") else item for item in access_keys]
|
||||
|
||||
if access_keys:
|
||||
if file:
|
||||
app_env_file = os.path.join(
|
||||
|
|
@ -186,6 +194,7 @@ def up(file, path, service):
|
|||
else:
|
||||
app_env_file = os.path.abspath(os.path.join(path, ".env"))
|
||||
|
||||
print(f"app_env_file: {app_env_file}")
|
||||
if not os.path.exists(
|
||||
app_env_file
|
||||
): # check to see if the environment variables in the current environment or not
|
||||
|
|
@ -205,7 +214,7 @@ def up(file, path, service):
|
|||
env_stage[access_key] = env_file_dict[access_key]
|
||||
|
||||
with open(
|
||||
pkg_resources.resource_filename(__name__, "../config/stage.env"), "w"
|
||||
pkg_resources.resource_filename(__name__, "../config/env.list"), "w"
|
||||
) as file:
|
||||
for key, value in env_stage.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
|
|
|
|||
28
arch/tools/poetry.lock
generated
28
arch/tools/poetry.lock
generated
|
|
@ -13,7 +13,7 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "archgw_modelserver"
|
||||
version = "0.0.4"
|
||||
version = "0.1.1"
|
||||
description = "A model server for serving models"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
|
@ -250,13 +250,13 @@ tqdm = ["tqdm"]
|
|||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.26.1"
|
||||
version = "0.26.2"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.26.1-py3-none-any.whl", hash = "sha256:5927a8fc64ae68859cd954b7cc29d1c8390a5e15caba6d3d349c973be8fdacf3"},
|
||||
{file = "huggingface_hub-0.26.1.tar.gz", hash = "sha256:414c0d9b769eecc86c70f9d939d0f48bb28e8461dd1130021542eff0212db890"},
|
||||
{file = "huggingface_hub-0.26.2-py3-none-any.whl", hash = "sha256:98c2a5a8e786c7b2cb6fdeb2740893cba4d53e312572ed3d8afafda65b128c46"},
|
||||
{file = "huggingface_hub-0.26.2.tar.gz", hash = "sha256:b100d853465d965733964d123939ba287da60a547087783ddff8a323f340332b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -765,33 +765,33 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "75.2.0"
|
||||
version = "75.3.0"
|
||||
description = "Easily download, build, install, upgrade, and uninstall Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"},
|
||||
{file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"},
|
||||
{file = "setuptools-75.3.0-py3-none-any.whl", hash = "sha256:f2504966861356aa38616760c0f66568e535562374995367b4e69c7143cf6bcd"},
|
||||
{file = "setuptools-75.3.0.tar.gz", hash = "sha256:fba5dd4d766e97be1b1681d98712680ae8f2f26d7881245f2ce9e40714f1a686"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"]
|
||||
core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
|
||||
core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
|
||||
cover = ["pytest-cov"]
|
||||
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
|
||||
enabler = ["pytest-enabler (>=2.2)"]
|
||||
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
|
||||
type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"]
|
||||
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
|
||||
type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.12.*)", "pytest-mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.66.5"
|
||||
version = "4.66.6"
|
||||
description = "Fast, Extensible Progress Meter"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"},
|
||||
{file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"},
|
||||
{file = "tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63"},
|
||||
{file = "tqdm-4.66.6.tar.gz", hash = "sha256:4bdd694238bef1485ce839d67967ab50af8f9272aab687c0d7702a01da0be090"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -834,4 +834,4 @@ zstd = ["zstandard (>=0.18.0)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "e51783523cbe087cb1db94e874a28564b43af03a1689523d3738b212e288f64b"
|
||||
content-hash = "c6d4df2015f02a8105934690d43d2133771c358f2f393a7a8a5a1e0df1c0a55b"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "archgw"
|
||||
version = "0.0.5"
|
||||
version = "0.1.0"
|
||||
description = "Python-based CLI tool to manage Arch Gateway."
|
||||
authors = ["Katanemo Labs, Inc."]
|
||||
packages = [
|
||||
|
|
@ -12,7 +12,6 @@ include = [
|
|||
# Include package data (docker-compose.yaml and other files)[
|
||||
"config/docker-compose.yaml",
|
||||
"config/arch_config_schema.yaml",
|
||||
"config/stage.env"
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
|
|
@ -22,8 +21,8 @@ pydantic = "^2.9.2"
|
|||
click = "^8.1.7"
|
||||
jinja2 = "^3.1.4"
|
||||
jsonschema = "^4.23.0"
|
||||
setuptools = "75.2.0"
|
||||
archgw_modelserver = "0.0.4"
|
||||
setuptools = "75.3.0"
|
||||
archgw_modelserver = "0.1.1"
|
||||
huggingface_hub = "^0.26.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
|
|
|||
|
|
@ -16,10 +16,6 @@
|
|||
"name": "model_server",
|
||||
"path": "model_server"
|
||||
},
|
||||
{
|
||||
"name": "chatbot_ui",
|
||||
"path": "chatbot_ui"
|
||||
},
|
||||
{
|
||||
"name": "e2e_tests",
|
||||
"path": "e2e_tests"
|
||||
|
|
|
|||
|
|
@ -1,77 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import yaml
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def process_stream_chunk(chunk, history):
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.role and delta.role != history[-1]["role"]:
|
||||
# create new history item if role changes
|
||||
# this is likely due to arch tool call and api response
|
||||
history.append({"role": delta.role})
|
||||
|
||||
history[-1]["model"] = chunk.model
|
||||
# append tool calls to history if there are any in the chunk
|
||||
if delta.tool_calls:
|
||||
history[-1]["tool_calls"] = delta.tool_calls
|
||||
|
||||
if delta.content:
|
||||
# append content to the last history item
|
||||
history[-1]["content"] = history[-1].get("content", "") + delta.content
|
||||
# yield content if it is from assistant
|
||||
if history[-1]["role"] == "assistant":
|
||||
return delta.content
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def convert_prompt_target_to_openai_format(target):
|
||||
tool = {
|
||||
"description": target["description"],
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
if "parameters" in target:
|
||||
for param_info in target["parameters"]:
|
||||
parameter = {
|
||||
"type": param_info["type"],
|
||||
"description": param_info["description"],
|
||||
}
|
||||
|
||||
for key in ["default", "format", "enum", "items", "minimum", "maximum"]:
|
||||
if key in param_info:
|
||||
parameter[key] = param_info[key]
|
||||
|
||||
tool["parameters"]["properties"][param_info["name"]] = parameter
|
||||
|
||||
required = param_info.get("required", False)
|
||||
if required:
|
||||
tool["parameters"]["required"].append(param_info["name"])
|
||||
|
||||
return {"name": target["name"], "info": tool}
|
||||
|
||||
|
||||
def get_prompt_targets():
|
||||
try:
|
||||
with open(os.getenv("ARCH_CONFIG", "arch_config.yaml"), "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
available_tools = []
|
||||
for target in config["prompt_targets"]:
|
||||
if not target.get("default", False):
|
||||
available_tools.append(
|
||||
convert_prompt_target_to_openai_format(target)
|
||||
)
|
||||
|
||||
return {tool["name"]: tool["info"] for tool in available_tools}
|
||||
except Exception as e:
|
||||
log.info(e)
|
||||
return None
|
||||
|
|
@ -319,7 +319,14 @@ impl HttpContext for StreamContext {
|
|||
self.arch_state = Some(Vec::new());
|
||||
}
|
||||
|
||||
let mut data = serde_json::from_str(&body_utf8).unwrap();
|
||||
let mut data = match serde_json::from_str(&body_utf8) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
warn!("could not deserialize response: {}", e);
|
||||
self.send_server_error(ServerError::Deserialization(e), None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
if let Value::Object(ref mut map) = data {
|
||||
// serialize arch state and add to metadata
|
||||
|
|
|
|||
|
|
@ -458,65 +458,93 @@ impl StreamContext {
|
|||
// it may be that arch fc is handling the conversation for parameter collection
|
||||
if arch_assistant {
|
||||
info!("arch fc is engaged in parameter collection");
|
||||
} else {
|
||||
if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!(
|
||||
"default prompt target found, forwarding request to default prompt target"
|
||||
);
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
} else if let Some(default_prompt_target) = self
|
||||
.prompt_targets
|
||||
.values()
|
||||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found, forwarding request to default prompt target");
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
let upstream_endpoint = endpoint.name;
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
MESSAGES_KEY.to_string(),
|
||||
callout_context.request_body.messages.clone(),
|
||||
);
|
||||
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
let mut headers = vec![
|
||||
(":method", "POST"),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
|
||||
(":path", &upstream_path),
|
||||
(":authority", &upstream_endpoint),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
if self.traceparent.is_some() {
|
||||
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
if self.request_id.is_some() {
|
||||
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
if self.traceparent.is_some() {
|
||||
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&upstream_path,
|
||||
headers,
|
||||
Some(arch_messages_json.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||
|
||||
if let Err(e) = self.http_call(call_args, callout_context) {
|
||||
warn!("error dispatching default prompt target request: {}", e);
|
||||
return self.send_server_error(
|
||||
ServerError::HttpDispatch(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
// if no default prompt target is found and similarity score is low send response to upstream llm
|
||||
// removing tool calls and tool response
|
||||
|
||||
let messages = self.filter_out_arch_messages(&callout_context);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model,
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"archgw (low similarity score) => llm request: {}",
|
||||
llm_request_str
|
||||
);
|
||||
|
||||
self.set_http_request_body(
|
||||
0,
|
||||
self.request_body_size,
|
||||
&llm_request_str.into_bytes(),
|
||||
);
|
||||
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
|
|
@ -862,7 +890,7 @@ impl StreamContext {
|
|||
);
|
||||
|
||||
debug!(
|
||||
"archgw => api call, endpoint: {}/{}, body: {}",
|
||||
"archgw => api call, endpoint: {}{}, body: {}",
|
||||
endpoint.name.as_str(),
|
||||
path,
|
||||
tool_params_json_str
|
||||
|
|
@ -901,42 +929,8 @@ impl StreamContext {
|
|||
"archgw <= api call response: {}",
|
||||
self.tool_call_response.as_ref().unwrap()
|
||||
);
|
||||
let prompt_target_name = callout_context.prompt_target_name.unwrap();
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(&prompt_target_name)
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
let mut messages: Vec<Message> = Vec::new();
|
||||
|
||||
// add system prompt
|
||||
let system_prompt = match prompt_target.system_prompt.as_ref() {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(system_prompt) => Some(system_prompt.clone()),
|
||||
};
|
||||
if system_prompt.is_some() {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: system_prompt,
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
||||
// don't send tools message and api response to chat gpt
|
||||
for m in callout_context.request_body.messages.iter() {
|
||||
// don't send api response and tool calls to upstream LLMs
|
||||
if m.role == TOOL_ROLE
|
||||
|| m.content.is_none()
|
||||
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
messages.push(m.clone());
|
||||
}
|
||||
let mut messages = self.filter_out_arch_messages(&callout_context);
|
||||
|
||||
let user_message = match messages.pop() {
|
||||
Some(user_message) => user_message,
|
||||
|
|
@ -988,6 +982,51 @@ impl StreamContext {
|
|||
self.resume_http_request();
|
||||
}
|
||||
|
||||
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
|
||||
let mut messages: Vec<Message> = Vec::new();
|
||||
// add system prompt
|
||||
|
||||
let system_prompt = match callout_context.prompt_target_name.as_ref() {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(prompt_target_name) => {
|
||||
let prompt_system_prompt = self
|
||||
.prompt_targets
|
||||
.get(prompt_target_name)
|
||||
.unwrap()
|
||||
.clone()
|
||||
.system_prompt;
|
||||
match prompt_system_prompt {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(system_prompt) => Some(system_prompt),
|
||||
}
|
||||
}
|
||||
};
|
||||
if system_prompt.is_some() {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: system_prompt,
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
||||
// don't send tools message and api response to chat gpt
|
||||
for m in callout_context.request_body.messages.iter() {
|
||||
// don't send api response and tool calls to upstream LLMs
|
||||
if m.role == TOOL_ROLE
|
||||
|| m.content.is_none()
|
||||
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
messages.push(m.clone());
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
|
||||
debug!(
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ services:
|
|||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../../chatbot_ui
|
||||
context: ../shared/chatbot_ui
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
|
|
|
|||
|
|
@ -5,16 +5,12 @@ FROM base AS builder
|
|||
WORKDIR /src
|
||||
|
||||
COPY requirements.txt /src/
|
||||
COPY workforce_data.json /src/
|
||||
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
|
||||
|
||||
COPY . /src
|
||||
|
||||
FROM python:3.10-slim AS output
|
||||
|
||||
COPY --from=builder /runtime /usr/local
|
||||
|
||||
COPY . /app
|
||||
WORKDIR /app
|
||||
COPY . /app
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"]
|
||||
|
|
|
|||
|
|
@ -25,7 +25,5 @@ This demo showcases how the **Arch** can be used to build an HR agent to manage
|
|||
```sh
|
||||
sh run_demo.sh
|
||||
```
|
||||
3. Navigate to http://localhost:18080/
|
||||
3. Navigate to http://localhost:18080/agent/chat
|
||||
4. "Can you give me workforce data for asia?"
|
||||
|
||||

|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o-mini
|
||||
default: true
|
||||
|
||||
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
|
||||
|
|
@ -18,21 +18,17 @@ endpoints:
|
|||
# value could be ip address or a hostname with port
|
||||
# this could also be a list of endpoints for load balancing
|
||||
# for example endpoint: [ ip1:port, ip2:port ]
|
||||
endpoint: host.docker.internal:18083
|
||||
endpoint: host.docker.internal:18080
|
||||
# max time to wait for a connection to be established
|
||||
connect_timeout: 0.005s
|
||||
|
||||
# default system prompt used by all prompt targets
|
||||
system_prompt: |
|
||||
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workfoce planning. NOTHING ELSE. When you get data in json format, offer some summary but don't be too verbose.
|
||||
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workforce planning. Use following rules when responding,
|
||||
- when you get data in json format, offer some summary but don't be too verbose
|
||||
- be concise, to the point and do not over analyze the data
|
||||
|
||||
prompt_targets:
|
||||
- name: hr_qa
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/hr_qa
|
||||
description: Handle general Q/A related to HR.
|
||||
default: true
|
||||
- name: workforce
|
||||
description: Get workforce data like headcount and satisfacton levels by region and staffing type
|
||||
endpoint:
|
||||
|
|
@ -41,16 +37,16 @@ prompt_targets:
|
|||
parameters:
|
||||
- name: staffing_type
|
||||
type: str
|
||||
description: Staffing type like contract, fte or agency
|
||||
description: specific category or nature of employment used by an organization like fte, contract and agency
|
||||
required: true
|
||||
- name: region
|
||||
type: str
|
||||
required: true
|
||||
description: Geographical region for which you want workforce data like asia, europe, americas.
|
||||
- name: point_in_time
|
||||
- name: data_snapshot_days_ago
|
||||
type: int
|
||||
required: false
|
||||
description: the point in time for which to retrieve data. For e.g 0 days ago, 30 days ago, etc.
|
||||
description: the snapshot day for which you want workforce data.
|
||||
- name: slack_message
|
||||
endpoint:
|
||||
name: app_server
|
||||
|
|
|
|||
|
|
@ -4,22 +4,14 @@ services:
|
|||
context: .
|
||||
environment:
|
||||
- SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN:-None}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
- ../shared/chatbot_ui/common.py:/app/common.py
|
||||
ports:
|
||||
- "18083:80"
|
||||
- "18080:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../../chatbot_ui
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 456 KiB After Width: | Height: | Size: 549 KiB |
|
|
@ -1,28 +1,39 @@
|
|||
import os
|
||||
import json
|
||||
import pandas as pd
|
||||
import gradio as gr
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from openai import OpenAI
|
||||
from common import create_gradio_app
|
||||
|
||||
app = FastAPI()
|
||||
workforce_data_df = None
|
||||
demo_description = """This demo showcases how the **Arch** can be used to build an
|
||||
HR agent to manage workforce-related inquiries, workforce planning, and communication via Slack.
|
||||
It intelligently routes incoming prompts to the correct targets, providing concise and useful responses
|
||||
tailored for HR and workforce decision-making. """
|
||||
|
||||
with open("workforce_data.json") as file:
|
||||
workforce_data = json.load(file)
|
||||
workforce_data_df = pd.json_normalize(
|
||||
workforce_data, record_path=["regions"], meta=["point_in_time", "satisfaction"]
|
||||
workforce_data,
|
||||
record_path=["regions"],
|
||||
meta=["data_snapshot_days_ago", "satisfaction"],
|
||||
)
|
||||
|
||||
|
||||
# Define the request model
|
||||
class WorkforceRequset(BaseModel):
|
||||
class WorkforceRequest(BaseModel):
|
||||
region: str
|
||||
staffing_type: str
|
||||
point_in_time: Optional[int] = None
|
||||
data_snapshot_days_ago: Optional[int] = None
|
||||
|
||||
|
||||
class SlackRequest(BaseModel):
|
||||
|
|
@ -36,25 +47,6 @@ class WorkforceResponse(BaseModel):
|
|||
satisfaction: float
|
||||
|
||||
|
||||
# Post method for device summary
|
||||
@app.post("/agent/workforce")
|
||||
def get_workforce(request: WorkforceRequset):
|
||||
"""
|
||||
Endpoint to workforce data by region, staffing type at a given point in time.
|
||||
"""
|
||||
region = request.region.lower()
|
||||
staffing_type = request.staffing_type.lower()
|
||||
point_in_time = request.point_in_time if request.point_in_time else 0
|
||||
|
||||
response = {
|
||||
"region": region,
|
||||
"staffing_type": f"Staffing agency: {staffing_type}",
|
||||
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)][staffing_type].values[0])}",
|
||||
"satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)]['satisfaction'].values[0])}",
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/agent/slack_message")
|
||||
def send_slack_message(request: SlackRequest):
|
||||
"""
|
||||
|
|
@ -80,27 +72,38 @@ def send_slack_message(request: SlackRequest):
|
|||
print(f"Error sending message: {e.response['error']}")
|
||||
|
||||
|
||||
@app.post("/agent/hr_qa")
|
||||
async def general_hr_qa():
|
||||
# Post method for device summary
|
||||
@app.post("/agent/workforce")
|
||||
def get_workforce(request: WorkforceRequest):
|
||||
"""
|
||||
This method handles Q/A related to general issues in HR.
|
||||
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
|
||||
Endpoint to workforce data by region, staffing type at a given point in time.
|
||||
"""
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am a helpful HR agent, and I can help you plan for workforce related questions",
|
||||
},
|
||||
"finish_reason": "completed",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"model": "hr_agent",
|
||||
"usage": {"completion_tokens": 0},
|
||||
}
|
||||
region = request.region.lower()
|
||||
staffing_type = request.staffing_type.lower()
|
||||
data_snapshot_days_ago = (
|
||||
request.data_snapshot_days_ago
|
||||
if request.data_snapshot_days_ago
|
||||
else 0 # this param is not required.
|
||||
)
|
||||
|
||||
response = {
|
||||
"region": region,
|
||||
"staffing_type": f"Staffing agency: {staffing_type}",
|
||||
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)][staffing_type].values[0])}",
|
||||
"satisfaction": f"Satisfaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}",
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
client = OpenAI(
|
||||
api_key="--",
|
||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||
)
|
||||
|
||||
gr.mount_gradio_app(
|
||||
app, create_gradio_app(demo_description, client), path="/agent/chat"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,13 @@
|
|||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
slack-sdk
|
||||
typing
|
||||
pandas
|
||||
gradio==5.3.0
|
||||
async_timeout==4.0.3
|
||||
loguru==0.7.2
|
||||
asyncio==3.4.3
|
||||
httpx==0.27.0
|
||||
python-dotenv==1.0.1
|
||||
pydantic==2.8.2
|
||||
openai==1.51.0
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[
|
||||
{
|
||||
"point_in_time": 0,
|
||||
"data_snapshot_days_ago": 0,
|
||||
"regions": [
|
||||
{ "region": "asia", "contract": 100, "fte": 150, "agency": 2000 },
|
||||
{ "region": "europe", "contract": 80, "fte": 120, "agency": 2500 },
|
||||
|
|
@ -9,7 +9,7 @@
|
|||
"satisfaction": 3.5
|
||||
},
|
||||
{
|
||||
"point_in_time": 30,
|
||||
"data_snapshot_days_ago": 30,
|
||||
"regions": [
|
||||
{ "region": "asia", "contract": 110, "fte": 155, "agency": 1000 },
|
||||
{ "region": "europe", "contract": 85, "fte": 130, "agency": 1600 },
|
||||
|
|
@ -18,7 +18,7 @@
|
|||
"satisfaction": 4.0
|
||||
},
|
||||
{
|
||||
"point_in_time": 60,
|
||||
"data_snapshot_days_ago": 60,
|
||||
"regions": [
|
||||
{ "region": "asia", "contract": 115, "fte": 160, "agency": 500 },
|
||||
{ "region": "europe", "contract": 90, "fte": 140, "agency": 700 },
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ system_prompt: |
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ services:
|
|||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../../chatbot_ui
|
||||
context: ../shared/chatbot_ui
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18080:8080"
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ The assistant can perform several key operations, including rebooting devices, a
|
|||
```sh
|
||||
sh run_demo.sh
|
||||
```
|
||||
3. Navigate to http://localhost:18080/
|
||||
3. Navigate to http://localhost:18080/agent/chat
|
||||
4. Tell me what can you do for me?"
|
||||
|
||||
# Observability
|
||||
|
|
@ -39,4 +39,4 @@ Arch gateway publishes stats endpoint at http://localhost:19901/stats. In this d
|
|||
|
||||
Here is sample interaction
|
||||
|
||||
<img width="575" alt="image" src="https://github.com/user-attachments/assets/25d40f46-616e-41ea-be8e-1623055c84ec">
|
||||

|
||||
|
|
|
|||
|
|
@ -8,36 +8,17 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-3.5-turbo
|
||||
default: true
|
||||
|
||||
# default system prompt used by all prompt targets
|
||||
system_prompt: |
|
||||
You are a network assistant that just offers facts; not advice on manufacturers or purchasing decisions.
|
||||
You are a network assistant that helps operators with a better understanding of network traffic flow and perform actions on networking operations. No advice on manufacturers or purchasing decisions.
|
||||
|
||||
prompt_targets:
|
||||
- name: reboot_devices
|
||||
description: Reboot specific devices or device groups
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_reboot
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs) to reboot.
|
||||
required: true
|
||||
- name: time_range
|
||||
type: int
|
||||
description: Optional time range in days for reboot operations. Defaults to 7.
|
||||
- name: network_qa
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/network_summary
|
||||
description: Handle general Q/A related to networking.
|
||||
default: true
|
||||
- name: device_summary
|
||||
description: Retrieve statistics for specific devices within a time range
|
||||
description: Retrieve network statistics for specific devices within a time range
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_summary
|
||||
|
|
@ -46,9 +27,23 @@ prompt_targets:
|
|||
type: list
|
||||
description: A list of device identifiers (IDs) to retrieve statistics for.
|
||||
required: true # device_ids are required to get device statistics
|
||||
- name: time_range
|
||||
- name: days
|
||||
type: int
|
||||
description: Time range in days for which to gather device statistics. Defaults to 7.
|
||||
description: The number of days for which to gather device statistics.
|
||||
default: "7"
|
||||
- name: reboot_devices
|
||||
description: Reboot a list of devices
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_reboot
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs).
|
||||
required: true
|
||||
- name: days
|
||||
type: int
|
||||
description: A list of device identifiers (IDs)
|
||||
default: "7"
|
||||
|
||||
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
|
||||
|
|
@ -57,15 +52,6 @@ endpoints:
|
|||
# value could be ip address or a hostname with port
|
||||
# this could also be a list of endpoints for load balancing
|
||||
# for example endpoint: [ ip1:port, ip2:port ]
|
||||
endpoint: host.docker.internal:18083
|
||||
endpoint: host.docker.internal:18080
|
||||
# max time to wait for a connection to be established
|
||||
connect_timeout: 0.005s
|
||||
|
||||
ratelimits:
|
||||
- model: gpt-4
|
||||
selector:
|
||||
key: selector-key
|
||||
value: selector-value
|
||||
limit:
|
||||
tokens: 1
|
||||
unit: minute
|
||||
|
|
|
|||
|
|
@ -2,24 +2,14 @@ services:
|
|||
api_server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
environment:
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
- ../shared/chatbot_ui/common.py:/app/common.py
|
||||
ports:
|
||||
- "18083:80"
|
||||
- "18080:80"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
|
||||
interval: 5s
|
||||
retries: 20
|
||||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../../chatbot_ui
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
|
|
|
|||
BIN
demos/network_agent/image.png
Normal file
BIN
demos/network_agent/image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 636 KiB |
|
|
@ -1,8 +1,15 @@
|
|||
from openai import OpenAI
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
from common import create_gradio_app
|
||||
import gradio as gr
|
||||
import os
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
demo_description = """This demo illustrates how **Arch** can be used to perform function calling with network-related tasks.
|
||||
In this demo, you act as a **network assistant** that provides factual information, without offering advice on manufacturers or purchasing decisions."""
|
||||
|
||||
|
||||
# Define the request model
|
||||
|
|
@ -88,27 +95,15 @@ def get_device_summary(request: DeviceSummaryRequest):
|
|||
return DeviceSummaryResponse(statistics=statistics)
|
||||
|
||||
|
||||
@app.post("/agent/network_summary")
|
||||
async def policy_qa():
|
||||
"""
|
||||
This method handles Q/A related to general issues in networks.
|
||||
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
|
||||
"""
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am a helpful networking agent, and I can help you get status for network devices or reboot them",
|
||||
},
|
||||
"finish_reason": "completed",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"model": "network_agent",
|
||||
"usage": {"completion_tokens": 0},
|
||||
}
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
client = OpenAI(
|
||||
api_key="--",
|
||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||
)
|
||||
|
||||
gr.mount_gradio_app(
|
||||
app, create_gradio_app(demo_description, client), path="/agent/chat"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True)
|
||||
|
|
|
|||
|
|
@ -2,3 +2,12 @@ fastapi
|
|||
uvicorn
|
||||
pydantic
|
||||
typing
|
||||
pandas
|
||||
gradio==5.3.0
|
||||
async_timeout==4.0.3
|
||||
loguru==0.7.2
|
||||
asyncio==3.4.3
|
||||
httpx==0.27.0
|
||||
python-dotenv==1.0.1
|
||||
pydantic==2.8.2
|
||||
openai==1.51.0
|
||||
|
|
|
|||
171
demos/shared/chatbot_ui/common.py
Normal file
171
demos/shared/chatbot_ui/common.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import yaml
|
||||
import gradio as gr
|
||||
from typing import List, Optional, Tuple
|
||||
from functools import partial
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
GRADIO_CSS_STYLE = """
|
||||
.json-container {
|
||||
height: 80vh !important;
|
||||
overflow-y: auto !important;
|
||||
}
|
||||
.chatbot {
|
||||
height: calc(80vh - 100px) !important;
|
||||
overflow-y: auto !important;
|
||||
}
|
||||
footer {visibility: hidden}
|
||||
"""
|
||||
|
||||
|
||||
def chat(
|
||||
query: Optional[str],
|
||||
conversation: Optional[List[Tuple[str, str]]],
|
||||
history: List[dict],
|
||||
client,
|
||||
):
|
||||
history.append({"role": "user", "content": query})
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
# we select model from arch_config file
|
||||
model="--",
|
||||
messages=history,
|
||||
temperature=1.0,
|
||||
stream=True,
|
||||
)
|
||||
except Exception as e:
|
||||
# remove last user message in case of exception
|
||||
history.pop()
|
||||
log.info("Error calling gateway API: {}".format(e))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e))
|
||||
|
||||
conversation.append((query, ""))
|
||||
|
||||
for chunk in response:
|
||||
tokens = process_stream_chunk(chunk, history)
|
||||
if tokens:
|
||||
conversation[-1] = (
|
||||
conversation[-1][0],
|
||||
conversation[-1][1] + tokens,
|
||||
)
|
||||
|
||||
yield "", conversation, history
|
||||
|
||||
|
||||
def create_gradio_app(demo_description, client):
|
||||
with gr.Blocks(
|
||||
theme=gr.themes.Default(
|
||||
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"]
|
||||
),
|
||||
fill_height=True,
|
||||
css=GRADIO_CSS_STYLE,
|
||||
) as demo:
|
||||
with gr.Row(equal_height=True):
|
||||
history = gr.State([])
|
||||
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown(demo_description),
|
||||
with gr.Accordion("Available Tools/APIs", open=True):
|
||||
with gr.Column(scale=1):
|
||||
gr.JSON(
|
||||
value=get_prompt_targets(),
|
||||
show_indices=False,
|
||||
elem_classes="json-container",
|
||||
min_height="80vh",
|
||||
)
|
||||
|
||||
with gr.Column(scale=2):
|
||||
chatbot = gr.Chatbot(
|
||||
label="Arch Chatbot",
|
||||
elem_classes="chatbot",
|
||||
)
|
||||
textbox = gr.Textbox(
|
||||
show_label=False,
|
||||
placeholder="Enter text and press enter",
|
||||
autofocus=True,
|
||||
elem_classes="textbox",
|
||||
)
|
||||
chat_with_client = partial(chat, client=client)
|
||||
|
||||
textbox.submit(
|
||||
chat_with_client,
|
||||
[textbox, chatbot, history],
|
||||
[textbox, chatbot, history],
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def process_stream_chunk(chunk, history):
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.role and delta.role != history[-1]["role"]:
|
||||
# create new history item if role changes
|
||||
# this is likely due to arch tool call and api response
|
||||
history.append({"role": delta.role})
|
||||
|
||||
history[-1]["model"] = chunk.model
|
||||
# append tool calls to history if there are any in the chunk
|
||||
if delta.tool_calls:
|
||||
history[-1]["tool_calls"] = delta.tool_calls
|
||||
|
||||
if delta.content:
|
||||
# append content to the last history item
|
||||
history[-1]["content"] = history[-1].get("content", "") + delta.content
|
||||
# yield content if it is from assistant
|
||||
if history[-1]["role"] == "assistant":
|
||||
return delta.content
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def convert_prompt_target_to_openai_format(target):
|
||||
tool = {
|
||||
"description": target["description"],
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
if "parameters" in target:
|
||||
for param_info in target["parameters"]:
|
||||
parameter = {
|
||||
"type": param_info["type"],
|
||||
"description": param_info["description"],
|
||||
}
|
||||
|
||||
for key in ["default", "format", "enum", "items", "minimum", "maximum"]:
|
||||
if key in param_info:
|
||||
parameter[key] = param_info[key]
|
||||
|
||||
tool["parameters"]["properties"][param_info["name"]] = parameter
|
||||
|
||||
required = param_info.get("required", False)
|
||||
if required:
|
||||
tool["parameters"]["required"].append(param_info["name"])
|
||||
|
||||
return {"name": target["name"], "info": tool}
|
||||
|
||||
|
||||
def get_prompt_targets():
|
||||
try:
|
||||
with open(os.getenv("ARCH_CONFIG", "arch_config.yaml"), "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
available_tools = []
|
||||
for target in config["prompt_targets"]:
|
||||
if not target.get("default", False):
|
||||
available_tools.append(
|
||||
convert_prompt_target_to_openai_format(target)
|
||||
)
|
||||
|
||||
return {tool["name"]: tool["info"] for tool in available_tools}
|
||||
except Exception as e:
|
||||
log.info(e)
|
||||
return None
|
||||
|
|
@ -8,7 +8,7 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-3.5-turbo
|
||||
default: true
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
stream: true
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ listen:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
stream: true
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ listener:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
stream: true
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ endpoints:
|
|||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: OPENAI_API_KEY
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-4o
|
||||
default: true
|
||||
stream: true
|
||||
|
|
@ -47,7 +47,7 @@ llm_providers:
|
|||
|
||||
- name: Mistral8x7b
|
||||
provider: mistral
|
||||
access_key: MISTRAL_API_KEY
|
||||
access_key: $MISTRAL_API_KEY
|
||||
model: mistral-8x7b
|
||||
|
||||
- name: MistralLocal7b
|
||||
|
|
|
|||
|
|
@ -6,6 +6,11 @@ log() {
|
|||
echo "$timestamp: $message"
|
||||
}
|
||||
|
||||
print_disk_usage() {
|
||||
echo free disk space
|
||||
df -h | grep "/$"
|
||||
}
|
||||
|
||||
wait_for_healthz() {
|
||||
local healthz_url="$1"
|
||||
local timeout_seconds="${2:-30}" # Default timeout of 30 seconds
|
||||
|
|
@ -28,6 +33,8 @@ wait_for_healthz() {
|
|||
return 1
|
||||
fi
|
||||
|
||||
print_disk_usage
|
||||
|
||||
sleep $sleep_between
|
||||
done
|
||||
}
|
||||
|
|
|
|||
39
e2e_tests/poetry.lock
generated
39
e2e_tests/poetry.lock
generated
|
|
@ -455,13 +455,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "7.4.4"
|
||||
version = "8.3.3"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"},
|
||||
{file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"},
|
||||
{file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"},
|
||||
{file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -469,11 +469,11 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
|||
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
|
||||
iniconfig = "*"
|
||||
packaging = "*"
|
||||
pluggy = ">=0.12,<2.0"
|
||||
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
pluggy = ">=1.5,<2"
|
||||
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
|
|
@ -493,6 +493,23 @@ pytest = ">=4.6"
|
|||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-retry"
|
||||
version = "1.6.3"
|
||||
description = "Adds the ability to retry flaky tests in CI environments"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pytest_retry-1.6.3-py3-none-any.whl", hash = "sha256:e96f7df77ee70b0838d1085f9c3b8b5b7d74bf8947a0baf32e2b8c71b27683c8"},
|
||||
{file = "pytest_retry-1.6.3.tar.gz", hash = "sha256:36ccfa11c8c8f9ddad5e20375182146d040c20c4a791745139c5a99ddf1b557d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "flake8", "isort", "mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-sugar"
|
||||
version = "1.0.0"
|
||||
|
|
@ -535,13 +552,13 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
|||
|
||||
[[package]]
|
||||
name = "selenium"
|
||||
version = "4.25.0"
|
||||
version = "4.26.0"
|
||||
description = "Official Python bindings for Selenium WebDriver"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "selenium-4.25.0-py3-none-any.whl", hash = "sha256:3798d2d12b4a570bc5790163ba57fef10b2afee958bf1d80f2a3cf07c4141f33"},
|
||||
{file = "selenium-4.25.0.tar.gz", hash = "sha256:95d08d3b82fb353f3c474895154516604c7f0e6a9a565ae6498ef36c9bac6921"},
|
||||
{file = "selenium-4.26.0-py3-none-any.whl", hash = "sha256:48013f36e812de5b3948ef53d04e73f77bc923ee3e1d7d99eaf0618179081b99"},
|
||||
{file = "selenium-4.26.0.tar.gz", hash = "sha256:f0780f85f10310aa5d085b81e79d73d3c93b83d8de121d0400d543a50ee963e8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -699,4 +716,4 @@ h11 = ">=0.9.0,<1"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "6ae4fa6397091b87b63698201a08d7d97628ed65992d46514f118768b46b99ce"
|
||||
content-hash = "a40015b90325879e50f82cca6a26a730d763cad26589671df798832d41c42db3"
|
||||
|
|
|
|||
|
|
@ -9,11 +9,12 @@ package-mode = false
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
pytest = "^7.3.1"
|
||||
pytest = "^8.3.3"
|
||||
requests = "^2.29.0"
|
||||
selenium = "^4.11.2"
|
||||
pytest-sugar = "^1.0.0"
|
||||
deepdiff = "^8.0.1"
|
||||
pytest-retry = "^1.6.3"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest-cov = "^4.1.0"
|
||||
|
|
@ -21,3 +22,6 @@ pytest-cov = "^4.1.0"
|
|||
[tool.pytest.ini_options]
|
||||
python_files = ["test*.py"]
|
||||
addopts = ["-v", "-s"]
|
||||
retries = 2
|
||||
retry_delay = 0.5
|
||||
cumulative_timing = false
|
||||
|
|
|
|||
|
|
@ -4,6 +4,11 @@ set -e
|
|||
|
||||
. ./common_scripts.sh
|
||||
|
||||
print_disk_usage
|
||||
|
||||
mkdir -p ~/archgw_logs
|
||||
touch ~/archgw_logs/modelserver.log
|
||||
|
||||
print_debug() {
|
||||
log "Received signal to stop"
|
||||
log "Printing debug logs for model_server"
|
||||
|
|
@ -18,63 +23,62 @@ trap 'print_debug' INT TERM ERR
|
|||
|
||||
log starting > ../build.log
|
||||
|
||||
log building function_callling demo
|
||||
log ===============================
|
||||
log building and running function_callling demo
|
||||
log ===========================================
|
||||
cd ../demos/function_calling
|
||||
docker compose build -q
|
||||
|
||||
log starting the function_calling demo
|
||||
docker compose up -d
|
||||
docker compose up api_server --build -d
|
||||
cd -
|
||||
|
||||
log building model server
|
||||
log =====================
|
||||
print_disk_usage
|
||||
|
||||
log building and install model server
|
||||
log =================================
|
||||
cd ../model_server
|
||||
poetry install 2>&1 >> ../build.log
|
||||
log starting model server
|
||||
log =====================
|
||||
mkdir -p ~/archgw_logs
|
||||
touch ~/archgw_logs/modelserver.log
|
||||
poetry run archgw_modelserver restart &
|
||||
poetry install
|
||||
cd -
|
||||
|
||||
print_disk_usage
|
||||
|
||||
log building and installing archgw cli
|
||||
log ==================================
|
||||
cd ../arch/tools
|
||||
sh build_cli.sh
|
||||
cd -
|
||||
|
||||
print_disk_usage
|
||||
|
||||
log building docker image for arch gateway
|
||||
log ======================================
|
||||
cd ../
|
||||
archgw build
|
||||
cd -
|
||||
|
||||
print_disk_usage
|
||||
|
||||
log startup arch gateway with function calling demo
|
||||
cd ..
|
||||
tail -F ~/archgw_logs/modelserver.log &
|
||||
model_server_tail_pid=$!
|
||||
archgw down
|
||||
archgw up demos/function_calling/arch_config.yaml
|
||||
kill $model_server_tail_pid
|
||||
cd -
|
||||
|
||||
log building llm and prompt gateway rust modules
|
||||
log ============================================
|
||||
cd ../arch
|
||||
docker build -f Dockerfile .. -t katanemo/archgw -q
|
||||
log starting the arch gateway service
|
||||
log =================================
|
||||
docker compose -f docker-compose.e2e.yaml down
|
||||
log waiting for model service to be healthy
|
||||
wait_for_healthz "http://localhost:51000/healthz" 300
|
||||
kill $model_server_tail_pid
|
||||
docker compose -f docker-compose.e2e.yaml up -d
|
||||
log waiting for arch gateway service to be healthy
|
||||
wait_for_healthz "http://localhost:10000/healthz" 60
|
||||
log waiting for arch gateway service to be healthy
|
||||
cd -
|
||||
print_disk_usage
|
||||
|
||||
log running e2e tests
|
||||
log =================
|
||||
poetry install 2>&1 >> ../build.log
|
||||
poetry install
|
||||
poetry run pytest
|
||||
|
||||
log shutting down the arch gateway service
|
||||
log ======================================
|
||||
cd ../arch
|
||||
docker compose -f docker-compose.e2e.yaml stop 2>&1 >> ../build.log
|
||||
cd ../
|
||||
archgw down
|
||||
cd -
|
||||
|
||||
log shutting down the function_calling demo
|
||||
log =======================================
|
||||
cd ../demos/function_calling
|
||||
docker compose down 2>&1 >> ../build.log
|
||||
cd -
|
||||
|
||||
log shutting down the model server
|
||||
log ==============================
|
||||
cd ../model_server
|
||||
poetry run archgw_modelserver stop 2>&1 >> ../build.log
|
||||
docker compose down
|
||||
cd -
|
||||
|
|
|
|||
BIN
image.png
Normal file
BIN
image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 207 KiB |
2
model_server/.vscode/launch.json
vendored
2
model_server/.vscode/launch.json
vendored
|
|
@ -9,7 +9,7 @@
|
|||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": ["app.main:app","--reload", "--port", "8000"]
|
||||
"args": ["app.main:app","--reload", "--port", "51000"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import importlib
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
|
@ -7,6 +8,15 @@ import tempfile
|
|||
import subprocess
|
||||
import logging
|
||||
|
||||
|
||||
def get_version():
|
||||
try:
|
||||
version = importlib.metadata.version("archgw_modelserver")
|
||||
return version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
return "version not found"
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
|
|
@ -14,6 +24,7 @@ logging.basicConfig(
|
|||
|
||||
log = logging.getLogger("model_server.cli")
|
||||
log.setLevel(logging.INFO)
|
||||
log.info(f"model server version: {get_version()}")
|
||||
|
||||
|
||||
def run_server(port=51000):
|
||||
|
|
@ -37,8 +48,9 @@ def run_server(port=51000):
|
|||
def start_server(port=51000):
|
||||
"""Start the Uvicorn server"""
|
||||
log.info(
|
||||
"Starting model server - loading some awesomeness, this may take some time :)"
|
||||
"starting model server - loading some awesomeness, this may take some time :)"
|
||||
)
|
||||
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
|
|
@ -61,7 +73,7 @@ def start_server(port=51000):
|
|||
log.info(f"Model server started with PID {process.pid}")
|
||||
else:
|
||||
# Add model_server boot-up logs
|
||||
log.info("Model server - Didn't Sart In Time. Shutting Down")
|
||||
log.info("model server - didn't start in time, shutting down")
|
||||
process.terminate()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -67,12 +67,16 @@ async def chat_completion(req: ChatMessage, res: Response):
|
|||
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
|
||||
)
|
||||
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
try:
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"model_server <= arch_function: error: {e}")
|
||||
raise
|
||||
|
||||
tool_calls = const.arch_function_hanlder.extract_tool_calls(
|
||||
resp.choices[0].message.content
|
||||
|
|
|
|||
1192
model_server/poetry.lock
generated
1192
model_server/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "archgw_modelserver"
|
||||
version = "0.0.4"
|
||||
version = "0.1.1"
|
||||
description = "A model server for serving models"
|
||||
authors = ["Katanemo Labs, Inc <archgw@katanemo.com>"]
|
||||
license = "Apache 2.0"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue