Merge branch 'main' into metrics

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-12-09 10:36:04 -08:00
commit 0939c72bcc
56 changed files with 6168 additions and 1902 deletions

243
README.md
View file

@ -33,94 +33,184 @@ Engineered with purpose-built LLMs, Arch handles the critical but undifferentiat
To get in touch with us, please join our [discord server](https://discord.gg/pGZf2gcwEc). We will be monitoring that actively and offering support there.
## Demos
* [Weather Forecast](demos/weather_forecast/README.md) - Walk through of the core function calling capabilities of of arch gateway using weather forecasting service
* [Weather Forecast](demos/weather_forecast/README.md) - Walk through of the core function calling capabilities of arch gateway using weather forecasting service
* [Insurance Agent](demos/insurance_agent/README.md) - Build a full insurance agent with Arch
* [Network Agent](demos/network_agent/README.md) - Build a networking co-pilot/agent agent with Arch
## Quickstart
Follow this guide to learn how to quickly set up Arch and integrate it into your generative AI applications.
Follow this quickstart guide to use arch gateway to build a simple AI agent. Laster in the section we will see how you can Arch Gateway to manage access keys, provide unified access to upstream LLMs and to provide e2e observability.
### Prerequisites
Before you begin, ensure you have the following:
- `Docker` & `Python` installed on your system
- `API Keys` for LLM providers (if using external LLMs)
### Step 1: Install Arch
Arch's CLI allows you to manage and interact with the Arch gateway efficiently. To install the CLI, simply run the following command:
Tip: We recommend that developers create a new Python virtual environment to isolate dependencies before installing Arch. This ensures that archgw and its dependencies do not interfere with other packages on your system.
Make sure you have following utilities installed before proceeding further,
1. [Docker System](https://docs.docker.com/get-started/get-docker/) (v24)
2. [Docker compose](https://docs.docker.com/compose/install/) (v2.29)
3. [Python](https://www.python.org/downloads/) (v3.12)
4. [Poetry](https://python-poetry.org/docs/#installing-with-the-official-installer) (v1.8.3. *Note: only needed for local development*)
Arch's CLI allows you to manage and interact with the Arch gateway efficiently. To install the CLI, simply run the following command:
> [!TIP]
> We recommend that developers create a new Python virtual environment to isolate dependencies before installing Arch. This ensures that archgw and its dependencies do not interfere with other packages on your system.
```console
$ python -m venv venv
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
$ pip install archgw
$ pip install archgw==0.1.6
```
### Step 2: Configure Arch with your application
### Build AI Agent with Arch Gateway
Arch operates based on a configuration file where you can define LLM providers, prompt targets, guardrails, etc.
Below is an example configuration to get you started:
In following quickstart we will show you how easy it is to build AI agent with Arch gateway. We will build a currency exchange agent using following simple steps. For this demo we will use `https://api.frankfurter.dev/` to fetch latest price for currencies and assume USD as base currency.
#### Step 1. Create arch config file
Create `arch_config.yaml` file with following content,
```yaml
version: v0.1
listener:
address: 127.0.0.1
port: 8080 #If you configure port 443, you'll need to update the listener with tls_certificates
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
# Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way
llm_providers:
- name: OpenAI
provider: openai
- name: gpt-4o
access_key: $OPENAI_API_KEY
model: gpt-3.5-turbo
default: true
provider: openai
model: gpt-4o
# default system prompt used by all prompt targets
system_prompt: |
You are a network assistant that helps operators with a better understanding of network traffic flow and perform actions on networking operations. No advice on manufacturers or purchasing decisions.
You are a helpful assistant.
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
prompt_targets:
- name: device_summary
description: Retrieve network statistics for specific devices within a time range
endpoint:
name: app_server
path: /agent/device_summary
parameters:
- name: device_ids
type: list
description: A list of device identifiers (IDs) to retrieve statistics for.
required: true # device_ids are required to get device statistics
- name: days
type: int
description: The number of days for which to gather device statistics.
default: "7"
- name: currency_exchange
description: Get currency exchange rate from USD to other currencies
parameters:
- name: currency_symbol
description: the currency that needs conversion
required: true
type: str
in_path: true
endpoint:
name: frankfurther_api
path: /v1/latest?base=USD&symbols={currency_symbol}
system_prompt: |
You are a helpful assistant. Show me the currency symbol you want to convert from USD.
- name: get_supported_currencies
description: Get list of supported currencies for conversion
endpoint:
name: frankfurther_api
path: /v1/currencies
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
endpoints:
app_server:
# value could be ip address or a hostname with port
# this could also be a list of endpoints for load balancing
# for example endpoint: [ ip1:port, ip2:port ]
endpoint: host.docker.internal:18083
# max time to wait for a connection to be established
connect_timeout: 0.005s
frankfurther_api:
endpoint: api.frankfurter.dev:443
protocol: https
```
### Step 3: Using OpenAI Client with Arch as an Egress Gateway
Make outbound calls via Arch
#### Step 2. Start arch gateway with currency conversion config
```sh
$ archgw up arch_config.yaml
2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.1.5
...
2024-12-05 16:56:28,485 - cli.utils - INFO - Schema validation successful!
2024-12-05 16:56:28,485 - cli.main - INFO - Starging arch model server and arch gateway
...
2024-12-05 16:56:51,647 - cli.core - INFO - Container is healthy!
```
Once the gateway is up you can start interacting with at port 10000 using openai chat completion API.
Some of the sample queries you can ask could be `what is currency rate for gbp?` or `show me list of currencies for conversion`.
#### Step 3. Interacting with gateway using curl command
Here is a sample curl command you can use to interact,
```bash
$ curl --header 'Content-Type: application/json' \
--data '{"messages": [{"role": "user","content": "what is exchange rate for gbp"}]}' \
http://localhost:10000/v1/chat/completions | jq ".choices[0].message.content"
"As of the date provided in your context, December 5, 2024, the exchange rate for GBP (British Pound) from USD (United States Dollar) is 0.78558. This means that 1 USD is equivalent to 0.78558 GBP."
```
And to get list of supported currencies,
```bash
$ curl --header 'Content-Type: application/json' \
--data '{"messages": [{"role": "user","content": "show me list of currencies that are supported for conversion"}]}' \
http://localhost:10000/v1/chat/completions | jq ".choices[0].message.content"
"Here is a list of the currencies that are supported for conversion from USD, along with their symbols:\n\n1. AUD - Australian Dollar\n2. BGN - Bulgarian Lev\n3. BRL - Brazilian Real\n4. CAD - Canadian Dollar\n5. CHF - Swiss Franc\n6. CNY - Chinese Renminbi Yuan\n7. CZK - Czech Koruna\n8. DKK - Danish Krone\n9. EUR - Euro\n10. GBP - British Pound\n11. HKD - Hong Kong Dollar\n12. HUF - Hungarian Forint\n13. IDR - Indonesian Rupiah\n14. ILS - Israeli New Sheqel\n15. INR - Indian Rupee\n16. ISK - Icelandic Króna\n17. JPY - Japanese Yen\n18. KRW - South Korean Won\n19. MXN - Mexican Peso\n20. MYR - Malaysian Ringgit\n21. NOK - Norwegian Krone\n22. NZD - New Zealand Dollar\n23. PHP - Philippine Peso\n24. PLN - Polish Złoty\n25. RON - Romanian Leu\n26. SEK - Swedish Krona\n27. SGD - Singapore Dollar\n28. THB - Thai Baht\n29. TRY - Turkish Lira\n30. USD - United States Dollar\n31. ZAR - South African Rand\n\nIf you want to convert USD to any of these currencies, you can select the one you are interested in."
```
### Use Arch Gateway as LLM Router
#### Step 1. Create arch config file
Arch operates based on a configuration file where you can define LLM providers, prompt targets, guardrails, etc. Below is an example configuration that defines openai and mistral LLM providers.
Create `arch_config.yaml` file with following content:
```yaml
version: v0.1
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
llm_providers:
- name: gpt-4o
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-4o
default: true
- name: ministral-3b
access_key: $MISTRAL_API_KEY
provider: mistral
model: ministral-3b-latest
```
#### Step 2. Start arch gateway
Once the config file is created ensure that you have env vars setup for `MISTRAL_API_KEY` and `OPENAI_API_KEY` (or these are defined in `.env` file).
Start arch gateway,
```
$ archgw up arch_config.yaml
2024-12-05 11:24:51,288 - cli.main - INFO - Starting archgw cli version: 0.1.5
2024-12-05 11:24:51,825 - cli.utils - INFO - Schema validation successful!
2024-12-05 11:24:51,825 - cli.main - INFO - Starting arch model server and arch gateway
...
2024-12-05 11:25:16,131 - cli.core - INFO - Container is healthy!
```
### Step 3: Interact with LLM
#### Step 3.1: Using OpenAI python client
Make outbound calls via Arch gateway
```python
from openai import OpenAI
@ -143,12 +233,59 @@ print("OpenAI Response:", response.choices[0].message.content)
```
### [Observability](https://docs.archgw.com/guides/observability/observability.html)
#### Step 3.2: Using curl command
```
$ curl --header 'Content-Type: application/json' \
--data '{"messages": [{"role": "user","content": "What is the capital of France?"}]}' \
http://localhost:12000/v1/chat/completions
{
...
"model": "gpt-4o-2024-08-06",
"choices": [
{
...
"message": {
"role": "assistant",
"content": "The capital of France is Paris.",
},
}
],
...
}
```
You can override model selection using `x-arch-llm-provider-hint` header. For example if you want to use mistral using following curl command,
```
$ curl --header 'Content-Type: application/json' \
--header 'x-arch-llm-provider-hint: ministral-3b' \
--data '{"messages": [{"role": "user","content": "What is the capital of France?"}]}' \
http://localhost:12000/v1/chat/completions
{
...
"model": "ministral-3b-latest",
"choices": [
{
"message": {
"role": "assistant",
"content": "The capital of France is Paris. It is the most populous city in France and is known for its iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. Paris is also a major global center for art, fashion, gastronomy, and culture.",
},
...
}
],
...
}
```
## [Observability](https://docs.archgw.com/guides/observability/observability.html)
Arch is designed to support best-in class observability by supporting open standards. Please read our [docs](https://docs.archgw.com/guides/observability/observability.html) on observability for more details on tracing, metrics, and logs. The screenshot below is from our integration with Signoz (among others)
![alt text](docs/source/_static/img/tracing.png)
### Contribution
## Contribution
We would love feedback on our [Roadmap](https://github.com/orgs/katanemo/projects/1) and we welcome contributions to **Arch**!
Whether you're fixing bugs, adding new features, improving documentation, or creating tutorials, your help is much appreciated.
Please visit our [Contribution Guide](CONTRIBUTING.md) for more details

View file

@ -28,6 +28,11 @@ properties:
type: string
connect_timeout:
type: string
protocol:
type: string
enum:
- http
- https
additionalProperties: false
required:
- endpoint
@ -92,6 +97,8 @@ properties:
type: array
items:
type: string
in_path:
type: boolean
additionalProperties: false
required:
- name
@ -104,6 +111,11 @@ properties:
type: string
path:
type: string
http_method:
type: string
enum:
- GET
- POST
additionalProperties: false
required:
- name

View file

@ -500,7 +500,18 @@ static_resources:
socket_address:
address: {{ cluster.endpoint }}
port_value: {{ cluster.port }}
hostname: {{ cluster.name }}
hostname: {{ cluster.endpoint }}
{% if cluster.protocol == "https" %}
transport_socket:
name: envoy.transport_sockets.tls
typed_config:
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
sni: {{ cluster.endpoint }}
common_tls_context:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
{% endif %}
{% endfor %}
- name: arch_internal
connect_timeout: 5s

View file

@ -19,7 +19,7 @@ source venv/bin/activate
### Step 3: Run the build script
```bash
pip install archgw
pip install archgw==0.1.6
```
## Uninstall Instructions: archgw CLI
@ -48,7 +48,7 @@ source venv/bin/activate
### Step 3: Run the build script
```bash
sh build_cli.sh
poetry install
```
### Step 4: build Arch

View file

@ -49,7 +49,9 @@ def validate_and_render_schema():
if "prompt_targets" in config_yaml:
for prompt_target in config_yaml["prompt_targets"]:
name = prompt_target.get("endpoint", {}).get("name", "")
name = prompt_target.get("endpoint", {}).get("name", None)
if not name:
continue
if name not in inferred_clusters:
inferred_clusters[name] = {
"name": name,

View file

@ -196,7 +196,7 @@ def start_arch_modelserver():
subprocess.run(
["archgw_modelserver", "restart"], check=True, start_new_session=True
)
log.info("Successfull ran model_server")
log.info("Successfully ran model_server")
except subprocess.CalledProcessError as e:
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1)
@ -212,7 +212,7 @@ def stop_arch_modelserver():
["archgw_modelserver", "stop"],
check=True,
)
log.info("Successfull stopped the archgw model_server")
log.info("Successfully stopped the archgw model_server")
except subprocess.CalledProcessError as e:
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1)

View file

@ -171,7 +171,7 @@ def up(file, path, service):
log.info(f"Error: {str(e)}")
sys.exit(1)
log.info("Starging arch model server and arch gateway")
log.info("Starting arch model server and arch gateway")
# Set the ARCH_CONFIG_FILE environment variable
env_stage = {}

3882
arch/tools/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw"
version = "0.1.5"
version = "0.1.6"
description = "Python-based CLI tool to manage Arch Gateway."
authors = ["Katanemo Labs, Inc."]
packages = [
@ -10,7 +10,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.12"
archgw_modelserver = "0.1.5"
archgw_modelserver = "0.1.6"
pyyaml = "^6.0.2"
pydantic = "^2.10.1"
click = "^8.1.7"

View file

@ -1,7 +1,23 @@
use common::{
common_types::open_ai::Message,
use std::collections::HashMap;
use crate::{
api::open_ai::Message,
consts::{ARCH_MODEL_PREFIX, HALLUCINATION_TEMPLATE, USER_ROLE},
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationRequest {
pub prompt: String,
pub parameters: HashMap<String, String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationResponse {
pub params_scores: HashMap<String, f64>,
pub model: String,
}
pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
let mut arch_assistant = false;
@ -42,7 +58,7 @@ pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
#[cfg(test)]
mod test {
use common::common_types::open_ai::Message;
use crate::api::open_ai::Message;
use pretty_assertions::assert_eq;
use super::extract_messages_for_hallucination;

View file

@ -0,0 +1,4 @@
pub mod hallucination;
pub mod open_ai;
pub mod prompt_guard;
pub mod zero_shot;

View file

@ -0,0 +1,653 @@
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsRequest {
#[serde(default)]
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatCompletionTool>>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionTool {
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameters {
pub properties: HashMap<String, FunctionParameter>,
}
impl Serialize for FunctionParameters {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// select all requried parameters
let required: Vec<&String> = self
.properties
.iter()
.filter(|(_, v)| v.required.unwrap_or(false))
.map(|(k, _)| k)
.collect();
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("properties", &self.properties)?;
if !required.is_empty() {
map.serialize_entry("required", &required)?;
}
map.end()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameter {
#[serde(rename = "type")]
#[serde(default = "ParameterType::string")]
pub parameter_type: ParameterType,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
}
impl Serialize for FunctionParameter {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(5))?;
map.serialize_entry("type", &self.parameter_type)?;
map.serialize_entry("description", &self.description)?;
if let Some(enum_values) = &self.enum_values {
map.serialize_entry("enum", enum_values)?;
}
if let Some(default) = &self.default {
map.serialize_entry("default", default)?;
}
map.end()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ParameterType {
#[serde(rename = "int")]
Int,
#[serde(rename = "float")]
Float,
#[serde(rename = "bool")]
Bool,
#[serde(rename = "str")]
String,
#[serde(rename = "list")]
List,
#[serde(rename = "dict")]
Dict,
}
impl From<String> for ParameterType {
fn from(s: String) -> Self {
match s.as_str() {
"int" => ParameterType::Int,
"integer" => ParameterType::Int,
"float" => ParameterType::Float,
"bool" => ParameterType::Bool,
"boolean" => ParameterType::Bool,
"str" => ParameterType::String,
"string" => ParameterType::String,
"list" => ParameterType::List,
"array" => ParameterType::List,
"dict" => ParameterType::Dict,
"dictionary" => ParameterType::Dict,
_ => ParameterType::String,
}
}
}
impl ParameterType {
pub fn string() -> ParameterType {
ParameterType::String
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub finish_reason: String,
pub index: usize,
pub message: Message,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionCallDetail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDetail {
pub name: String,
pub arguments: HashMap<String, Value>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ToolCallState {
pub key: String,
pub message: Option<Message>,
pub tool_call: FunctionCallDetail,
pub tool_response: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ArchState {
ToolCall(Vec<ToolCallState>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
pub usage: Option<Usage>,
pub choices: Vec<Choice>,
pub model: String,
pub metadata: Option<HashMap<String, String>>,
}
impl ChatCompletionsResponse {
pub fn new(message: String) -> Self {
ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
role: ASSISTANT_ROLE.to_string(),
content: Some(message),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
},
index: 0,
finish_reason: "done".to_string(),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
metadata: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub completion_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionStreamResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
pub choices: Vec<ChunkChoice>,
}
impl ChatCompletionStreamResponse {
pub fn new(
response: Option<String>,
role: Option<String>,
model: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
) -> Self {
ChatCompletionStreamResponse {
model,
choices: vec![ChunkChoice {
delta: Delta {
role,
content: response,
tool_calls,
model: None,
tool_call_id: None,
},
finish_reason: None,
}],
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ChatCompletionChunkResponseError {
#[error("failed to deserialize")]
Deserialization(#[from] serde_json::Error),
#[error("empty content in data chunk")]
EmptyContent,
#[error("no chunks present")]
NoChunks,
}
pub struct ChatCompletionStreamResponseServerEvents {
pub events: Vec<ChatCompletionStreamResponse>,
}
impl Display for ChatCompletionStreamResponseServerEvents {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let tokens_str = self
.events
.iter()
.map(|response_chunk| {
if response_chunk.choices.is_empty() {
return "".to_string();
}
response_chunk.choices[0]
.delta
.content
.clone()
.unwrap_or("".to_string())
})
.collect::<Vec<String>>()
.join("");
write!(f, "{}", tokens_str)
}
}
impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents {
type Error = ChatCompletionChunkResponseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let response_chunks: VecDeque<ChatCompletionStreamResponse> = value
.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line.get(6..).unwrap())
.filter(|data_chunk| *data_chunk != "[DONE]")
.map(serde_json::from_str::<ChatCompletionStreamResponse>)
.collect::<Result<VecDeque<ChatCompletionStreamResponse>, _>>()?;
Ok(ChatCompletionStreamResponseServerEvents {
events: response_chunks.into(),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkChoice {
pub delta: Delta,
// TODO: could this be an enum?
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Delta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
}
#[cfg(test)]
mod test {
use super::{ChatCompletionStreamResponseServerEvents, Message};
use pretty_assertions::assert_eq;
use std::collections::HashMap;
const TOOL_SERIALIZED: &str = r#"{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "What city do you want to know the weather for?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "weather_forecast",
"description": "function to retrieve weather forecast",
"parameters": {
"properties": {
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
},
"required": [
"city"
]
}
}
}
],
"stream": true,
"stream_options": {
"include_usage": true
}
}"#;
#[test]
fn test_tool_type_request() {
use super::{
ChatCompletionTool, ChatCompletionsRequest, FunctionDefinition, FunctionParameter,
FunctionParameters, ParameterType, StreamOptions, ToolType,
};
let mut properties = HashMap::new();
properties.insert(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
},
);
let function_definition = FunctionDefinition {
name: "weather_forecast".to_string(),
description: "function to retrieve weather forecast".to_string(),
parameters: FunctionParameters { properties },
};
let chat_completions_request = ChatCompletionsRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: Some("What city do you want to know the weather for?".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: Some(vec![ChatCompletionTool {
tool_type: ToolType::Function,
function: function_definition,
}]),
stream: true,
stream_options: Some(StreamOptions {
include_usage: true,
}),
metadata: None,
};
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
println!("{}", serialized);
assert_eq!(TOOL_SERIALIZED, serialized);
}
#[test]
fn test_parameter_types() {
use super::{FunctionParameter, ParameterType};
const PARAMETER_SERIALZIED: &str = r#"{
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
}"#;
let properties = HashMap::from([(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
},
)]);
let serialized = serde_json::to_string_pretty(&properties).unwrap();
assert_eq!(PARAMETER_SERIALZIED, serialized);
// ensure that if type is missing it is set to string
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
{
"city": {
"description": "city for weather forecast"
}
}"#;
let missing_type_deserialized: HashMap<String, FunctionParameter> =
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
println!("{:?}", missing_type_deserialized);
assert_eq!(
missing_type_deserialized
.get("city")
.unwrap()
.parameter_type,
ParameterType::String
);
}
#[test]
fn stream_chunk_parse() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]}
"#;
let sever_events =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 5);
assert_eq!(
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
""
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"Hello"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"!"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" How"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" can"
);
assert_eq!(sever_events.to_string(), "Hello! How can");
}
#[test]
fn stream_chunk_parse_done() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
data: [DONE]
"#;
let sever_events: ChatCompletionStreamResponseServerEvents =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 6);
assert_eq!(
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" I"
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" assist"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" you"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" today"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"?"
);
assert_eq!(sever_events.events[5].choices[0].delta.content, None);
assert_eq!(sever_events.to_string(), " I assist you today?");
}
#[test]
fn stream_chunk_parse_mistral() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}}
data: [DONE]
"#;
let sever_events: ChatCompletionStreamResponseServerEvents =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 11);
assert_eq!(
sever_events.to_string(),
"Hello! How can I assist you today?"
);
}
}

View file

@ -0,0 +1,25 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PromptGuardTask {
#[serde(rename = "jailbreak")]
Jailbreak,
#[serde(rename = "toxicity")]
Toxicity,
#[serde(rename = "both")]
Both,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptGuardRequest {
pub input: String,
pub task: PromptGuardTask,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptGuardResponse {
pub toxic_prob: Option<f64>,
pub jailbreak_prob: Option<f64>,
pub toxic_verdict: Option<bool>,
pub jailbreak_verdict: Option<bool>,
}

View file

@ -0,0 +1,18 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationRequest {
pub input: String,
pub labels: Vec<String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationResponse {
pub predicted_class: String,
pub predicted_class_score: f64,
pub scores: HashMap<String, f64>,
pub model: String,
}

View file

@ -1,743 +0,0 @@
use crate::configuration::PromptTarget;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub prompt_target: PromptTarget,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum EmbeddingType {
Name,
Description,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
pub payload: HashMap<String, String>,
pub vector: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoreVectorEmbeddingsRequest {
pub points: Vec<VectorPoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointResult {
pub id: String,
pub version: i32,
pub score: f64,
pub payload: HashMap<String, String>,
}
pub mod open_ai {
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsRequest {
#[serde(default)]
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatCompletionTool>>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionTool {
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameters {
pub properties: HashMap<String, FunctionParameter>,
}
impl Serialize for FunctionParameters {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// select all requried parameters
let required: Vec<&String> = self
.properties
.iter()
.filter(|(_, v)| v.required.unwrap_or(false))
.map(|(k, _)| k)
.collect();
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("properties", &self.properties)?;
if !required.is_empty() {
map.serialize_entry("required", &required)?;
}
map.end()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameter {
#[serde(rename = "type")]
#[serde(default = "ParameterType::string")]
pub parameter_type: ParameterType,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
}
impl Serialize for FunctionParameter {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(5))?;
map.serialize_entry("type", &self.parameter_type)?;
map.serialize_entry("description", &self.description)?;
if let Some(enum_values) = &self.enum_values {
map.serialize_entry("enum", enum_values)?;
}
if let Some(default) = &self.default {
map.serialize_entry("default", default)?;
}
map.end()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ParameterType {
#[serde(rename = "int")]
Int,
#[serde(rename = "float")]
Float,
#[serde(rename = "bool")]
Bool,
#[serde(rename = "str")]
String,
#[serde(rename = "list")]
List,
#[serde(rename = "dict")]
Dict,
}
impl From<String> for ParameterType {
fn from(s: String) -> Self {
match s.as_str() {
"int" => ParameterType::Int,
"integer" => ParameterType::Int,
"float" => ParameterType::Float,
"bool" => ParameterType::Bool,
"boolean" => ParameterType::Bool,
"str" => ParameterType::String,
"string" => ParameterType::String,
"list" => ParameterType::List,
"array" => ParameterType::List,
"dict" => ParameterType::Dict,
"dictionary" => ParameterType::Dict,
_ => ParameterType::String,
}
}
}
impl ParameterType {
pub fn string() -> ParameterType {
ParameterType::String
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub finish_reason: String,
pub index: usize,
pub message: Message,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionCallDetail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDetail {
pub name: String,
pub arguments: HashMap<String, Value>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ToolCallState {
pub key: String,
pub message: Option<Message>,
pub tool_call: FunctionCallDetail,
pub tool_response: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ArchState {
ToolCall(Vec<ToolCallState>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
pub usage: Option<Usage>,
pub choices: Vec<Choice>,
pub model: String,
pub metadata: Option<HashMap<String, String>>,
}
impl ChatCompletionsResponse {
pub fn new(message: String) -> Self {
ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
role: ASSISTANT_ROLE.to_string(),
content: Some(message),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
},
index: 0,
finish_reason: "done".to_string(),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
metadata: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub completion_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionStreamResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
pub choices: Vec<ChunkChoice>,
}
impl ChatCompletionStreamResponse {
pub fn new(
response: Option<String>,
role: Option<String>,
model: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
) -> Self {
ChatCompletionStreamResponse {
model,
choices: vec![ChunkChoice {
delta: Delta {
role,
content: response,
tool_calls,
model: None,
tool_call_id: None,
},
finish_reason: None,
}],
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ChatCompletionChunkResponseError {
#[error("failed to deserialize")]
Deserialization(#[from] serde_json::Error),
#[error("empty content in data chunk")]
EmptyContent,
#[error("no chunks present")]
NoChunks,
}
pub struct ChatCompletionStreamResponseServerEvents {
pub events: Vec<ChatCompletionStreamResponse>,
}
impl Display for ChatCompletionStreamResponseServerEvents {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let tokens_str = self
.events
.iter()
.map(|response_chunk| {
if response_chunk.choices.is_empty() {
return "".to_string();
}
response_chunk.choices[0]
.delta
.content
.clone()
.unwrap_or("".to_string())
})
.collect::<Vec<String>>()
.join("");
write!(f, "{}", tokens_str)
}
}
impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents {
type Error = ChatCompletionChunkResponseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let response_chunks: VecDeque<ChatCompletionStreamResponse> = value
.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line.get(6..).unwrap())
.filter(|data_chunk| *data_chunk != "[DONE]")
.map(serde_json::from_str::<ChatCompletionStreamResponse>)
.collect::<Result<VecDeque<ChatCompletionStreamResponse>, _>>()?;
Ok(ChatCompletionStreamResponseServerEvents {
events: response_chunks.into(),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkChoice {
pub delta: Delta,
// TODO: could this be an enum?
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Delta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationRequest {
pub input: String,
pub labels: Vec<String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationResponse {
pub predicted_class: String,
pub predicted_class_score: f64,
pub scores: HashMap<String, f64>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationRequest {
pub prompt: String,
pub parameters: HashMap<String, String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationResponse {
pub params_scores: HashMap<String, f64>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PromptGuardTask {
#[serde(rename = "jailbreak")]
Jailbreak,
#[serde(rename = "toxicity")]
Toxicity,
#[serde(rename = "both")]
Both,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptGuardRequest {
pub input: String,
pub task: PromptGuardTask,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptGuardResponse {
pub toxic_prob: Option<f64>,
pub jailbreak_prob: Option<f64>,
pub toxic_verdict: Option<bool>,
pub jailbreak_verdict: Option<bool>,
}
#[cfg(test)]
mod test {
use crate::common_types::open_ai::{ChatCompletionStreamResponseServerEvents, Message};
use pretty_assertions::assert_eq;
use std::collections::HashMap;
const TOOL_SERIALIZED: &str = r#"{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "What city do you want to know the weather for?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "weather_forecast",
"description": "function to retrieve weather forecast",
"parameters": {
"properties": {
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
},
"required": [
"city"
]
}
}
}
],
"stream": true,
"stream_options": {
"include_usage": true
}
}"#;
#[test]
fn test_tool_type_request() {
use super::open_ai::{
ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType,
};
let mut properties = HashMap::new();
properties.insert(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
},
);
let function_definition = FunctionDefinition {
name: "weather_forecast".to_string(),
description: "function to retrieve weather forecast".to_string(),
parameters: super::open_ai::FunctionParameters { properties },
};
let chat_completions_request = ChatCompletionsRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: Some("What city do you want to know the weather for?".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: Some(vec![super::open_ai::ChatCompletionTool {
tool_type: ToolType::Function,
function: function_definition,
}]),
stream: true,
stream_options: Some(super::open_ai::StreamOptions {
include_usage: true,
}),
metadata: None,
};
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
println!("{}", serialized);
assert_eq!(TOOL_SERIALIZED, serialized);
}
#[test]
fn test_parameter_types() {
use super::open_ai::{FunctionParameter, ParameterType};
const PARAMETER_SERIALZIED: &str = r#"{
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
}"#;
let properties = HashMap::from([(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
},
)]);
let serialized = serde_json::to_string_pretty(&properties).unwrap();
assert_eq!(PARAMETER_SERIALZIED, serialized);
// ensure that if type is missing it is set to string
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
{
"city": {
"description": "city for weather forecast"
}
}"#;
let missing_type_deserialized: HashMap<String, FunctionParameter> =
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
println!("{:?}", missing_type_deserialized);
assert_eq!(
missing_type_deserialized
.get("city")
.unwrap()
.parameter_type,
ParameterType::String
);
}
#[test]
fn stream_chunk_parse() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]}
"#;
let sever_events =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 5);
assert_eq!(
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
""
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"Hello"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"!"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" How"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" can"
);
assert_eq!(sever_events.to_string(), "Hello! How can");
}
#[test]
fn stream_chunk_parse_done() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
data: [DONE]
"#;
let sever_events: ChatCompletionStreamResponseServerEvents =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 6);
assert_eq!(
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" I"
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" assist"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" you"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" today"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"?"
);
assert_eq!(sever_events.events[5].choices[0].delta.content, None);
assert_eq!(sever_events.to_string(), " I assist you today?");
}
#[test]
fn stream_chunk_parse_mistral() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}}
data: [DONE]
"#;
let sever_events: ChatCompletionStreamResponseServerEvents =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 11);
assert_eq!(
sever_events.to_string(),
"Hello! How can I assist you today?"
);
}
}

View file

@ -2,6 +2,22 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
pub listener: Listener,
pub endpoints: Option<HashMap<String, Endpoint>>,
pub llm_providers: Vec<LlmProvider>,
pub overrides: Option<Overrides>,
pub system_prompt: Option<String>,
pub prompt_guards: Option<PromptGuards>,
pub prompt_targets: Option<Vec<PromptTarget>>,
pub error_target: Option<ErrorTargetDetail>,
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
@ -22,22 +38,6 @@ pub enum GatewayMode {
Prompt,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
pub listener: Listener,
pub endpoints: Option<HashMap<String, Endpoint>>,
pub llm_providers: Vec<LlmProvider>,
pub overrides: Option<Overrides>,
pub system_prompt: Option<String>,
pub prompt_guards: Option<PromptGuards>,
pub prompt_targets: Option<Vec<PromptTarget>>,
pub error_target: Option<ErrorTargetDetail>,
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorTargetDetail {
pub endpoint: Option<EndpointDetails>,
@ -179,8 +179,6 @@ impl Display for LlmProvider {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub endpoint: Option<String>,
// pub connect_timeout: Option<DurationString>,
// pub timeout: Option<DurationString>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -193,12 +191,33 @@ pub struct Parameter {
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
pub default: Option<String>,
pub in_path: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
pub enum HttpMethod {
#[default]
#[serde(rename = "GET")]
Get,
#[serde(rename = "POST")]
Post,
}
impl Display for HttpMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HttpMethod::Get => write!(f, "GET"),
HttpMethod::Post => write!(f, "POST"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EndpointDetails {
pub name: String,
pub path: Option<String>,
#[serde(rename = "http_method")]
pub method: Option<HttpMethod>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -1,6 +1,6 @@
use proxy_wasm::types::Status;
use crate::{common_types::open_ai::ChatCompletionChunkResponseError, ratelimit};
use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit};
#[derive(thiserror::Error, Debug)]
pub enum ClientError {

View file

@ -1,4 +1,4 @@
pub mod common_types;
pub mod api;
pub mod configuration;
pub mod consts;
pub mod embeddings;
@ -11,3 +11,4 @@ pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod tracing;
pub mod path;

82
crates/common/src/path.rs Normal file
View file

@ -0,0 +1,82 @@
use std::collections::HashMap;
pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> Result<String, String> {
let mut result = String::new();
let mut in_param = false;
let mut current_param = String::new();
for c in path.chars() {
if c == '{' {
in_param = true;
} else if c == '}' {
in_param = false;
let param_name = current_param.clone();
if let Some(value) = params.get(&param_name) {
result.push_str(value);
} else {
return Err(format!("Missing value for parameter `{}`", param_name));
}
current_param.clear();
} else {
if in_param {
current_param.push(c);
} else {
result.push(c);
}
}
}
Ok(result)
}
#[cfg(test)]
mod test {
#[test]
fn test_replace_path() {
let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
let params = vec![("cluster_name".to_string(), "test1".to_string())]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string())
);
let path = "/cluster.open-cluster-management.io/v1/managedclusters";
let params = vec![].into_iter().collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/cluster.open-cluster-management.io/v1/managedclusters".to_string())
);
let path = "/foo/{bar}/baz";
let params = vec![("bar".to_string(), "qux".to_string())]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/foo/qux/baz".to_string())
);
let path = "/foo/{bar}/baz/{qux}";
let params = vec![
("bar".to_string(), "qux".to_string()),
("qux".to_string(), "quux".to_string()),
]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Ok("/foo/qux/baz/quux".to_string())
);
let path = "/foo/{bar}/baz/{qux}";
let params = vec![("bar".to_string(), "qux".to_string())]
.into_iter()
.collect();
assert_eq!(
super::replace_params_in_path(path, &params),
Err("Missing value for parameter `qux`".to_string())
);
}
}

View file

@ -1,5 +1,5 @@
use crate::metrics::Metrics;
use common::common_types::open_ai::{
use common::api::open_ai::{
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
Message, StreamOptions,
};

View file

@ -0,0 +1,5 @@
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum EmbeddingType {
Name,
Description,
}

View file

@ -1,6 +1,6 @@
use crate::embeddings::EmbeddingType;
use crate::metrics::Metrics;
use crate::stream_context::StreamContext;
use common::common_types::EmbeddingType;
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::DEFAULT_EMBEDDING_MODEL;

View file

@ -1,10 +1,8 @@
use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext};
use common::{
common_types::{
open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
},
PromptGuardRequest, PromptGuardTask,
api::{
open_ai::{self, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest},
prompt_guard::{PromptGuardRequest, PromptGuardTask},
},
consts::{
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
@ -324,7 +322,7 @@ impl HttpContext for StreamContext {
),
];
let mut response_str = to_server_events(chunks);
let mut response_str = open_ai::to_server_events(chunks);
// append the original response from the model to the stream
response_str.push_str(&body_utf8);
self.set_http_response_body(0, body_size, response_str.as_bytes());

View file

@ -3,8 +3,8 @@ use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod context;
mod embeddings;
mod filter_context;
mod hallucination;
mod http_context;
mod metrics;
mod stream_context;

View file

@ -1,16 +1,18 @@
use crate::embeddings::EmbeddingType;
use crate::filter_context::EmbeddingsStore;
use crate::hallucination::extract_messages_for_hallucination;
use crate::metrics::Metrics;
use acap::cos;
use common::common_types::open_ai::{
use common::api::hallucination::{
extract_messages_for_hallucination, HallucinationClassificationRequest,
HallucinationClassificationResponse,
};
use common::api::open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool,
ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
};
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
PromptGuardResponse, ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use common::api::prompt_guard::PromptGuardResponse;
use common::api::zero_shot::{ZeroShotClassificationRequest, ZeroShotClassificationResponse};
use common::configuration::{Overrides, PromptGuards, PromptTarget, Tracing};
use common::consts::{
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
@ -30,6 +32,7 @@ use derivative::Derivative;
use http::StatusCode;
use log::{debug, info, trace, warn};
use proxy_wasm::traits::*;
use serde_yaml::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
@ -893,9 +896,37 @@ impl StreamContext {
let endpoint = prompt_target.endpoint.unwrap();
let path: String = endpoint.path.unwrap_or(String::from("/"));
// only add params that are of string, number and bool type
let url_params = tool_params
.iter()
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
.map(|(key, value)| match value {
Value::Number(n) => (key.clone(), n.to_string()),
Value::String(s) => (key.clone(), s.clone()),
Value::Bool(b) => (key.clone(), b.to_string()),
Value::Null => todo!(),
Value::Sequence(_) => todo!(),
Value::Mapping(_) => todo!(),
Value::Tagged(_) => todo!(),
})
.collect::<HashMap<String, String>>();
let path = match common::path::replace_params_in_path(&path, &url_params) {
Ok(path) => path,
Err(e) => {
return self.send_server_error(
ServerError::BadRequest {
why: format!("error replacing params in path: {}", e),
},
Some(StatusCode::BAD_REQUEST),
);
}
};
let http_method = endpoint.method.unwrap_or_default().to_string();
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()),
(":method", "POST"),
(":method", &http_method),
(":path", &path),
(":authority", endpoint.name.as_str()),
("content-type", "application/json"),

View file

@ -1,11 +1,14 @@
use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
use common::api::hallucination::HallucinationClassificationResponse;
use common::api::open_ai::{
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
};
use common::api::prompt_guard::PromptGuardResponse;
use common::api::zero_shot::ZeroShotClassificationResponse;
use common::configuration::Configuration;
use common::embeddings::{
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
Embedding,
};
use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use http::StatusCode;
use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
@ -360,6 +363,7 @@ prompt_targets:
endpoint:
name: api_server
path: /weather
http_method: POST
system_prompt: |
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
- Use farenheight for temperature

View file

@ -0,0 +1 @@
This demo shows how you can use a publicly hosted rest api and interact it using arch gateway.

View file

@ -0,0 +1,52 @@
version: v0.1
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
llm_providers:
- name: gpt-4o
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-4o
system_prompt: |
You are a helpful assistant.
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
prompt_targets:
- name: currency_exchange
description: Get currency exchange rate from USD to other currencies
parameters:
- name: currency_symbol
description: the currency that needs conversion
required: true
type: str
in_path: true
endpoint:
name: frankfurther_api
path: /v1/latest?base=USD&symbols={currency_symbol}
system_prompt: |
You are a helpful assistant. Show me the currency symbol you want to convert from USD.
- name: get_supported_currencies
description: Get list of supported currencies for conversion
endpoint:
name: frankfurther_api
path: /v1/currencies
endpoints:
frankfurther_api:
endpoint: api.frankfurter.dev:443
protocol: https
tracing:
random_sampling: 100
trace_arch_internal: true

View file

@ -0,0 +1,21 @@
services:
chatbot_ui:
build:
context: ../shared/chatbot_ui
ports:
- "18080:8080"
environment:
# this is only because we are running the sample app in the same docker container environemtn as archgw
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml
jaeger:
build:
context: ../shared/jaeger
ports:
- "16686:16686"
- "4317:4317"
- "4318:4318"

View file

@ -34,6 +34,7 @@ prompt_targets:
endpoint:
name: app_server
path: /agent/workforce
http_method: POST
parameters:
- name: staffing_type
type: str
@ -52,6 +53,7 @@ prompt_targets:
endpoint:
name: app_server
path: /agent/slack_message
http_method: POST
description: sends a slack message on a channel
parameters:
- name: slack_message

View file

@ -29,6 +29,7 @@ prompt_targets:
endpoint:
name: app_server
path: /policy/qa
http_method: POST
description: Handle general Q/A related to insurance.
default: true
@ -37,6 +38,7 @@ prompt_targets:
endpoint:
name: app_server
path: /policy/coverage
http_method: POST
parameters:
- name: policy_type
type: str
@ -48,6 +50,7 @@ prompt_targets:
endpoint:
name: app_server
path: /policy/initiate
http_method: POST
description: Start a policy coverage for car, boat, motorcycle or house.
parameters:
- name: policy_type
@ -64,6 +67,7 @@ prompt_targets:
endpoint:
name: app_server
path: /policy/claim
http_method: POST
description: Update the notes on the claim
parameters:
- name: claim_id
@ -79,6 +83,7 @@ prompt_targets:
endpoint:
name: app_server
path: /policy/deductible
http_method: POST
description: Update the deductible amount for a specific policy coverage.
parameters:
- name: policy_id

View file

@ -22,6 +22,7 @@ prompt_targets:
endpoint:
name: app_server
path: /agent/device_summary
http_method: POST
parameters:
- name: device_ids
type: list
@ -36,6 +37,7 @@ prompt_targets:
endpoint:
name: app_server
path: /agent/device_reboot
http_method: POST
parameters:
- name: device_ids
type: list

View file

@ -0,0 +1,5 @@
FROM otel/opentelemetry-collector:latest
COPY otel-collector-config.yaml /etc/otel-collector-config.yaml
ENTRYPOINT ["/otelcol", "--config=/etc/otel-collector-config.yaml"]

View file

@ -0,0 +1,24 @@
receivers:
otlp:
protocols:
grpc:
endpoint: 0.0.0.0:4317
http:
endpoint: 0.0.0.0:4318
exporters:
otlphttp:
endpoint: "https://logfire-api.pydantic.dev"
headers:
Authorization: "${LOGFIRE_API_KEY}"
processors:
batch:
timeout: 5s
service:
pipelines:
traces:
receivers: [otlp]
processors: [batch]
exporters: [otlphttp]

View file

@ -1,27 +1,63 @@
# Function calling
This demo shows how you can use Arch's core function calling capabilites.
# Starting the demo
1. Please make sure the [pre-requisites](https://github.com/katanemo/arch/?tab=readme-ov-file#prerequisites) are installed correctly
2. Start Arch
3.
```sh
3. ```sh
sh run_demo.sh
```
4. Navigate to http://localhost:18080/
5. You can type in queries like "how is the weather?"
# Observability
Arch gateway publishes stats endpoint at http://localhost:19901/stats. In this demo we are using prometheus to pull stats from arch and we are using grafana to visalize the stats in dashboard. To see grafana dashboard follow instructions below,
1. Start grafana and prometheus using following command
```yaml
docker compose --profile monitoring up
```
1. Navigate to http://localhost:3000/ to open grafana UI (use admin/grafana as credentials)
1. From grafana left nav click on dashboards and select "Intelligent Gateway Overview" to view arch gateway stats
2. Navigate to http://localhost:3000/ to open grafana UI (use admin/grafana as credentials)
3. From grafana left nav click on dashboards and select "Intelligent Gateway Overview" to view arch gateway stats
Here is a sample interaction,
<img width="575" alt="image" src="https://github.com/user-attachments/assets/e0929490-3eb2-4130-ae87-a732aea4d059">
## Tracing
To see a tracing dashboard follow instructions below,
1. For Jaeger, you can either use the default run_demo.sh script or run the following command:
```sh
sh run_demo.sh jaeger
```
2. For Logfire, first make sure to add a LOGFIRE_API_KEY to the .env file. You can either use the default run_demo.sh script or run the following command:
```sh
sh run_demo.sh logfire
```
3. For Signoz, you can either use the default run_demo.sh script or run the following command:
```sh
sh run_demo.sh signoz
```
If using Jaeger, navigate to http://localhost:16686/ to open Jaeger UI
If using Signoz, navigate to http://localhost:3301/ to open Signoz UI
If using Logfire, navigate to your logfire dashboard that you got the write key from to view the dashboard
### Stopping Demo
1. To end the demo, run the following command:
```sh
sh run_demo.sh down
```

View file

@ -60,6 +60,7 @@ prompt_targets:
endpoint:
name: weather_forecast_service
path: /weather
http_method: POST
- name: default_target
default: true
@ -67,6 +68,7 @@ prompt_targets:
endpoint:
name: weather_forecast_service
path: /default_target
http_method: POST
system_prompt: |
You are a helpful assistant! Summarize the user's request and provide a helpful response.
# if it is set to false arch will send response that it received from this prompt target to the user

View file

@ -0,0 +1,41 @@
services:
weather_forecast_service:
build:
context: ./
environment:
- OLTP_HOST=http://jaeger:4317
extra_hosts:
- "host.docker.internal:host-gateway"
ports:
- "18083:80"
chatbot_ui:
build:
context: ../shared/chatbot_ui
ports:
- "18080:8080"
environment:
# this is only because we are running the sample app in the same docker container environemtn as archgw
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml
jaeger:
build:
context: ../shared/jaeger
ports:
- "16686:16686"
- "4317:4317"
- "4318:4318"
prometheus:
build:
context: ../shared/prometheus
grafana:
build:
context: ../shared/grafana
ports:
- "3000:3000"

View file

@ -0,0 +1,46 @@
services:
weather_forecast_service:
build:
context: ./
environment:
- OLTP_HOST=http://otel-collector:4317
extra_hosts:
- "host.docker.internal:host-gateway"
ports:
- "18083:80"
chatbot_ui:
build:
context: ../shared/chatbot_ui
ports:
- "18080:8080"
environment:
# this is only because we are running the sample app in the same docker container environment as archgw
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- ./arch_config.yaml:/app/arch_config.yaml
otel-collector:
build:
context: ../shared/logfire/
ports:
- "4317:4317"
- "4318:4318"
volumes:
- ../shared/logfire/otel-collector-config.yaml:/etc/otel-collector-config.yaml
env_file:
- .env
environment:
- LOGFIRE_API_KEY
prometheus:
build:
context: ../shared/prometheus
grafana:
build:
context: ../shared/grafana
ports:
- "3000:3000"

View file

@ -4,7 +4,7 @@ include:
services:
weather_forecast_service:
build:
context: ../shared/weather_forecast_service
context: .
environment:
- OLTP_HOST=http://otel-collector:4317
extra_hosts:

View file

@ -1,7 +1,7 @@
services:
weather_forecast_service:
build:
context: ../shared/weather_forecast_service
context: ./
environment:
- OLTP_HOST=http://jaeger:4317
extra_hosts:

View file

@ -1,47 +1,95 @@
#!/bin/bash
set -e
# Function to load environment variables from the .env file
load_env() {
if [ -f ".env" ]; then
export $(grep -v '^#' .env | xargs)
fi
}
# Function to determine the docker-compose file based on the argument
get_compose_file() {
case "$1" in
jaeger)
echo "docker-compose-jaeger.yaml"
;;
logfire)
echo "docker-compose-logfire.yaml"
;;
signoz)
echo "docker-compose-signoz.yaml"
;;
*)
echo "docker-compose.yaml"
;;
esac
}
# Function to start the demo
start_demo() {
# Step 1: Check if .env file exists
# Step 1: Determine the docker-compose file
COMPOSE_FILE=$(get_compose_file "$1" 2>/dev/null)
# Step 2: Check if .env file exists
if [ -f ".env" ]; then
echo ".env file already exists. Skipping creation."
else
# Step 2: Create `.env` file and set OpenAI key
# Step 3: Check for required environment variables
if [ -z "$OPENAI_API_KEY" ]; then
echo "Error: OPENAI_API_KEY environment variable is not set for the demo."
exit 1
fi
if [ "$1" == "logfire" ] && [ -z "$LOGFIRE_API_KEY" ]; then
echo "Error: LOGFIRE_API_KEY environment variable is required for Logfire."
exit 1
fi
# Create .env file
echo "Creating .env file..."
echo "OPENAI_API_KEY=$OPENAI_API_KEY" > .env
echo ".env file created with OPENAI_API_KEY."
if [ "$1" == "logfire" ]; then
echo "LOGFIRE_API_KEY=$LOGFIRE_API_KEY" >> .env
fi
echo ".env file created with required API keys."
fi
# Step 3: Start Arch
load_env
if [ "$1" == "logfire" ] && [ -z "$LOGFIRE_API_KEY" ]; then
echo "Error: LOGFIRE_API_KEY environment variable is required for Logfire."
exit 1
fi
# Step 4: Start Arch
echo "Starting Arch with arch_config.yaml..."
archgw up arch_config.yaml
# Step 4: Start Network Agent
echo "Starting Network Agent using Docker Compose..."
docker compose up -d # Run in detached mode
# Step 5: Start Network Agent with the chosen Docker Compose file
echo "Starting Network Agent with $COMPOSE_FILE..."
docker compose -f "$COMPOSE_FILE" up -d # Run in detached mode
}
# Function to stop the demo
stop_demo() {
# Step 1: Stop Docker Compose services
echo "Stopping Network Agent using Docker Compose..."
docker compose down
echo "Stopping all Docker Compose services..."
# Step 2: Stop Arch
# Stop all services by iterating through all configurations
for compose_file in ./docker-compose*.yaml; do
echo "Stopping services in $compose_file..."
docker compose -f "$compose_file" down
done
# Stop Arch
echo "Stopping Arch..."
archgw down
}
# Main script logic
if [ "$1" == "down" ]; then
# Call stop_demo with the second argument as the demo to stop
stop_demo
else
# Default action is to bring the demo up
start_demo
# Use the argument (jaeger, logfire, signoz) to determine the compose file
start_demo "$1"
fi

View file

@ -1,29 +0,0 @@
# Function calling
This demo shows how you can use Arch's core function calling capabilites.
# Starting the demo
1. Please make sure the [pre-requisites](https://github.com/katanemo/arch/?tab=readme-ov-file#prerequisites) are installed correctly
2. Start Arch
3.
```sh
sh run_demo.sh
```
4. Navigate to http://localhost:18080/
5. You can type in queries like "how is the weather?"
# Observability
Arch gateway publishes stats endpoint at http://localhost:19901/stats. In this demo we are using prometheus to pull stats from arch and we are using grafana to visalize the stats in dashboard. To see grafana dashboard follow instructions below,
1. Start grafana and prometheus using following command
```yaml
docker compose --profile monitoring up
```
1. Navigate to http://localhost:3000/ to open grafana UI (use admin/grafana as credentials)
1. From grafana left nav click on dashboards and select "Intelligent Gateway Overview" to view arch gateway stats
Here is a sample interaction,
<img width="575" alt="image" src="https://github.com/user-attachments/assets/e0929490-3eb2-4130-ae87-a732aea4d059">
1. Signoz UI: http://localhost:3301

View file

@ -1,72 +0,0 @@
version: "0.1-beta"
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
endpoints:
weather_forecast_service:
endpoint: host.docker.internal:18083
connect_timeout: 0.005s
overrides:
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.6
llm_providers:
- name: gpt-4o-mini
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-4o-mini
default: true
- name: gpt-3.5-turbo-0125
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-3.5-turbo-0125
- name: gpt-4o
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-4o
system_prompt: |
You are a helpful assistant.
prompt_targets:
- name: weather_forecast
description: Check weather information for a given city.
parameters:
- name: city
description: the name of the city
required: true
type: str
- name: days
description: the number of days
type: int
required: true
- name: units
description: the temperature unit, e.g., Celsius and Fahrenheit
type: str
default: Fahrenheit
endpoint:
name: weather_forecast_service
path: /weather
- name: default_target
default: true
description: This is the default target for all unmatched prompts.
endpoint:
name: weather_forecast_service
path: /default_target
system_prompt: |
You are a helpful assistant! Summarize the user's request and provide a helpful response.
# if it is set to false arch will send response that it received from this prompt target to the user
# if true arch will forward the response to the default LLM
auto_llm_dispatch_on_response: false
tracing:
random_sampling: 100
# trace_arch: true

View file

@ -1,43 +0,0 @@
version: v0.1
listener:
address: 127.0.0.1
port: 8080 #If you configure port 443, you'll need to update the listener with tls_certificates
message_format: huggingface
# Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way
llm_providers:
- name: OpenAI
provider: openai
access_key: $OPENAI_API_KEY
model: gpt-3.5-turbo
default: true
# default system prompt used by all prompt targets
system_prompt: |
You are a network assistant that helps operators with a better understanding of network traffic flow and perform actions on networking operations. No advice on manufacturers or purchasing decisions.
prompt_targets:
- name: device_summary
description: Retrieve network statistics for specific devices within a time range
endpoint:
name: app_server
path: /agent/device_summary
parameters:
- name: device_ids
type: list
description: A list of device identifiers (IDs) to retrieve statistics for.
required: true # device_ids are required to get device statistics
- name: days
type: int
description: The number of days for which to gather device statistics.
default: "7"
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
endpoints:
app_server:
# value could be ip address or a hostname with port
# this could also be a list of endpoints for load balancing
# for example endpoint: [ ip1:port, ip2:port ]
endpoint: host.docker.internal:18083
# max time to wait for a connection to be established
connect_timeout: 0.005s

View file

@ -7,74 +7,255 @@ Follow this guide to learn how to quickly set up Arch and integrate it into your
Prerequisites
----------------------------
-------------
Before you begin, ensure you have the following:
.. vale Vale.Spelling = NO
1. `Docker System <https://docs.docker.com/get-started/get-docker/>`_ (v24)
2. `Docker compose <https://docs.docker.com/compose/install/>`_ (v2.29)
3. `Python <https://www.python.org/downloads/>`_ (v3.12)
- ``Docker`` & ``Python`` installed on your system
- ``API Keys`` for LLM providers (if using external LLMs)
The fastest way to get started using Arch is to use `katanemo/archgw <https://hub.docker.com/r/katanemo/archgw>`_ pre-built binaries.
You can also build it from source.
Step 1: Install Arch
----------------------------
Arch's CLI allows you to manage and interact with the Arch gateway efficiently. To install the CLI, simply
run the following command:
.. code-block:: console
$ pip install archgw
This will install the archgw command-line tool globally on your system.
Arch's CLI allows you to manage and interact with the Arch gateway efficiently. To install the CLI, simply run the following command:
.. tip::
We recommend that developers create a new Python virtual environment to isolate dependencies before installing Arch.
This ensures that `archgw` and its dependencies do not interfere with other packages on your system.
To create and activate a virtual environment, you can run the following commands:
.. code-block:: console
$ python -m venv venv
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
$ pip install archgw
Step 2: Config Arch
-------------------
Arch operates based on a configuration file where you can define LLM providers, prompt targets, and guardrails, etc.
Below is an example configuration to get you started, including:
.. vale Vale.Spelling = NO
- ``endpoints``: Specifies where Arch listens for incoming prompts.
- ``system_prompts``: Defines predefined prompts to set the context for interactions.
- ``llm_providers``: Lists the LLM providers Arch can route prompts to.
- ``prompt_guards``: Sets up rules to detect and reject undesirable prompts.
- ``prompt_targets``: Defines endpoints that handle specific types of prompts.
- ``error_target``: Specifies where to route errors for handling.
.. literalinclude:: includes/quickstart.yaml
:language: yaml
Step 3: Start Arch Gateway
--------------------------
We recommend that developers create a new Python virtual environment to isolate dependencies before installing Arch. This ensures that ``archgw`` and its dependencies do not interfere with other packages on your system.
.. code-block:: console
$ archgw up [path_to_config]
$ python -m venv venv
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
$ pip install archgw==0.1.6
For detailed usage please refer to the `Support <https://github.com/katanemo/arch/blob/main/arch/tools/README.md#setup-instructionsuser-archgw-cli>`
Build AI Agent with Arch Gateway
--------------------------------
In the following quickstart, we will show you how easy it is to build an AI agent with the Arch gateway. We will build a currency exchange agent using the following simple steps. For this demo, we will use `https://api.frankfurter.dev/` to fetch the latest prices for currencies and assume USD as the base currency.
Step 1. Create arch config file
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Create ``arch_config.yaml`` file with the following content:
.. code-block:: yaml
version: v0.1
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
llm_providers:
- name: gpt-4o
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-4o
system_prompt: |
You are a helpful assistant.
prompt_guards:
input_guards:
jailbreak:
on_exception:
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
prompt_targets:
- name: currency_exchange
description: Get currency exchange rate from USD to other currencies
parameters:
- name: currency_symbol
description: the currency that needs conversion
required: true
type: str
in_path: true
endpoint:
name: frankfurther_api
path: /v1/latest?base=USD&symbols={currency_symbol}
system_prompt: |
You are a helpful assistant. Show me the currency symbol you want to convert from USD.
- name: get_supported_currencies
description: Get list of supported currencies for conversion
endpoint:
name: frankfurther_api
path: /v1/currencies
endpoints:
frankfurther_api:
endpoint: api.frankfurter.dev:443
protocol: https
Step 2. Start arch gateway with currency conversion config
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: sh
$ archgw up arch_config.yaml
2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.1.5
...
2024-12-05 16:56:28,485 - cli.utils - INFO - Schema validation successful!
2024-12-05 16:56:28,485 - cli.main - INFO - Starting arch model server and arch gateway
...
2024-12-05 16:56:51,647 - cli.core - INFO - Container is healthy!
Once the gateway is up, you can start interacting with it at port 10000 using the OpenAI chat completion API.
Some sample queries you can ask include: ``what is currency rate for gbp?`` or ``show me list of currencies for conversion``.
Step 3. Interacting with gateway using curl command
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Here is a sample curl command you can use to interact:
.. code-block:: bash
$ curl --header 'Content-Type: application/json' \
--data '{"messages": [{"role": "user","content": "what is exchange rate for gbp"}]}' \
http://localhost:10000/v1/chat/completions | jq ".choices[0].message.content"
"As of the date provided in your context, December 5, 2024, the exchange rate for GBP (British Pound) from USD (United States Dollar) is 0.78558. This means that 1 USD is equivalent to 0.78558 GBP."
And to get the list of supported currencies:
.. code-block:: bash
$ curl --header 'Content-Type: application/json' \
--data '{"messages": [{"role": "user","content": "show me list of currencies that are supported for conversion"}]}' \
http://localhost:10000/v1/chat/completions | jq ".choices[0].message.content"
"Here is a list of the currencies that are supported for conversion from USD, along with their symbols:\n\n1. AUD - Australian Dollar\n2. BGN - Bulgarian Lev\n3. BRL - Brazilian Real\n4. CAD - Canadian Dollar\n5. CHF - Swiss Franc\n6. CNY - Chinese Renminbi Yuan\n7. CZK - Czech Koruna\n8. DKK - Danish Krone\n9. EUR - Euro\n10. GBP - British Pound\n11. HKD - Hong Kong Dollar\n12. HUF - Hungarian Forint\n13. IDR - Indonesian Rupiah\n14. ILS - Israeli New Sheqel\n15. INR - Indian Rupee\n16. ISK - Icelandic Króna\n17. JPY - Japanese Yen\n18. KRW - South Korean Won\n19. MXN - Mexican Peso\n20. MYR - Malaysian Ringgit\n21. NOK - Norwegian Krone\n22. NZD - New Zealand Dollar\n23. PHP - Philippine Peso\n24. PLN - Polish Złoty\n25. RON - Romanian Leu\n26. SEK - Swedish Krona\n27. SGD - Singapore Dollar\n28. THB - Thai Baht\n29. TRY - Turkish Lira\n30. USD - United States Dollar\n31. ZAR - South African Rand\n\nIf you want to convert USD to any of these currencies, you can select the one you are interested in."
Use Arch Gateway as LLM Router
------------------------------
Step 1. Create arch config file
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Arch operates based on a configuration file where you can define LLM providers, prompt targets, guardrails, etc. Below is an example configuration that defines OpenAI and Mistral LLM providers.
Create ``arch_config.yaml`` file with the following content:
.. code-block:: yaml
version: v0.1
listener:
address: 0.0.0.0
port: 10000
message_format: huggingface
connect_timeout: 0.005s
llm_providers:
- name: gpt-4o
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-4o
default: true
- name: ministral-3b
access_key: $MISTRAL_API_KEY
provider: mistral
model: ministral-3b-latest
Step 2. Start arch gateway
~~~~~~~~~~~~~~~~~~~~~~~~~~
Once the config file is created, ensure that you have environment variables set up for ``MISTRAL_API_KEY`` and ``OPENAI_API_KEY`` (or these are defined in a ``.env`` file).
Start the Arch gateway:
.. code-block:: console
$ archgw up arch_config.yaml
2024-12-05 11:24:51,288 - cli.main - INFO - Starting archgw cli version: 0.1.5
2024-12-05 11:24:51,825 - cli.utils - INFO - Schema validation successful!
2024-12-05 11:24:51,825 - cli.main - INFO - Starting arch model server and arch gateway
...
2024-12-05 11:25:16,131 - cli.core - INFO - Container is healthy!
Step 3: Interact with LLM
~~~~~~~~~~~~~~~~~~~~~~~~~
Step 3.1: Using OpenAI Python client
++++++++++++++++++++++++++++++++++++
Make outbound calls via the Arch gateway:
.. code-block:: python
from openai import OpenAI
# Use the OpenAI client as usual
client = OpenAI(
# No need to set a specific openai.api_key since it's configured in Arch's gateway
api_key='--',
# Set the OpenAI API base URL to the Arch gateway endpoint
base_url="http://127.0.0.1:12000/v1"
)
response = client.chat.completions.create(
# we select model from arch_config file
model="--",
messages=[{"role": "user", "content": "What is the capital of France?"}],
)
print("OpenAI Response:", response.choices[0].message.content)
Step 3.2: Using curl command
++++++++++++++++++++++++++++
.. code-block:: bash
$ curl --header 'Content-Type: application/json' \
--data '{"messages": [{"role": "user","content": "What is the capital of France?"}]}' \
http://localhost:12000/v1/chat/completions
{
...
"model": "gpt-4o-2024-08-06",
"choices": [
{
...
"message": {
"role": "assistant",
"content": "The capital of France is Paris.",
},
}
],
}
You can override model selection using the ``x-arch-llm-provider-hint`` header. For example, to use Mistral, use the following curl command:
.. code-block:: bash
$ curl --header 'Content-Type: application/json' \
--header 'x-arch-llm-provider-hint: ministral-3b' \
--data '{"messages": [{"role": "user","content": "What is the capital of France?"}]}' \
http://localhost:12000/v1/chat/completions
{
...
"model": "ministral-3b-latest",
"choices": [
{
"message": {
"role": "assistant",
"content": "The capital of France is Paris. It is the most populous city in France and is known for its iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. Paris is also a major global center for art, fashion, gastronomy, and culture.",
},
...
}
],
...
}
Next Steps
-------------------
==========
Congratulations! You've successfully set up Arch and made your first prompt-based request. To further enhance your GenAI applications, explore the following resources:

View file

@ -31,6 +31,7 @@ prompt_targets:
endpoint:
name: app_server
path: /agent/summary
http_method: POST
# Arch uses the default LLM and treats the response from the endpoint as the prompt to send to the LLM
auto_llm_dispatch_on_response: true
# override system prompt for this prompt target
@ -41,6 +42,7 @@ prompt_targets:
endpoint:
name: app_server
path: /agent/action
http_method: POST
parameters:
- name: device_id
type: str

View file

@ -77,6 +77,7 @@ prompt_targets:
endpoint:
name: app_server
path: /agent/summary
http_method: POST
# Arch uses the default LLM and treats the response from the endpoint as the prompt to send to the LLM
auto_llm_dispatch_on_response: true
# override system prompt for this prompt target

View file

@ -39,7 +39,7 @@ cd -
log building and installing archgw cli
log ==================================
cd ../arch/tools
sh build_cli.sh
poetry install
cd -
log building docker image for arch gateway

View file

@ -220,6 +220,12 @@ async def hallucination(req: HallucinationRequest, res: Response):
if "messages" in req.parameters:
req.parameters.pop("messages")
if not req.parameters or len(req.parameters) == 0:
return {
"params_scores": {},
"model": req.model,
}
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
predictions = classifier(

1444
model_server/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "archgw_modelserver"
version = "0.1.5"
version = "0.1.6"
description = "A model server for serving models"
authors = ["Katanemo Labs, Inc <archgw@katanemo.com>"]
license = "Apache 2.0"
@ -29,7 +29,7 @@ openai = "1.50.2"
tf-keras = "*"
onnx = "1.17.0"
onnxruntime = "1.19.2"
httpx = "*"
httpx = "0.27.2" # https://community.openai.com/t/typeerror-asyncclient-init-got-an-unexpected-keyword-argument-proxies/1040287
pytest-asyncio = "*"
pytest = "*"
opentelemetry-api = "^1.28.0"