Merge branch 'main' into adil/signoz_tracing

This commit is contained in:
Adil Hafeez 2024-11-04 16:05:20 -08:00
commit c18dc04a7d
51 changed files with 1593 additions and 695 deletions

View file

@ -8,7 +8,8 @@ on:
jobs:
test:
runs-on: ubuntu-latest
runs-on: ubuntu-latest-m
# runs-on: gh-large-150gb-ssd
steps:
- name: Checkout code
@ -29,4 +30,5 @@ jobs:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
run: |
cd e2e_tests && bash run_e2e_tests.sh
python -mvenv venv
source venv/bin/activate && cd e2e_tests && bash run_e2e_tests.sh

View file

@ -1,6 +1,9 @@
<p>
<img src="docs/source/_static/img/arch-logo.png" alt="Arch Gateway Logo" title="Arch Gateway Logo">
</p>
![alt text](image.png)
[![pre-commit](https://github.com/katanemo/arch/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/katanemo/arch/actions/workflows/pre-commit.yml)
[![rust tests (prompt and llm gateway)](https://github.com/katanemo/arch/actions/workflows/rust_tests.yml/badge.svg)](https://github.com/katanemo/arch/actions/workflows/rust_tests.yml)
[![e2e tests](https://github.com/katanemo/arch/actions/workflows/e2e_tests.yml/badge.svg)](https://github.com/katanemo/arch/actions/workflows/e2e_tests.yml)
[![Build and Deploy Documentation](https://github.com/katanemo/arch/actions/workflows/static.yml/badge.svg)](https://github.com/katanemo/arch/actions/workflows/static.yml)
## Build fast, robust, and personalized AI agents.
@ -68,28 +71,43 @@ listener:
llm_providers:
- name: OpenAI
provider: openai
access_key: OPENAI_API_KEY
model: gpt-4o
access_key: $OPENAI_API_KEY
model: gpt-3.5-turbo
default: true
# default system prompt used by all prompt targets
system_prompt: |
You are a network assistant that just offers facts; not advice on manufacturers or purchasing decisions.
You are a network assistant that helps operators with a better understanding of network traffic flow and perform actions on networking operations. No advice on manufacturers or purchasing decisions.
prompt_targets:
- name: reboot_devices
description: Reboot specific devices or device groups
path: /agent/device_reboot
parameters:
- name: device_ids
type: list
description: A list of device identifiers (IDs) to reboot.
required: false
- name: device_group
type: str
description: The name of the device group to reboot
required: false
- name: device_summary
description: Retrieve network statistics for specific devices within a time range
endpoint:
name: app_server
path: /agent/device_summary
parameters:
- name: device_ids
type: list
description: A list of device identifiers (IDs) to retrieve statistics for.
required: true # device_ids are required to get device statistics
- name: days
type: int
description: The number of days for which to gather device statistics.
default: "7"
- name: reboot_devices
description: Reboot a list of devices
endpoint:
name: app_server
path: /agent/device_reboot
parameters:
- name: device_ids
type: list
description: A list of device identifiers (IDs).
required: true
- name: days
type: int
description: A list of device identifiers (IDs)
default: "7"
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
endpoints:
@ -97,7 +115,7 @@ endpoints:
# value could be ip address or a hostname with port
# this could also be a list of endpoints for load balancing
# for example endpoint: [ ip1:port, ip2:port ]
endpoint: 127.0.0.1:80
endpoint: host.docker.internal:18083
# max time to wait for a connection to be established
connect_timeout: 0.005s
```
@ -106,20 +124,23 @@ endpoints:
Make outbound calls via Arch
```python
import openai
# Set the OpenAI API base URL to the Arch gateway endpoint
openai.api_base = "http://127.0.0.1:12000/v1"
# No need to set a specific openai.api_key since it's configured in Arch's gateway
openai.api_key = "null"
from openai import OpenAI
# Use the OpenAI client as usual
response = openai.Completion.create(
model="text-davinci-003",
prompt="What is the capital of France?"
client = OpenAI(
# No need to set a specific openai.api_key since it's configured in Arch's gateway
api_key = '--',
# Set the OpenAI API base URL to the Arch gateway endpoint
base_url = "http://127.0.0.1:12000/v1"
)
print("OpenAI Response:", response.choices[0].text.strip())
response = client.chat.completions.create(
# we select model from arch_config file
model="--",
messages=[{"role": "user", "content": "What is the capital of France?"}],
)
print("OpenAI Response:", response.choices[0].message.content)
```

View file

@ -13,7 +13,7 @@ FROM envoyproxy/envoy:v1.31-latest as envoy
#Build config generator, so that we have a single build image for both Rust and Python
FROM python:3-slim as arch
RUN apt-get update && apt-get install -y gettext-base && apt-get clean && rm -rf /var/lib/apt/lists/*
RUN apt-get update && apt-get install -y gettext-base curl && apt-get clean && rm -rf /var/lib/apt/lists/*
COPY --from=builder /arch/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm
COPY --from=builder /arch/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm

View file

@ -11,6 +11,11 @@ services:
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
- ~/archgw_logs:/var/log/
env_file:
- stage.env
- env.list
extra_hosts:
- "host.docker.internal:host-gateway"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:10000/healthz"]
interval: 30s
timeout: 10s
retries: 3

View file

@ -3,7 +3,6 @@
# Define paths
source_schema="../arch_config_schema.yaml"
source_compose="../docker-compose.yaml"
source_stage_env="../stage.env"
destination_dir="config"
# Ensure the destination directory exists only if it doesn't already
@ -15,7 +14,7 @@ fi
# Copy the files
cp "$source_schema" "$destination_dir/arch_config_schema.yaml"
cp "$source_compose" "$destination_dir/docker-compose.yaml"
cp "$source_stage_env" "$destination_dir/stage.env"
touch "$destination_dir/env.list"
# Print success message
echo "Files copied successfully!"

View file

@ -34,7 +34,7 @@ def validate_and_render_schema():
try:
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
except Exception as e:
print(e)
print(str(e))
exit(1) # validate_prompt_config failed. Exit
with open(ARCH_CONFIG_FILE, "r") as file:
@ -73,7 +73,6 @@ def validate_and_render_schema():
print("updated clusters", inferred_clusters)
config_yaml = add_secret_key_to_llm_providers(config_yaml)
arch_llm_providers = config_yaml["llm_providers"]
arch_tracing = config_yaml.get("tracing", {})
arch_config_string = yaml.dump(config_yaml)

View file

@ -58,6 +58,8 @@ def main(ctx, version):
click.echo(f"archgw cli version: {get_version()}")
ctx.exit()
log.info(f"Starting archgw cli version: {get_version()}")
if ctx.invoked_subcommand is None:
click.echo("""Arch (The Intelligent Prompt Gateway) CLI""")
click.echo(logo)
@ -68,7 +70,7 @@ def main(ctx, version):
@click.option(
"--service",
default=SERVICE_ALL,
help="Optioanl parameter to specify which service to build. Options are model_server, archgw",
help="Optional parameter to specify which service to build. Options are model_server, archgw",
)
def build(service):
"""Build Arch from source. Must be in root of cloned repo."""
@ -168,7 +170,7 @@ def up(file, path, service):
arch_config_schema_file=arch_schema_config,
)
except Exception as e:
log.info(f"Exiting archgw up: {e}")
log.info(f"Exiting archgw up: validation failed")
sys.exit(1)
log.info("Starging arch model server and arch gateway")
@ -178,6 +180,12 @@ def up(file, path, service):
env = os.environ.copy()
# check if access_keys are preesnt in the config file
access_keys = get_llm_provider_access_keys(arch_config_file=arch_config_file)
# remove duplicates
access_keys = set(access_keys)
# remove the $ from the access_keys
access_keys = [item[1:] if item.startswith("$") else item for item in access_keys]
if access_keys:
if file:
app_env_file = os.path.join(
@ -186,6 +194,7 @@ def up(file, path, service):
else:
app_env_file = os.path.abspath(os.path.join(path, ".env"))
print(f"app_env_file: {app_env_file}")
if not os.path.exists(
app_env_file
): # check to see if the environment variables in the current environment or not
@ -205,7 +214,7 @@ def up(file, path, service):
env_stage[access_key] = env_file_dict[access_key]
with open(
pkg_resources.resource_filename(__name__, "../config/stage.env"), "w"
pkg_resources.resource_filename(__name__, "../config/env.list"), "w"
) as file:
for key, value in env_stage.items():
file.write(f"{key}={value}\n")

28
arch/tools/poetry.lock generated
View file

@ -13,7 +13,7 @@ files = [
[[package]]
name = "archgw_modelserver"
version = "0.0.4"
version = "0.1.1"
description = "A model server for serving models"
optional = false
python-versions = "*"
@ -250,13 +250,13 @@ tqdm = ["tqdm"]
[[package]]
name = "huggingface-hub"
version = "0.26.1"
version = "0.26.2"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "huggingface_hub-0.26.1-py3-none-any.whl", hash = "sha256:5927a8fc64ae68859cd954b7cc29d1c8390a5e15caba6d3d349c973be8fdacf3"},
{file = "huggingface_hub-0.26.1.tar.gz", hash = "sha256:414c0d9b769eecc86c70f9d939d0f48bb28e8461dd1130021542eff0212db890"},
{file = "huggingface_hub-0.26.2-py3-none-any.whl", hash = "sha256:98c2a5a8e786c7b2cb6fdeb2740893cba4d53e312572ed3d8afafda65b128c46"},
{file = "huggingface_hub-0.26.2.tar.gz", hash = "sha256:b100d853465d965733964d123939ba287da60a547087783ddff8a323f340332b"},
]
[package.dependencies]
@ -765,33 +765,33 @@ files = [
[[package]]
name = "setuptools"
version = "75.2.0"
version = "75.3.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"},
{file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"},
{file = "setuptools-75.3.0-py3-none-any.whl", hash = "sha256:f2504966861356aa38616760c0f66568e535562374995367b4e69c7143cf6bcd"},
{file = "setuptools-75.3.0.tar.gz", hash = "sha256:fba5dd4d766e97be1b1681d98712680ae8f2f26d7881245f2ce9e40714f1a686"},
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"]
core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.12.*)", "pytest-mypy"]
[[package]]
name = "tqdm"
version = "4.66.5"
version = "4.66.6"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
files = [
{file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"},
{file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"},
{file = "tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63"},
{file = "tqdm-4.66.6.tar.gz", hash = "sha256:4bdd694238bef1485ce839d67967ab50af8f9272aab687c0d7702a01da0be090"},
]
[package.dependencies]
@ -834,4 +834,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "e51783523cbe087cb1db94e874a28564b43af03a1689523d3738b212e288f64b"
content-hash = "c6d4df2015f02a8105934690d43d2133771c358f2f393a7a8a5a1e0df1c0a55b"

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw"
version = "0.0.5"
version = "0.1.0"
description = "Python-based CLI tool to manage Arch Gateway."
authors = ["Katanemo Labs, Inc."]
packages = [
@ -12,7 +12,6 @@ include = [
# Include package data (docker-compose.yaml and other files)[
"config/docker-compose.yaml",
"config/arch_config_schema.yaml",
"config/stage.env"
]
[tool.poetry.dependencies]
@ -22,8 +21,8 @@ pydantic = "^2.9.2"
click = "^8.1.7"
jinja2 = "^3.1.4"
jsonschema = "^4.23.0"
setuptools = "75.2.0"
archgw_modelserver = "0.0.4"
setuptools = "75.3.0"
archgw_modelserver = "0.1.1"
huggingface_hub = "^0.26.0"
[tool.poetry.scripts]

View file

@ -16,10 +16,6 @@
"name": "model_server",
"path": "model_server"
},
{
"name": "chatbot_ui",
"path": "chatbot_ui"
},
{
"name": "e2e_tests",
"path": "e2e_tests"

View file

@ -1,77 +0,0 @@
import json
import logging
import os
import yaml
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
log = logging.getLogger(__name__)
def process_stream_chunk(chunk, history):
delta = chunk.choices[0].delta
if delta.role and delta.role != history[-1]["role"]:
# create new history item if role changes
# this is likely due to arch tool call and api response
history.append({"role": delta.role})
history[-1]["model"] = chunk.model
# append tool calls to history if there are any in the chunk
if delta.tool_calls:
history[-1]["tool_calls"] = delta.tool_calls
if delta.content:
# append content to the last history item
history[-1]["content"] = history[-1].get("content", "") + delta.content
# yield content if it is from assistant
if history[-1]["role"] == "assistant":
return delta.content
return None
def convert_prompt_target_to_openai_format(target):
tool = {
"description": target["description"],
"parameters": {"type": "object", "properties": {}, "required": []},
}
if "parameters" in target:
for param_info in target["parameters"]:
parameter = {
"type": param_info["type"],
"description": param_info["description"],
}
for key in ["default", "format", "enum", "items", "minimum", "maximum"]:
if key in param_info:
parameter[key] = param_info[key]
tool["parameters"]["properties"][param_info["name"]] = parameter
required = param_info.get("required", False)
if required:
tool["parameters"]["required"].append(param_info["name"])
return {"name": target["name"], "info": tool}
def get_prompt_targets():
try:
with open(os.getenv("ARCH_CONFIG", "arch_config.yaml"), "r") as file:
config = yaml.safe_load(file)
available_tools = []
for target in config["prompt_targets"]:
if not target.get("default", False):
available_tools.append(
convert_prompt_target_to_openai_format(target)
)
return {tool["name"]: tool["info"] for tool in available_tools}
except Exception as e:
log.info(e)
return None

View file

@ -319,7 +319,14 @@ impl HttpContext for StreamContext {
self.arch_state = Some(Vec::new());
}
let mut data = serde_json::from_str(&body_utf8).unwrap();
let mut data = match serde_json::from_str(&body_utf8) {
Ok(data) => data,
Err(e) => {
warn!("could not deserialize response: {}", e);
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
}
};
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata

View file

@ -458,65 +458,93 @@ impl StreamContext {
// it may be that arch fc is handling the conversation for parameter collection
if arch_assistant {
info!("arch fc is engaged in parameter collection");
} else {
if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!(
"default prompt target found, forwarding request to default prompt target"
);
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
} else if let Some(default_prompt_target) = self
.prompt_targets
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found, forwarding request to default prompt target");
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
let mut headers = vec![
(":method", "POST"),
(ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
];
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
headers,
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
if self.request_id.is_some() {
headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap()));
}
if self.traceparent.is_some() {
headers.push((TRACE_PARENT_HEADER, self.traceparent.as_ref().unwrap()));
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&upstream_path,
headers,
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if let Err(e) = self.http_call(call_args, callout_context) {
warn!("error dispatching default prompt target request: {}", e);
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
return;
} else {
// if no default prompt target is found and similarity score is low send response to upstream llm
// removing tool calls and tool response
let messages = self.filter_out_arch_messages(&callout_context);
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
debug!(
"archgw (low similarity score) => llm request: {}",
llm_request_str
);
self.set_http_request_body(
0,
self.request_body_size,
&llm_request_str.into_bytes(),
);
self.resume_http_request();
return;
}
@ -862,7 +890,7 @@ impl StreamContext {
);
debug!(
"archgw => api call, endpoint: {}/{}, body: {}",
"archgw => api call, endpoint: {}{}, body: {}",
endpoint.name.as_str(),
path,
tool_params_json_str
@ -901,42 +929,8 @@ impl StreamContext {
"archgw <= api call response: {}",
self.tool_call_response.as_ref().unwrap()
);
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
.get(&prompt_target_name)
.unwrap()
.clone();
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match prompt_target.system_prompt.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt.clone()),
};
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
let mut messages = self.filter_out_arch_messages(&callout_context);
let user_message = match messages.pop() {
Some(user_message) => user_message,
@ -988,6 +982,51 @@ impl StreamContext {
self.resume_http_request();
}
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match callout_context.prompt_target_name.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target_name) => {
let prompt_system_prompt = self
.prompt_targets
.get(prompt_target_name)
.unwrap()
.clone()
.system_prompt;
match prompt_system_prompt {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt),
}
}
};
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
messages
}
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!(

View file

@ -12,7 +12,7 @@ services:
chatbot_ui:
build:
context: ../../chatbot_ui
context: ../shared/chatbot_ui
ports:
- "18080:8080"
environment:

View file

@ -5,16 +5,12 @@ FROM base AS builder
WORKDIR /src
COPY requirements.txt /src/
COPY workforce_data.json /src/
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
COPY . /src
FROM python:3.10-slim AS output
COPY --from=builder /runtime /usr/local
COPY . /app
WORKDIR /app
COPY . /app
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"]

View file

@ -25,7 +25,5 @@ This demo showcases how the **Arch** can be used to build an HR agent to manage
```sh
sh run_demo.sh
```
3. Navigate to http://localhost:18080/
3. Navigate to http://localhost:18080/agent/chat
4. "Can you give me workforce data for asia?"
![alt text](image.png)

View file

@ -8,8 +8,8 @@ listener:
llm_providers:
- name: OpenAI
provider: openai
access_key: OPENAI_API_KEY
model: gpt-4o
access_key: $OPENAI_API_KEY
model: gpt-4o-mini
default: true
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
@ -18,21 +18,17 @@ endpoints:
# value could be ip address or a hostname with port
# this could also be a list of endpoints for load balancing
# for example endpoint: [ ip1:port, ip2:port ]
endpoint: host.docker.internal:18083
endpoint: host.docker.internal:18080
# max time to wait for a connection to be established
connect_timeout: 0.005s
# default system prompt used by all prompt targets
system_prompt: |
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workfoce planning. NOTHING ELSE. When you get data in json format, offer some summary but don't be too verbose.
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workforce planning. Use following rules when responding,
- when you get data in json format, offer some summary but don't be too verbose
- be concise, to the point and do not over analyze the data
prompt_targets:
- name: hr_qa
endpoint:
name: app_server
path: /agent/hr_qa
description: Handle general Q/A related to HR.
default: true
- name: workforce
description: Get workforce data like headcount and satisfacton levels by region and staffing type
endpoint:
@ -41,16 +37,16 @@ prompt_targets:
parameters:
- name: staffing_type
type: str
description: Staffing type like contract, fte or agency
description: specific category or nature of employment used by an organization like fte, contract and agency
required: true
- name: region
type: str
required: true
description: Geographical region for which you want workforce data like asia, europe, americas.
- name: point_in_time
- name: data_snapshot_days_ago
type: int
required: false
description: the point in time for which to retrieve data. For e.g 0 days ago, 30 days ago, etc.
description: the snapshot day for which you want workforce data.
- name: slack_message
endpoint:
name: app_server

View file

@ -4,22 +4,14 @@ services:
context: .
environment:
- SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN:-None}
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
volumes:
- ./arch_config.yaml:/app/arch_config.yaml
- ../shared/chatbot_ui/common.py:/app/common.py
ports:
- "18083:80"
- "18080:80"
healthcheck:
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
interval: 5s
retries: 20
chatbot_ui:
build:
context: ../../chatbot_ui
ports:
- "18080:8080"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml

Binary file not shown.

Before

Width:  |  Height:  |  Size: 456 KiB

After

Width:  |  Height:  |  Size: 549 KiB

Before After
Before After

View file

@ -1,28 +1,39 @@
import os
import json
import pandas as pd
import gradio as gr
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Optional
from enum import Enum
from typing import List, Optional, Tuple
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from openai import OpenAI
from common import create_gradio_app
app = FastAPI()
workforce_data_df = None
demo_description = """This demo showcases how the **Arch** can be used to build an
HR agent to manage workforce-related inquiries, workforce planning, and communication via Slack.
It intelligently routes incoming prompts to the correct targets, providing concise and useful responses
tailored for HR and workforce decision-making. """
with open("workforce_data.json") as file:
workforce_data = json.load(file)
workforce_data_df = pd.json_normalize(
workforce_data, record_path=["regions"], meta=["point_in_time", "satisfaction"]
workforce_data,
record_path=["regions"],
meta=["data_snapshot_days_ago", "satisfaction"],
)
# Define the request model
class WorkforceRequset(BaseModel):
class WorkforceRequest(BaseModel):
region: str
staffing_type: str
point_in_time: Optional[int] = None
data_snapshot_days_ago: Optional[int] = None
class SlackRequest(BaseModel):
@ -36,25 +47,6 @@ class WorkforceResponse(BaseModel):
satisfaction: float
# Post method for device summary
@app.post("/agent/workforce")
def get_workforce(request: WorkforceRequset):
"""
Endpoint to workforce data by region, staffing type at a given point in time.
"""
region = request.region.lower()
staffing_type = request.staffing_type.lower()
point_in_time = request.point_in_time if request.point_in_time else 0
response = {
"region": region,
"staffing_type": f"Staffing agency: {staffing_type}",
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)][staffing_type].values[0])}",
"satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['point_in_time']==point_in_time)]['satisfaction'].values[0])}",
}
return response
@app.post("/agent/slack_message")
def send_slack_message(request: SlackRequest):
"""
@ -80,27 +72,38 @@ def send_slack_message(request: SlackRequest):
print(f"Error sending message: {e.response['error']}")
@app.post("/agent/hr_qa")
async def general_hr_qa():
# Post method for device summary
@app.post("/agent/workforce")
def get_workforce(request: WorkforceRequest):
"""
This method handles Q/A related to general issues in HR.
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
Endpoint to workforce data by region, staffing type at a given point in time.
"""
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "I am a helpful HR agent, and I can help you plan for workforce related questions",
},
"finish_reason": "completed",
"index": 0,
}
],
"model": "hr_agent",
"usage": {"completion_tokens": 0},
}
region = request.region.lower()
staffing_type = request.staffing_type.lower()
data_snapshot_days_ago = (
request.data_snapshot_days_ago
if request.data_snapshot_days_ago
else 0 # this param is not required.
)
response = {
"region": region,
"staffing_type": f"Staffing agency: {staffing_type}",
"headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)][staffing_type].values[0])}",
"satisfaction": f"Satisfaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}",
}
return response
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
client = OpenAI(
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
)
gr.mount_gradio_app(
app, create_gradio_app(demo_description, client), path="/agent/chat"
)
if __name__ == "__main__":
app.run(debug=True)

View file

@ -1,6 +1,13 @@
fastapi
uvicorn
pydantic
slack-sdk
typing
pandas
gradio==5.3.0
async_timeout==4.0.3
loguru==0.7.2
asyncio==3.4.3
httpx==0.27.0
python-dotenv==1.0.1
pydantic==2.8.2
openai==1.51.0

View file

@ -1,6 +1,6 @@
[
{
"point_in_time": 0,
"data_snapshot_days_ago": 0,
"regions": [
{ "region": "asia", "contract": 100, "fte": 150, "agency": 2000 },
{ "region": "europe", "contract": 80, "fte": 120, "agency": 2500 },
@ -9,7 +9,7 @@
"satisfaction": 3.5
},
{
"point_in_time": 30,
"data_snapshot_days_ago": 30,
"regions": [
{ "region": "asia", "contract": 110, "fte": 155, "agency": 1000 },
{ "region": "europe", "contract": 85, "fte": 130, "agency": 1600 },
@ -18,7 +18,7 @@
"satisfaction": 4.0
},
{
"point_in_time": 60,
"data_snapshot_days_ago": 60,
"regions": [
{ "region": "asia", "contract": 115, "fte": 160, "agency": 500 },
{ "region": "europe", "contract": 90, "fte": 140, "agency": 700 },

View file

@ -10,7 +10,7 @@ system_prompt: |
llm_providers:
- name: OpenAI
provider: openai
access_key: OPENAI_API_KEY
access_key: $OPENAI_API_KEY
model: gpt-4o
default: true

View file

@ -12,7 +12,7 @@ services:
chatbot_ui:
build:
context: ../../chatbot_ui
context: ../shared/chatbot_ui
dockerfile: Dockerfile
ports:
- "18080:8080"

View file

@ -24,7 +24,7 @@ The assistant can perform several key operations, including rebooting devices, a
```sh
sh run_demo.sh
```
3. Navigate to http://localhost:18080/
3. Navigate to http://localhost:18080/agent/chat
4. Tell me what can you do for me?"
# Observability
@ -39,4 +39,4 @@ Arch gateway publishes stats endpoint at http://localhost:19901/stats. In this d
Here is sample interaction
<img width="575" alt="image" src="https://github.com/user-attachments/assets/25d40f46-616e-41ea-be8e-1623055c84ec">
![alt text](image.png)

View file

@ -8,36 +8,17 @@ listener:
llm_providers:
- name: OpenAI
provider: openai
access_key: OPENAI_API_KEY
access_key: $OPENAI_API_KEY
model: gpt-3.5-turbo
default: true
# default system prompt used by all prompt targets
system_prompt: |
You are a network assistant that just offers facts; not advice on manufacturers or purchasing decisions.
You are a network assistant that helps operators with a better understanding of network traffic flow and perform actions on networking operations. No advice on manufacturers or purchasing decisions.
prompt_targets:
- name: reboot_devices
description: Reboot specific devices or device groups
endpoint:
name: app_server
path: /agent/device_reboot
parameters:
- name: device_ids
type: list
description: A list of device identifiers (IDs) to reboot.
required: true
- name: time_range
type: int
description: Optional time range in days for reboot operations. Defaults to 7.
- name: network_qa
endpoint:
name: app_server
path: /agent/network_summary
description: Handle general Q/A related to networking.
default: true
- name: device_summary
description: Retrieve statistics for specific devices within a time range
description: Retrieve network statistics for specific devices within a time range
endpoint:
name: app_server
path: /agent/device_summary
@ -46,9 +27,23 @@ prompt_targets:
type: list
description: A list of device identifiers (IDs) to retrieve statistics for.
required: true # device_ids are required to get device statistics
- name: time_range
- name: days
type: int
description: Time range in days for which to gather device statistics. Defaults to 7.
description: The number of days for which to gather device statistics.
default: "7"
- name: reboot_devices
description: Reboot a list of devices
endpoint:
name: app_server
path: /agent/device_reboot
parameters:
- name: device_ids
type: list
description: A list of device identifiers (IDs).
required: true
- name: days
type: int
description: A list of device identifiers (IDs)
default: "7"
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
@ -57,15 +52,6 @@ endpoints:
# value could be ip address or a hostname with port
# this could also be a list of endpoints for load balancing
# for example endpoint: [ ip1:port, ip2:port ]
endpoint: host.docker.internal:18083
endpoint: host.docker.internal:18080
# max time to wait for a connection to be established
connect_timeout: 0.005s
ratelimits:
- model: gpt-4
selector:
key: selector-key
value: selector-value
limit:
tokens: 1
unit: minute

View file

@ -2,24 +2,14 @@ services:
api_server:
build:
context: .
dockerfile: Dockerfile
environment:
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
volumes:
- ./arch_config.yaml:/app/arch_config.yaml
- ../shared/chatbot_ui/common.py:/app/common.py
ports:
- "18083:80"
- "18080:80"
healthcheck:
test: ["CMD", "curl" ,"http://localhost:80/healthz"]
interval: 5s
retries: 20
chatbot_ui:
build:
context: ../../chatbot_ui
dockerfile: Dockerfile
ports:
- "18080:8080"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml

Binary file not shown.

After

Width:  |  Height:  |  Size: 636 KiB

View file

@ -1,8 +1,15 @@
from openai import OpenAI
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Optional
from common import create_gradio_app
import gradio as gr
import os
app = FastAPI()
demo_description = """This demo illustrates how **Arch** can be used to perform function calling with network-related tasks.
In this demo, you act as a **network assistant** that provides factual information, without offering advice on manufacturers or purchasing decisions."""
# Define the request model
@ -88,27 +95,15 @@ def get_device_summary(request: DeviceSummaryRequest):
return DeviceSummaryResponse(statistics=statistics)
@app.post("/agent/network_summary")
async def policy_qa():
"""
This method handles Q/A related to general issues in networks.
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
"""
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "I am a helpful networking agent, and I can help you get status for network devices or reboot them",
},
"finish_reason": "completed",
"index": 0,
}
],
"model": "network_agent",
"usage": {"completion_tokens": 0},
}
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
client = OpenAI(
api_key="--",
base_url=CHAT_COMPLETION_ENDPOINT,
)
gr.mount_gradio_app(
app, create_gradio_app(demo_description, client), path="/agent/chat"
)
if __name__ == "__main__":
app.run(debug=True)

View file

@ -2,3 +2,12 @@ fastapi
uvicorn
pydantic
typing
pandas
gradio==5.3.0
async_timeout==4.0.3
loguru==0.7.2
asyncio==3.4.3
httpx==0.27.0
python-dotenv==1.0.1
pydantic==2.8.2
openai==1.51.0

View file

@ -0,0 +1,171 @@
import json
import logging
import os
import yaml
import gradio as gr
from typing import List, Optional, Tuple
from functools import partial
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
log = logging.getLogger(__name__)
GRADIO_CSS_STYLE = """
.json-container {
height: 80vh !important;
overflow-y: auto !important;
}
.chatbot {
height: calc(80vh - 100px) !important;
overflow-y: auto !important;
}
footer {visibility: hidden}
"""
def chat(
query: Optional[str],
conversation: Optional[List[Tuple[str, str]]],
history: List[dict],
client,
):
history.append({"role": "user", "content": query})
try:
response = client.chat.completions.create(
# we select model from arch_config file
model="--",
messages=history,
temperature=1.0,
stream=True,
)
except Exception as e:
# remove last user message in case of exception
history.pop()
log.info("Error calling gateway API: {}".format(e))
raise gr.Error("Error calling gateway API: {}".format(e))
conversation.append((query, ""))
for chunk in response:
tokens = process_stream_chunk(chunk, history)
if tokens:
conversation[-1] = (
conversation[-1][0],
conversation[-1][1] + tokens,
)
yield "", conversation, history
def create_gradio_app(demo_description, client):
with gr.Blocks(
theme=gr.themes.Default(
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "Arial", "sans-serif"]
),
fill_height=True,
css=GRADIO_CSS_STYLE,
) as demo:
with gr.Row(equal_height=True):
history = gr.State([])
with gr.Column(scale=1):
gr.Markdown(demo_description),
with gr.Accordion("Available Tools/APIs", open=True):
with gr.Column(scale=1):
gr.JSON(
value=get_prompt_targets(),
show_indices=False,
elem_classes="json-container",
min_height="80vh",
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Arch Chatbot",
elem_classes="chatbot",
)
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
autofocus=True,
elem_classes="textbox",
)
chat_with_client = partial(chat, client=client)
textbox.submit(
chat_with_client,
[textbox, chatbot, history],
[textbox, chatbot, history],
)
return demo
def process_stream_chunk(chunk, history):
delta = chunk.choices[0].delta
if delta.role and delta.role != history[-1]["role"]:
# create new history item if role changes
# this is likely due to arch tool call and api response
history.append({"role": delta.role})
history[-1]["model"] = chunk.model
# append tool calls to history if there are any in the chunk
if delta.tool_calls:
history[-1]["tool_calls"] = delta.tool_calls
if delta.content:
# append content to the last history item
history[-1]["content"] = history[-1].get("content", "") + delta.content
# yield content if it is from assistant
if history[-1]["role"] == "assistant":
return delta.content
return None
def convert_prompt_target_to_openai_format(target):
tool = {
"description": target["description"],
"parameters": {"type": "object", "properties": {}, "required": []},
}
if "parameters" in target:
for param_info in target["parameters"]:
parameter = {
"type": param_info["type"],
"description": param_info["description"],
}
for key in ["default", "format", "enum", "items", "minimum", "maximum"]:
if key in param_info:
parameter[key] = param_info[key]
tool["parameters"]["properties"][param_info["name"]] = parameter
required = param_info.get("required", False)
if required:
tool["parameters"]["required"].append(param_info["name"])
return {"name": target["name"], "info": tool}
def get_prompt_targets():
try:
with open(os.getenv("ARCH_CONFIG", "arch_config.yaml"), "r") as file:
config = yaml.safe_load(file)
available_tools = []
for target in config["prompt_targets"]:
if not target.get("default", False):
available_tools.append(
convert_prompt_target_to_openai_format(target)
)
return {tool["name"]: tool["info"] for tool in available_tools}
except Exception as e:
log.info(e)
return None

View file

@ -8,7 +8,7 @@ listener:
llm_providers:
- name: OpenAI
provider: openai
access_key: OPENAI_API_KEY
access_key: $OPENAI_API_KEY
model: gpt-3.5-turbo
default: true

View file

@ -10,7 +10,7 @@ listener:
llm_providers:
- name: OpenAI
provider: openai
access_key: OPENAI_API_KEY
access_key: $OPENAI_API_KEY
model: gpt-4o
default: true
stream: true

View file

@ -10,7 +10,7 @@ listen:
llm_providers:
- name: OpenAI
provider: openai
access_key: OPENAI_API_KEY
access_key: $OPENAI_API_KEY
model: gpt-4o
default: true
stream: true

View file

@ -10,7 +10,7 @@ listener:
llm_providers:
- name: OpenAI
provider: openai
access_key: OPENAI_API_KEY
access_key: $OPENAI_API_KEY
model: gpt-4o
default: true
stream: true

View file

@ -32,7 +32,7 @@ endpoints:
llm_providers:
- name: OpenAI
provider: openai
access_key: OPENAI_API_KEY
access_key: $OPENAI_API_KEY
model: gpt-4o
default: true
stream: true
@ -47,7 +47,7 @@ llm_providers:
- name: Mistral8x7b
provider: mistral
access_key: MISTRAL_API_KEY
access_key: $MISTRAL_API_KEY
model: mistral-8x7b
- name: MistralLocal7b

View file

@ -6,6 +6,11 @@ log() {
echo "$timestamp: $message"
}
print_disk_usage() {
echo free disk space
df -h | grep "/$"
}
wait_for_healthz() {
local healthz_url="$1"
local timeout_seconds="${2:-30}" # Default timeout of 30 seconds
@ -28,6 +33,8 @@ wait_for_healthz() {
return 1
fi
print_disk_usage
sleep $sleep_between
done
}

39
e2e_tests/poetry.lock generated
View file

@ -455,13 +455,13 @@ files = [
[[package]]
name = "pytest"
version = "7.4.4"
version = "8.3.3"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"},
{file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"},
{file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"},
{file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"},
]
[package.dependencies]
@ -469,11 +469,11 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
pluggy = ">=1.5,<2"
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-cov"
@ -493,6 +493,23 @@ pytest = ">=4.6"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
[[package]]
name = "pytest-retry"
version = "1.6.3"
description = "Adds the ability to retry flaky tests in CI environments"
optional = false
python-versions = ">=3.9"
files = [
{file = "pytest_retry-1.6.3-py3-none-any.whl", hash = "sha256:e96f7df77ee70b0838d1085f9c3b8b5b7d74bf8947a0baf32e2b8c71b27683c8"},
{file = "pytest_retry-1.6.3.tar.gz", hash = "sha256:36ccfa11c8c8f9ddad5e20375182146d040c20c4a791745139c5a99ddf1b557d"},
]
[package.dependencies]
pytest = ">=7.0.0"
[package.extras]
dev = ["black", "flake8", "isort", "mypy"]
[[package]]
name = "pytest-sugar"
version = "1.0.0"
@ -535,13 +552,13 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "selenium"
version = "4.25.0"
version = "4.26.0"
description = "Official Python bindings for Selenium WebDriver"
optional = false
python-versions = ">=3.8"
files = [
{file = "selenium-4.25.0-py3-none-any.whl", hash = "sha256:3798d2d12b4a570bc5790163ba57fef10b2afee958bf1d80f2a3cf07c4141f33"},
{file = "selenium-4.25.0.tar.gz", hash = "sha256:95d08d3b82fb353f3c474895154516604c7f0e6a9a565ae6498ef36c9bac6921"},
{file = "selenium-4.26.0-py3-none-any.whl", hash = "sha256:48013f36e812de5b3948ef53d04e73f77bc923ee3e1d7d99eaf0618179081b99"},
{file = "selenium-4.26.0.tar.gz", hash = "sha256:f0780f85f10310aa5d085b81e79d73d3c93b83d8de121d0400d543a50ee963e8"},
]
[package.dependencies]
@ -699,4 +716,4 @@ h11 = ">=0.9.0,<1"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "6ae4fa6397091b87b63698201a08d7d97628ed65992d46514f118768b46b99ce"
content-hash = "a40015b90325879e50f82cca6a26a730d763cad26589671df798832d41c42db3"

View file

@ -9,11 +9,12 @@ package-mode = false
[tool.poetry.dependencies]
python = "^3.10"
pytest = "^7.3.1"
pytest = "^8.3.3"
requests = "^2.29.0"
selenium = "^4.11.2"
pytest-sugar = "^1.0.0"
deepdiff = "^8.0.1"
pytest-retry = "^1.6.3"
[tool.poetry.dev-dependencies]
pytest-cov = "^4.1.0"
@ -21,3 +22,6 @@ pytest-cov = "^4.1.0"
[tool.pytest.ini_options]
python_files = ["test*.py"]
addopts = ["-v", "-s"]
retries = 2
retry_delay = 0.5
cumulative_timing = false

View file

@ -4,6 +4,11 @@ set -e
. ./common_scripts.sh
print_disk_usage
mkdir -p ~/archgw_logs
touch ~/archgw_logs/modelserver.log
print_debug() {
log "Received signal to stop"
log "Printing debug logs for model_server"
@ -18,63 +23,62 @@ trap 'print_debug' INT TERM ERR
log starting > ../build.log
log building function_callling demo
log ===============================
log building and running function_callling demo
log ===========================================
cd ../demos/function_calling
docker compose build -q
log starting the function_calling demo
docker compose up -d
docker compose up api_server --build -d
cd -
log building model server
log =====================
print_disk_usage
log building and install model server
log =================================
cd ../model_server
poetry install 2>&1 >> ../build.log
log starting model server
log =====================
mkdir -p ~/archgw_logs
touch ~/archgw_logs/modelserver.log
poetry run archgw_modelserver restart &
poetry install
cd -
print_disk_usage
log building and installing archgw cli
log ==================================
cd ../arch/tools
sh build_cli.sh
cd -
print_disk_usage
log building docker image for arch gateway
log ======================================
cd ../
archgw build
cd -
print_disk_usage
log startup arch gateway with function calling demo
cd ..
tail -F ~/archgw_logs/modelserver.log &
model_server_tail_pid=$!
archgw down
archgw up demos/function_calling/arch_config.yaml
kill $model_server_tail_pid
cd -
log building llm and prompt gateway rust modules
log ============================================
cd ../arch
docker build -f Dockerfile .. -t katanemo/archgw -q
log starting the arch gateway service
log =================================
docker compose -f docker-compose.e2e.yaml down
log waiting for model service to be healthy
wait_for_healthz "http://localhost:51000/healthz" 300
kill $model_server_tail_pid
docker compose -f docker-compose.e2e.yaml up -d
log waiting for arch gateway service to be healthy
wait_for_healthz "http://localhost:10000/healthz" 60
log waiting for arch gateway service to be healthy
cd -
print_disk_usage
log running e2e tests
log =================
poetry install 2>&1 >> ../build.log
poetry install
poetry run pytest
log shutting down the arch gateway service
log ======================================
cd ../arch
docker compose -f docker-compose.e2e.yaml stop 2>&1 >> ../build.log
cd ../
archgw down
cd -
log shutting down the function_calling demo
log =======================================
cd ../demos/function_calling
docker compose down 2>&1 >> ../build.log
cd -
log shutting down the model server
log ==============================
cd ../model_server
poetry run archgw_modelserver stop 2>&1 >> ../build.log
docker compose down
cd -

BIN
image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 207 KiB

View file

@ -9,7 +9,7 @@
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"args": ["app.main:app","--reload", "--port", "8000"]
"args": ["app.main:app","--reload", "--port", "51000"]
}
]
}

View file

@ -1,3 +1,4 @@
import importlib
import sys
import os
import time
@ -7,6 +8,15 @@ import tempfile
import subprocess
import logging
def get_version():
try:
version = importlib.metadata.version("archgw_modelserver")
return version
except importlib.metadata.PackageNotFoundError:
return "version not found"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@ -14,6 +24,7 @@ logging.basicConfig(
log = logging.getLogger("model_server.cli")
log.setLevel(logging.INFO)
log.info(f"model server version: {get_version()}")
def run_server(port=51000):
@ -37,8 +48,9 @@ def run_server(port=51000):
def start_server(port=51000):
"""Start the Uvicorn server"""
log.info(
"Starting model server - loading some awesomeness, this may take some time :)"
"starting model server - loading some awesomeness, this may take some time :)"
)
process = subprocess.Popen(
[
"python",
@ -61,7 +73,7 @@ def start_server(port=51000):
log.info(f"Model server started with PID {process.pid}")
else:
# Add model_server boot-up logs
log.info("Model server - Didn't Sart In Time. Shutting Down")
log.info("model server - didn't start in time, shutting down")
process.terminate()

View file

@ -67,12 +67,16 @@ async def chat_completion(req: ChatMessage, res: Response):
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
)
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=const.arch_function_generation_params,
)
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=const.arch_function_generation_params,
)
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
tool_calls = const.arch_function_hanlder.extract_tool_calls(
resp.choices[0].message.content

1192
model_server/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw_modelserver"
version = "0.0.4"
version = "0.1.1"
description = "A model server for serving models"
authors = ["Katanemo Labs, Inc <archgw@katanemo.com>"]
license = "Apache 2.0"