mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Merge branch 'main' into metrics
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
commit
0939c72bcc
56 changed files with 6168 additions and 1902 deletions
243
README.md
243
README.md
|
|
@ -33,94 +33,184 @@ Engineered with purpose-built LLMs, Arch handles the critical but undifferentiat
|
|||
To get in touch with us, please join our [discord server](https://discord.gg/pGZf2gcwEc). We will be monitoring that actively and offering support there.
|
||||
|
||||
## Demos
|
||||
* [Weather Forecast](demos/weather_forecast/README.md) - Walk through of the core function calling capabilities of of arch gateway using weather forecasting service
|
||||
* [Weather Forecast](demos/weather_forecast/README.md) - Walk through of the core function calling capabilities of arch gateway using weather forecasting service
|
||||
* [Insurance Agent](demos/insurance_agent/README.md) - Build a full insurance agent with Arch
|
||||
* [Network Agent](demos/network_agent/README.md) - Build a networking co-pilot/agent agent with Arch
|
||||
|
||||
## Quickstart
|
||||
|
||||
Follow this guide to learn how to quickly set up Arch and integrate it into your generative AI applications.
|
||||
Follow this quickstart guide to use arch gateway to build a simple AI agent. Laster in the section we will see how you can Arch Gateway to manage access keys, provide unified access to upstream LLMs and to provide e2e observability.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before you begin, ensure you have the following:
|
||||
|
||||
- `Docker` & `Python` installed on your system
|
||||
- `API Keys` for LLM providers (if using external LLMs)
|
||||
|
||||
|
||||
### Step 1: Install Arch
|
||||
|
||||
Arch's CLI allows you to manage and interact with the Arch gateway efficiently. To install the CLI, simply run the following command:
|
||||
Tip: We recommend that developers create a new Python virtual environment to isolate dependencies before installing Arch. This ensures that archgw and its dependencies do not interfere with other packages on your system.
|
||||
|
||||
Make sure you have following utilities installed before proceeding further,
|
||||
|
||||
1. [Docker System](https://docs.docker.com/get-started/get-docker/) (v24)
|
||||
2. [Docker compose](https://docs.docker.com/compose/install/) (v2.29)
|
||||
3. [Python](https://www.python.org/downloads/) (v3.12)
|
||||
4. [Poetry](https://python-poetry.org/docs/#installing-with-the-official-installer) (v1.8.3. *Note: only needed for local development*)
|
||||
|
||||
Arch's CLI allows you to manage and interact with the Arch gateway efficiently. To install the CLI, simply run the following command:
|
||||
|
||||
> [!TIP]
|
||||
> We recommend that developers create a new Python virtual environment to isolate dependencies before installing Arch. This ensures that archgw and its dependencies do not interfere with other packages on your system.
|
||||
|
||||
```console
|
||||
$ python -m venv venv
|
||||
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
|
||||
$ pip install archgw
|
||||
$ pip install archgw==0.1.6
|
||||
```
|
||||
|
||||
### Step 2: Configure Arch with your application
|
||||
### Build AI Agent with Arch Gateway
|
||||
|
||||
Arch operates based on a configuration file where you can define LLM providers, prompt targets, guardrails, etc.
|
||||
Below is an example configuration to get you started:
|
||||
In following quickstart we will show you how easy it is to build AI agent with Arch gateway. We will build a currency exchange agent using following simple steps. For this demo we will use `https://api.frankfurter.dev/` to fetch latest price for currencies and assume USD as base currency.
|
||||
|
||||
#### Step 1. Create arch config file
|
||||
|
||||
Create `arch_config.yaml` file with following content,
|
||||
|
||||
```yaml
|
||||
version: v0.1
|
||||
|
||||
listener:
|
||||
address: 127.0.0.1
|
||||
port: 8080 #If you configure port 443, you'll need to update the listener with tls_certificates
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
# Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way
|
||||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-3.5-turbo
|
||||
default: true
|
||||
provider: openai
|
||||
model: gpt-4o
|
||||
|
||||
# default system prompt used by all prompt targets
|
||||
system_prompt: |
|
||||
You are a network assistant that helps operators with a better understanding of network traffic flow and perform actions on networking operations. No advice on manufacturers or purchasing decisions.
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
|
||||
|
||||
prompt_targets:
|
||||
- name: device_summary
|
||||
description: Retrieve network statistics for specific devices within a time range
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_summary
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs) to retrieve statistics for.
|
||||
required: true # device_ids are required to get device statistics
|
||||
- name: days
|
||||
type: int
|
||||
description: The number of days for which to gather device statistics.
|
||||
default: "7"
|
||||
- name: currency_exchange
|
||||
description: Get currency exchange rate from USD to other currencies
|
||||
parameters:
|
||||
- name: currency_symbol
|
||||
description: the currency that needs conversion
|
||||
required: true
|
||||
type: str
|
||||
in_path: true
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/latest?base=USD&symbols={currency_symbol}
|
||||
system_prompt: |
|
||||
You are a helpful assistant. Show me the currency symbol you want to convert from USD.
|
||||
|
||||
- name: get_supported_currencies
|
||||
description: Get list of supported currencies for conversion
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/currencies
|
||||
|
||||
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
|
||||
endpoints:
|
||||
app_server:
|
||||
# value could be ip address or a hostname with port
|
||||
# this could also be a list of endpoints for load balancing
|
||||
# for example endpoint: [ ip1:port, ip2:port ]
|
||||
endpoint: host.docker.internal:18083
|
||||
# max time to wait for a connection to be established
|
||||
connect_timeout: 0.005s
|
||||
frankfurther_api:
|
||||
endpoint: api.frankfurter.dev:443
|
||||
protocol: https
|
||||
```
|
||||
### Step 3: Using OpenAI Client with Arch as an Egress Gateway
|
||||
|
||||
Make outbound calls via Arch
|
||||
#### Step 2. Start arch gateway with currency conversion config
|
||||
|
||||
```sh
|
||||
|
||||
$ archgw up arch_config.yaml
|
||||
2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.1.5
|
||||
...
|
||||
2024-12-05 16:56:28,485 - cli.utils - INFO - Schema validation successful!
|
||||
2024-12-05 16:56:28,485 - cli.main - INFO - Starging arch model server and arch gateway
|
||||
...
|
||||
2024-12-05 16:56:51,647 - cli.core - INFO - Container is healthy!
|
||||
|
||||
```
|
||||
|
||||
Once the gateway is up you can start interacting with at port 10000 using openai chat completion API.
|
||||
|
||||
Some of the sample queries you can ask could be `what is currency rate for gbp?` or `show me list of currencies for conversion`.
|
||||
|
||||
#### Step 3. Interacting with gateway using curl command
|
||||
|
||||
Here is a sample curl command you can use to interact,
|
||||
|
||||
```bash
|
||||
$ curl --header 'Content-Type: application/json' \
|
||||
--data '{"messages": [{"role": "user","content": "what is exchange rate for gbp"}]}' \
|
||||
http://localhost:10000/v1/chat/completions | jq ".choices[0].message.content"
|
||||
|
||||
"As of the date provided in your context, December 5, 2024, the exchange rate for GBP (British Pound) from USD (United States Dollar) is 0.78558. This means that 1 USD is equivalent to 0.78558 GBP."
|
||||
|
||||
```
|
||||
|
||||
And to get list of supported currencies,
|
||||
|
||||
```bash
|
||||
$ curl --header 'Content-Type: application/json' \
|
||||
--data '{"messages": [{"role": "user","content": "show me list of currencies that are supported for conversion"}]}' \
|
||||
http://localhost:10000/v1/chat/completions | jq ".choices[0].message.content"
|
||||
|
||||
"Here is a list of the currencies that are supported for conversion from USD, along with their symbols:\n\n1. AUD - Australian Dollar\n2. BGN - Bulgarian Lev\n3. BRL - Brazilian Real\n4. CAD - Canadian Dollar\n5. CHF - Swiss Franc\n6. CNY - Chinese Renminbi Yuan\n7. CZK - Czech Koruna\n8. DKK - Danish Krone\n9. EUR - Euro\n10. GBP - British Pound\n11. HKD - Hong Kong Dollar\n12. HUF - Hungarian Forint\n13. IDR - Indonesian Rupiah\n14. ILS - Israeli New Sheqel\n15. INR - Indian Rupee\n16. ISK - Icelandic Króna\n17. JPY - Japanese Yen\n18. KRW - South Korean Won\n19. MXN - Mexican Peso\n20. MYR - Malaysian Ringgit\n21. NOK - Norwegian Krone\n22. NZD - New Zealand Dollar\n23. PHP - Philippine Peso\n24. PLN - Polish Złoty\n25. RON - Romanian Leu\n26. SEK - Swedish Krona\n27. SGD - Singapore Dollar\n28. THB - Thai Baht\n29. TRY - Turkish Lira\n30. USD - United States Dollar\n31. ZAR - South African Rand\n\nIf you want to convert USD to any of these currencies, you can select the one you are interested in."
|
||||
|
||||
```
|
||||
|
||||
### Use Arch Gateway as LLM Router
|
||||
|
||||
#### Step 1. Create arch config file
|
||||
|
||||
Arch operates based on a configuration file where you can define LLM providers, prompt targets, guardrails, etc. Below is an example configuration that defines openai and mistral LLM providers.
|
||||
|
||||
Create `arch_config.yaml` file with following content:
|
||||
|
||||
```yaml
|
||||
version: v0.1
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-4o
|
||||
default: true
|
||||
|
||||
- name: ministral-3b
|
||||
access_key: $MISTRAL_API_KEY
|
||||
provider: mistral
|
||||
model: ministral-3b-latest
|
||||
```
|
||||
|
||||
#### Step 2. Start arch gateway
|
||||
|
||||
Once the config file is created ensure that you have env vars setup for `MISTRAL_API_KEY` and `OPENAI_API_KEY` (or these are defined in `.env` file).
|
||||
|
||||
Start arch gateway,
|
||||
|
||||
```
|
||||
$ archgw up arch_config.yaml
|
||||
2024-12-05 11:24:51,288 - cli.main - INFO - Starting archgw cli version: 0.1.5
|
||||
2024-12-05 11:24:51,825 - cli.utils - INFO - Schema validation successful!
|
||||
2024-12-05 11:24:51,825 - cli.main - INFO - Starting arch model server and arch gateway
|
||||
...
|
||||
2024-12-05 11:25:16,131 - cli.core - INFO - Container is healthy!
|
||||
```
|
||||
|
||||
### Step 3: Interact with LLM
|
||||
|
||||
#### Step 3.1: Using OpenAI python client
|
||||
|
||||
Make outbound calls via Arch gateway
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
|
@ -143,12 +233,59 @@ print("OpenAI Response:", response.choices[0].message.content)
|
|||
|
||||
```
|
||||
|
||||
### [Observability](https://docs.archgw.com/guides/observability/observability.html)
|
||||
#### Step 3.2: Using curl command
|
||||
```
|
||||
$ curl --header 'Content-Type: application/json' \
|
||||
--data '{"messages": [{"role": "user","content": "What is the capital of France?"}]}' \
|
||||
http://localhost:12000/v1/chat/completions
|
||||
|
||||
{
|
||||
...
|
||||
"model": "gpt-4o-2024-08-06",
|
||||
"choices": [
|
||||
{
|
||||
...
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The capital of France is Paris.",
|
||||
},
|
||||
}
|
||||
],
|
||||
...
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
You can override model selection using `x-arch-llm-provider-hint` header. For example if you want to use mistral using following curl command,
|
||||
|
||||
```
|
||||
$ curl --header 'Content-Type: application/json' \
|
||||
--header 'x-arch-llm-provider-hint: ministral-3b' \
|
||||
--data '{"messages": [{"role": "user","content": "What is the capital of France?"}]}' \
|
||||
http://localhost:12000/v1/chat/completions
|
||||
{
|
||||
...
|
||||
"model": "ministral-3b-latest",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The capital of France is Paris. It is the most populous city in France and is known for its iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. Paris is also a major global center for art, fashion, gastronomy, and culture.",
|
||||
},
|
||||
...
|
||||
}
|
||||
],
|
||||
...
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## [Observability](https://docs.archgw.com/guides/observability/observability.html)
|
||||
Arch is designed to support best-in class observability by supporting open standards. Please read our [docs](https://docs.archgw.com/guides/observability/observability.html) on observability for more details on tracing, metrics, and logs. The screenshot below is from our integration with Signoz (among others)
|
||||
|
||||

|
||||
|
||||
### Contribution
|
||||
## Contribution
|
||||
We would love feedback on our [Roadmap](https://github.com/orgs/katanemo/projects/1) and we welcome contributions to **Arch**!
|
||||
Whether you're fixing bugs, adding new features, improving documentation, or creating tutorials, your help is much appreciated.
|
||||
Please visit our [Contribution Guide](CONTRIBUTING.md) for more details
|
||||
|
|
|
|||
|
|
@ -28,6 +28,11 @@ properties:
|
|||
type: string
|
||||
connect_timeout:
|
||||
type: string
|
||||
protocol:
|
||||
type: string
|
||||
enum:
|
||||
- http
|
||||
- https
|
||||
additionalProperties: false
|
||||
required:
|
||||
- endpoint
|
||||
|
|
@ -92,6 +97,8 @@ properties:
|
|||
type: array
|
||||
items:
|
||||
type: string
|
||||
in_path:
|
||||
type: boolean
|
||||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
|
|
@ -104,6 +111,11 @@ properties:
|
|||
type: string
|
||||
path:
|
||||
type: string
|
||||
http_method:
|
||||
type: string
|
||||
enum:
|
||||
- GET
|
||||
- POST
|
||||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
|
|
|
|||
|
|
@ -500,7 +500,18 @@ static_resources:
|
|||
socket_address:
|
||||
address: {{ cluster.endpoint }}
|
||||
port_value: {{ cluster.port }}
|
||||
hostname: {{ cluster.name }}
|
||||
hostname: {{ cluster.endpoint }}
|
||||
{% if cluster.protocol == "https" %}
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
sni: {{ cluster.endpoint }}
|
||||
common_tls_context:
|
||||
tls_params:
|
||||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
- name: arch_internal
|
||||
connect_timeout: 5s
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ source venv/bin/activate
|
|||
|
||||
### Step 3: Run the build script
|
||||
```bash
|
||||
pip install archgw
|
||||
pip install archgw==0.1.6
|
||||
```
|
||||
|
||||
## Uninstall Instructions: archgw CLI
|
||||
|
|
@ -48,7 +48,7 @@ source venv/bin/activate
|
|||
|
||||
### Step 3: Run the build script
|
||||
```bash
|
||||
sh build_cli.sh
|
||||
poetry install
|
||||
```
|
||||
|
||||
### Step 4: build Arch
|
||||
|
|
|
|||
|
|
@ -49,7 +49,9 @@ def validate_and_render_schema():
|
|||
|
||||
if "prompt_targets" in config_yaml:
|
||||
for prompt_target in config_yaml["prompt_targets"]:
|
||||
name = prompt_target.get("endpoint", {}).get("name", "")
|
||||
name = prompt_target.get("endpoint", {}).get("name", None)
|
||||
if not name:
|
||||
continue
|
||||
if name not in inferred_clusters:
|
||||
inferred_clusters[name] = {
|
||||
"name": name,
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ def start_arch_modelserver():
|
|||
subprocess.run(
|
||||
["archgw_modelserver", "restart"], check=True, start_new_session=True
|
||||
)
|
||||
log.info("Successfull ran model_server")
|
||||
log.info("Successfully ran model_server")
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
sys.exit(1)
|
||||
|
|
@ -212,7 +212,7 @@ def stop_arch_modelserver():
|
|||
["archgw_modelserver", "stop"],
|
||||
check=True,
|
||||
)
|
||||
log.info("Successfull stopped the archgw model_server")
|
||||
log.info("Successfully stopped the archgw model_server")
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
sys.exit(1)
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ def up(file, path, service):
|
|||
log.info(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
log.info("Starging arch model server and arch gateway")
|
||||
log.info("Starting arch model server and arch gateway")
|
||||
|
||||
# Set the ARCH_CONFIG_FILE environment variable
|
||||
env_stage = {}
|
||||
|
|
|
|||
3882
arch/tools/poetry.lock
generated
3882
arch/tools/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "archgw"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
description = "Python-based CLI tool to manage Arch Gateway."
|
||||
authors = ["Katanemo Labs, Inc."]
|
||||
packages = [
|
||||
|
|
@ -10,7 +10,7 @@ readme = "README.md"
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.12"
|
||||
archgw_modelserver = "0.1.5"
|
||||
archgw_modelserver = "0.1.6"
|
||||
pyyaml = "^6.0.2"
|
||||
pydantic = "^2.10.1"
|
||||
click = "^8.1.7"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,23 @@
|
|||
use common::{
|
||||
common_types::open_ai::Message,
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
api::open_ai::Message,
|
||||
consts::{ARCH_MODEL_PREFIX, HALLUCINATION_TEMPLATE, USER_ROLE},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationRequest {
|
||||
pub prompt: String,
|
||||
pub parameters: HashMap<String, String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationResponse {
|
||||
pub params_scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
|
||||
let mut arch_assistant = false;
|
||||
|
|
@ -42,7 +58,7 @@ pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use common::common_types::open_ai::Message;
|
||||
use crate::api::open_ai::Message;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::extract_messages_for_hallucination;
|
||||
4
crates/common/src/api/mod.rs
Normal file
4
crates/common/src/api/mod.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pub mod hallucination;
|
||||
pub mod open_ai;
|
||||
pub mod prompt_guard;
|
||||
pub mod zero_shot;
|
||||
653
crates/common/src/api/open_ai.rs
Normal file
653
crates/common/src/api/open_ai.rs
Normal file
|
|
@ -0,0 +1,653 @@
|
|||
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt::Display,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ChatCompletionTool>>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ToolType {
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionTool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
pub properties: HashMap<String, FunctionParameter>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameters {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
// select all requried parameters
|
||||
let required: Vec<&String> = self
|
||||
.properties
|
||||
.iter()
|
||||
.filter(|(_, v)| v.required.unwrap_or(false))
|
||||
.map(|(k, _)| k)
|
||||
.collect();
|
||||
let mut map = serializer.serialize_map(Some(2))?;
|
||||
map.serialize_entry("properties", &self.properties)?;
|
||||
if !required.is_empty() {
|
||||
map.serialize_entry("required", &required)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameter {
|
||||
#[serde(rename = "type")]
|
||||
#[serde(default = "ParameterType::string")]
|
||||
pub parameter_type: ParameterType,
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "enum")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<String>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameter {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_map(Some(5))?;
|
||||
map.serialize_entry("type", &self.parameter_type)?;
|
||||
map.serialize_entry("description", &self.description)?;
|
||||
if let Some(enum_values) = &self.enum_values {
|
||||
map.serialize_entry("enum", enum_values)?;
|
||||
}
|
||||
if let Some(default) = &self.default {
|
||||
map.serialize_entry("default", default)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum ParameterType {
|
||||
#[serde(rename = "int")]
|
||||
Int,
|
||||
#[serde(rename = "float")]
|
||||
Float,
|
||||
#[serde(rename = "bool")]
|
||||
Bool,
|
||||
#[serde(rename = "str")]
|
||||
String,
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
#[serde(rename = "dict")]
|
||||
Dict,
|
||||
}
|
||||
|
||||
impl From<String> for ParameterType {
|
||||
fn from(s: String) -> Self {
|
||||
match s.as_str() {
|
||||
"int" => ParameterType::Int,
|
||||
"integer" => ParameterType::Int,
|
||||
"float" => ParameterType::Float,
|
||||
"bool" => ParameterType::Bool,
|
||||
"boolean" => ParameterType::Bool,
|
||||
"str" => ParameterType::String,
|
||||
"string" => ParameterType::String,
|
||||
"list" => ParameterType::List,
|
||||
"array" => ParameterType::List,
|
||||
"dict" => ParameterType::Dict,
|
||||
"dictionary" => ParameterType::Dict,
|
||||
_ => ParameterType::String,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ParameterType {
|
||||
pub fn string() -> ParameterType {
|
||||
ParameterType::String
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamOptions {
|
||||
pub include_usage: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub finish_reason: String,
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionCallDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDetail {
|
||||
pub name: String,
|
||||
pub arguments: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ToolCallState {
|
||||
pub key: String,
|
||||
pub message: Option<Message>,
|
||||
pub tool_call: FunctionCallDetail,
|
||||
pub tool_response: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ArchState {
|
||||
ToolCall(Vec<ToolCallState>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub usage: Option<Usage>,
|
||||
pub choices: Vec<Choice>,
|
||||
pub model: String,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl ChatCompletionsResponse {
|
||||
pub fn new(message: String) -> Self {
|
||||
ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: Some(message),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
index: 0,
|
||||
finish_reason: "done".to_string(),
|
||||
}],
|
||||
usage: None,
|
||||
model: ARCH_FC_MODEL_NAME.to_string(),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub completion_tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
impl ChatCompletionStreamResponse {
|
||||
pub fn new(
|
||||
response: Option<String>,
|
||||
role: Option<String>,
|
||||
model: Option<String>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
) -> Self {
|
||||
ChatCompletionStreamResponse {
|
||||
model,
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role,
|
||||
content: response,
|
||||
tool_calls,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChatCompletionChunkResponseError {
|
||||
#[error("failed to deserialize")]
|
||||
Deserialization(#[from] serde_json::Error),
|
||||
#[error("empty content in data chunk")]
|
||||
EmptyContent,
|
||||
#[error("no chunks present")]
|
||||
NoChunks,
|
||||
}
|
||||
|
||||
pub struct ChatCompletionStreamResponseServerEvents {
|
||||
pub events: Vec<ChatCompletionStreamResponse>,
|
||||
}
|
||||
|
||||
impl Display for ChatCompletionStreamResponseServerEvents {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let tokens_str = self
|
||||
.events
|
||||
.iter()
|
||||
.map(|response_chunk| {
|
||||
if response_chunk.choices.is_empty() {
|
||||
return "".to_string();
|
||||
}
|
||||
response_chunk.choices[0]
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or("".to_string())
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("");
|
||||
|
||||
write!(f, "{}", tokens_str)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents {
|
||||
type Error = ChatCompletionChunkResponseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
let response_chunks: VecDeque<ChatCompletionStreamResponse> = value
|
||||
.lines()
|
||||
.filter(|line| line.starts_with("data: "))
|
||||
.map(|line| line.get(6..).unwrap())
|
||||
.filter(|data_chunk| *data_chunk != "[DONE]")
|
||||
.map(serde_json::from_str::<ChatCompletionStreamResponse>)
|
||||
.collect::<Result<VecDeque<ChatCompletionStreamResponse>, _>>()?;
|
||||
|
||||
Ok(ChatCompletionStreamResponseServerEvents {
|
||||
events: response_chunks.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub delta: Delta,
|
||||
// TODO: could this be an enum?
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Delta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
|
||||
let mut response_str = String::new();
|
||||
for chunk in chunks.iter() {
|
||||
response_str.push_str("data: ");
|
||||
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
|
||||
response_str.push_str("\n\n");
|
||||
}
|
||||
response_str
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{ChatCompletionStreamResponseServerEvents, Message};
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
|
||||
const TOOL_SERIALIZED: &str = r#"{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What city do you want to know the weather for?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"description": "function to retrieve weather forecast",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "str",
|
||||
"description": "city for weather forecast",
|
||||
"default": "test"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"city"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"stream_options": {
|
||||
"include_usage": true
|
||||
}
|
||||
}"#;
|
||||
|
||||
#[test]
|
||||
fn test_tool_type_request() {
|
||||
use super::{
|
||||
ChatCompletionTool, ChatCompletionsRequest, FunctionDefinition, FunctionParameter,
|
||||
FunctionParameters, ParameterType, StreamOptions, ToolType,
|
||||
};
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert(
|
||||
"city".to_string(),
|
||||
FunctionParameter {
|
||||
parameter_type: ParameterType::String,
|
||||
description: "city for weather forecast".to_string(),
|
||||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
let function_definition = FunctionDefinition {
|
||||
name: "weather_forecast".to_string(),
|
||||
description: "function to retrieve weather forecast".to_string(),
|
||||
parameters: FunctionParameters { properties },
|
||||
};
|
||||
|
||||
let chat_completions_request = ChatCompletionsRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What city do you want to know the weather for?".to_string()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
tools: Some(vec![ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: function_definition,
|
||||
}]),
|
||||
stream: true,
|
||||
stream_options: Some(StreamOptions {
|
||||
include_usage: true,
|
||||
}),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
|
||||
println!("{}", serialized);
|
||||
assert_eq!(TOOL_SERIALIZED, serialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameter_types() {
|
||||
use super::{FunctionParameter, ParameterType};
|
||||
|
||||
const PARAMETER_SERIALZIED: &str = r#"{
|
||||
"city": {
|
||||
"type": "str",
|
||||
"description": "city for weather forecast",
|
||||
"default": "test"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let properties = HashMap::from([(
|
||||
"city".to_string(),
|
||||
FunctionParameter {
|
||||
parameter_type: ParameterType::String,
|
||||
description: "city for weather forecast".to_string(),
|
||||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
},
|
||||
)]);
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&properties).unwrap();
|
||||
assert_eq!(PARAMETER_SERIALZIED, serialized);
|
||||
|
||||
// ensure that if type is missing it is set to string
|
||||
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
|
||||
{
|
||||
"city": {
|
||||
"description": "city for weather forecast"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let missing_type_deserialized: HashMap<String, FunctionParameter> =
|
||||
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
|
||||
println!("{:?}", missing_type_deserialized);
|
||||
assert_eq!(
|
||||
missing_type_deserialized
|
||||
.get("city")
|
||||
.unwrap()
|
||||
.parameter_type,
|
||||
ParameterType::String
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
"#;
|
||||
|
||||
let sever_events =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 5);
|
||||
assert_eq!(
|
||||
sever_events.events[0].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
""
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[1].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"Hello"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[2].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"!"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[3].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" How"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[4].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" can"
|
||||
);
|
||||
assert_eq!(sever_events.to_string(), "Hello! How can");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_done() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let sever_events: ChatCompletionStreamResponseServerEvents =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 6);
|
||||
assert_eq!(
|
||||
sever_events.events[0].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" I"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[1].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" assist"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[2].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" you"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[3].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" today"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[4].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"?"
|
||||
);
|
||||
assert_eq!(sever_events.events[5].choices[0].delta.content, None);
|
||||
|
||||
assert_eq!(sever_events.to_string(), " I assist you today?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_mistral() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let sever_events: ChatCompletionStreamResponseServerEvents =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 11);
|
||||
|
||||
assert_eq!(
|
||||
sever_events.to_string(),
|
||||
"Hello! How can I assist you today?"
|
||||
);
|
||||
}
|
||||
}
|
||||
25
crates/common/src/api/prompt_guard.rs
Normal file
25
crates/common/src/api/prompt_guard.rs
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PromptGuardTask {
|
||||
#[serde(rename = "jailbreak")]
|
||||
Jailbreak,
|
||||
#[serde(rename = "toxicity")]
|
||||
Toxicity,
|
||||
#[serde(rename = "both")]
|
||||
Both,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptGuardRequest {
|
||||
pub input: String,
|
||||
pub task: PromptGuardTask,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptGuardResponse {
|
||||
pub toxic_prob: Option<f64>,
|
||||
pub jailbreak_prob: Option<f64>,
|
||||
pub toxic_verdict: Option<bool>,
|
||||
pub jailbreak_verdict: Option<bool>,
|
||||
}
|
||||
18
crates/common/src/api/zero_shot.rs
Normal file
18
crates/common/src/api/zero_shot.rs
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationRequest {
|
||||
pub input: String,
|
||||
pub labels: Vec<String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationResponse {
|
||||
pub predicted_class: String,
|
||||
pub predicted_class_score: f64,
|
||||
pub scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
|
@ -1,743 +0,0 @@
|
|||
use crate::configuration::PromptTarget;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub prompt_target: PromptTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub enum EmbeddingType {
|
||||
Name,
|
||||
Description,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorPoint {
|
||||
pub id: String,
|
||||
pub payload: HashMap<String, String>,
|
||||
pub vector: Vec<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StoreVectorEmbeddingsRequest {
|
||||
pub points: Vec<VectorPoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointResult {
|
||||
pub id: String,
|
||||
pub version: i32,
|
||||
pub score: f64,
|
||||
pub payload: HashMap<String, String>,
|
||||
}
|
||||
|
||||
pub mod open_ai {
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt::Display,
|
||||
};
|
||||
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
|
||||
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ChatCompletionTool>>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ToolType {
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionTool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
pub properties: HashMap<String, FunctionParameter>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameters {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
// select all requried parameters
|
||||
let required: Vec<&String> = self
|
||||
.properties
|
||||
.iter()
|
||||
.filter(|(_, v)| v.required.unwrap_or(false))
|
||||
.map(|(k, _)| k)
|
||||
.collect();
|
||||
let mut map = serializer.serialize_map(Some(2))?;
|
||||
map.serialize_entry("properties", &self.properties)?;
|
||||
if !required.is_empty() {
|
||||
map.serialize_entry("required", &required)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameter {
|
||||
#[serde(rename = "type")]
|
||||
#[serde(default = "ParameterType::string")]
|
||||
pub parameter_type: ParameterType,
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "enum")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<String>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameter {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_map(Some(5))?;
|
||||
map.serialize_entry("type", &self.parameter_type)?;
|
||||
map.serialize_entry("description", &self.description)?;
|
||||
if let Some(enum_values) = &self.enum_values {
|
||||
map.serialize_entry("enum", enum_values)?;
|
||||
}
|
||||
if let Some(default) = &self.default {
|
||||
map.serialize_entry("default", default)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum ParameterType {
|
||||
#[serde(rename = "int")]
|
||||
Int,
|
||||
#[serde(rename = "float")]
|
||||
Float,
|
||||
#[serde(rename = "bool")]
|
||||
Bool,
|
||||
#[serde(rename = "str")]
|
||||
String,
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
#[serde(rename = "dict")]
|
||||
Dict,
|
||||
}
|
||||
|
||||
impl From<String> for ParameterType {
|
||||
fn from(s: String) -> Self {
|
||||
match s.as_str() {
|
||||
"int" => ParameterType::Int,
|
||||
"integer" => ParameterType::Int,
|
||||
"float" => ParameterType::Float,
|
||||
"bool" => ParameterType::Bool,
|
||||
"boolean" => ParameterType::Bool,
|
||||
"str" => ParameterType::String,
|
||||
"string" => ParameterType::String,
|
||||
"list" => ParameterType::List,
|
||||
"array" => ParameterType::List,
|
||||
"dict" => ParameterType::Dict,
|
||||
"dictionary" => ParameterType::Dict,
|
||||
_ => ParameterType::String,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ParameterType {
|
||||
pub fn string() -> ParameterType {
|
||||
ParameterType::String
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamOptions {
|
||||
pub include_usage: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub finish_reason: String,
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionCallDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDetail {
|
||||
pub name: String,
|
||||
pub arguments: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ToolCallState {
|
||||
pub key: String,
|
||||
pub message: Option<Message>,
|
||||
pub tool_call: FunctionCallDetail,
|
||||
pub tool_response: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ArchState {
|
||||
ToolCall(Vec<ToolCallState>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub usage: Option<Usage>,
|
||||
pub choices: Vec<Choice>,
|
||||
pub model: String,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl ChatCompletionsResponse {
|
||||
pub fn new(message: String) -> Self {
|
||||
ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: Some(message),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
index: 0,
|
||||
finish_reason: "done".to_string(),
|
||||
}],
|
||||
usage: None,
|
||||
model: ARCH_FC_MODEL_NAME.to_string(),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub completion_tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
impl ChatCompletionStreamResponse {
|
||||
pub fn new(
|
||||
response: Option<String>,
|
||||
role: Option<String>,
|
||||
model: Option<String>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
) -> Self {
|
||||
ChatCompletionStreamResponse {
|
||||
model,
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role,
|
||||
content: response,
|
||||
tool_calls,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChatCompletionChunkResponseError {
|
||||
#[error("failed to deserialize")]
|
||||
Deserialization(#[from] serde_json::Error),
|
||||
#[error("empty content in data chunk")]
|
||||
EmptyContent,
|
||||
#[error("no chunks present")]
|
||||
NoChunks,
|
||||
}
|
||||
|
||||
pub struct ChatCompletionStreamResponseServerEvents {
|
||||
pub events: Vec<ChatCompletionStreamResponse>,
|
||||
}
|
||||
|
||||
impl Display for ChatCompletionStreamResponseServerEvents {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let tokens_str = self
|
||||
.events
|
||||
.iter()
|
||||
.map(|response_chunk| {
|
||||
if response_chunk.choices.is_empty() {
|
||||
return "".to_string();
|
||||
}
|
||||
response_chunk.choices[0]
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or("".to_string())
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("");
|
||||
|
||||
write!(f, "{}", tokens_str)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents {
|
||||
type Error = ChatCompletionChunkResponseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
let response_chunks: VecDeque<ChatCompletionStreamResponse> = value
|
||||
.lines()
|
||||
.filter(|line| line.starts_with("data: "))
|
||||
.map(|line| line.get(6..).unwrap())
|
||||
.filter(|data_chunk| *data_chunk != "[DONE]")
|
||||
.map(serde_json::from_str::<ChatCompletionStreamResponse>)
|
||||
.collect::<Result<VecDeque<ChatCompletionStreamResponse>, _>>()?;
|
||||
|
||||
Ok(ChatCompletionStreamResponseServerEvents {
|
||||
events: response_chunks.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub delta: Delta,
|
||||
// TODO: could this be an enum?
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Delta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
|
||||
let mut response_str = String::new();
|
||||
for chunk in chunks.iter() {
|
||||
response_str.push_str("data: ");
|
||||
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
|
||||
response_str.push_str("\n\n");
|
||||
}
|
||||
response_str
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationRequest {
|
||||
pub input: String,
|
||||
pub labels: Vec<String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationResponse {
|
||||
pub predicted_class: String,
|
||||
pub predicted_class_score: f64,
|
||||
pub scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationRequest {
|
||||
pub prompt: String,
|
||||
pub parameters: HashMap<String, String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationResponse {
|
||||
pub params_scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PromptGuardTask {
|
||||
#[serde(rename = "jailbreak")]
|
||||
Jailbreak,
|
||||
#[serde(rename = "toxicity")]
|
||||
Toxicity,
|
||||
#[serde(rename = "both")]
|
||||
Both,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptGuardRequest {
|
||||
pub input: String,
|
||||
pub task: PromptGuardTask,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptGuardResponse {
|
||||
pub toxic_prob: Option<f64>,
|
||||
pub jailbreak_prob: Option<f64>,
|
||||
pub toxic_verdict: Option<bool>,
|
||||
pub jailbreak_verdict: Option<bool>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::common_types::open_ai::{ChatCompletionStreamResponseServerEvents, Message};
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
|
||||
const TOOL_SERIALIZED: &str = r#"{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What city do you want to know the weather for?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"description": "function to retrieve weather forecast",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "str",
|
||||
"description": "city for weather forecast",
|
||||
"default": "test"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"city"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"stream_options": {
|
||||
"include_usage": true
|
||||
}
|
||||
}"#;
|
||||
|
||||
#[test]
|
||||
fn test_tool_type_request() {
|
||||
use super::open_ai::{
|
||||
ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType,
|
||||
};
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert(
|
||||
"city".to_string(),
|
||||
FunctionParameter {
|
||||
parameter_type: ParameterType::String,
|
||||
description: "city for weather forecast".to_string(),
|
||||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
let function_definition = FunctionDefinition {
|
||||
name: "weather_forecast".to_string(),
|
||||
description: "function to retrieve weather forecast".to_string(),
|
||||
parameters: super::open_ai::FunctionParameters { properties },
|
||||
};
|
||||
|
||||
let chat_completions_request = ChatCompletionsRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What city do you want to know the weather for?".to_string()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
tools: Some(vec![super::open_ai::ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: function_definition,
|
||||
}]),
|
||||
stream: true,
|
||||
stream_options: Some(super::open_ai::StreamOptions {
|
||||
include_usage: true,
|
||||
}),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
|
||||
println!("{}", serialized);
|
||||
assert_eq!(TOOL_SERIALIZED, serialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameter_types() {
|
||||
use super::open_ai::{FunctionParameter, ParameterType};
|
||||
|
||||
const PARAMETER_SERIALZIED: &str = r#"{
|
||||
"city": {
|
||||
"type": "str",
|
||||
"description": "city for weather forecast",
|
||||
"default": "test"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let properties = HashMap::from([(
|
||||
"city".to_string(),
|
||||
FunctionParameter {
|
||||
parameter_type: ParameterType::String,
|
||||
description: "city for weather forecast".to_string(),
|
||||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
},
|
||||
)]);
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&properties).unwrap();
|
||||
assert_eq!(PARAMETER_SERIALZIED, serialized);
|
||||
|
||||
// ensure that if type is missing it is set to string
|
||||
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
|
||||
{
|
||||
"city": {
|
||||
"description": "city for weather forecast"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let missing_type_deserialized: HashMap<String, FunctionParameter> =
|
||||
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
|
||||
println!("{:?}", missing_type_deserialized);
|
||||
assert_eq!(
|
||||
missing_type_deserialized
|
||||
.get("city")
|
||||
.unwrap()
|
||||
.parameter_type,
|
||||
ParameterType::String
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
"#;
|
||||
|
||||
let sever_events =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 5);
|
||||
assert_eq!(
|
||||
sever_events.events[0].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
""
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[1].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"Hello"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[2].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"!"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[3].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" How"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[4].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" can"
|
||||
);
|
||||
assert_eq!(sever_events.to_string(), "Hello! How can");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_done() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let sever_events: ChatCompletionStreamResponseServerEvents =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 6);
|
||||
assert_eq!(
|
||||
sever_events.events[0].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" I"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[1].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" assist"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[2].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" you"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[3].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" today"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[4].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"?"
|
||||
);
|
||||
assert_eq!(sever_events.events[5].choices[0].delta.content, None);
|
||||
|
||||
assert_eq!(sever_events.to_string(), " I assist you today?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_mistral() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let sever_events: ChatCompletionStreamResponseServerEvents =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 11);
|
||||
|
||||
assert_eq!(
|
||||
sever_events.to_string(),
|
||||
"Hello! How can I assist you today?"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,22 @@ use serde::{Deserialize, Serialize};
|
|||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
pub listener: Listener,
|
||||
pub endpoints: Option<HashMap<String, Endpoint>>,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
pub prompt_targets: Option<Vec<PromptTarget>>,
|
||||
pub error_target: Option<ErrorTargetDetail>,
|
||||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Overrides {
|
||||
pub prompt_target_intent_matching_threshold: Option<f64>,
|
||||
|
|
@ -22,22 +38,6 @@ pub enum GatewayMode {
|
|||
Prompt,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
pub listener: Listener,
|
||||
pub endpoints: Option<HashMap<String, Endpoint>>,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
pub prompt_targets: Option<Vec<PromptTarget>>,
|
||||
pub error_target: Option<ErrorTargetDetail>,
|
||||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorTargetDetail {
|
||||
pub endpoint: Option<EndpointDetails>,
|
||||
|
|
@ -179,8 +179,6 @@ impl Display for LlmProvider {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Endpoint {
|
||||
pub endpoint: Option<String>,
|
||||
// pub connect_timeout: Option<DurationString>,
|
||||
// pub timeout: Option<DurationString>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -193,12 +191,33 @@ pub struct Parameter {
|
|||
#[serde(rename = "enum")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
pub default: Option<String>,
|
||||
pub in_path: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
|
||||
pub enum HttpMethod {
|
||||
#[default]
|
||||
#[serde(rename = "GET")]
|
||||
Get,
|
||||
#[serde(rename = "POST")]
|
||||
Post,
|
||||
}
|
||||
|
||||
impl Display for HttpMethod {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
HttpMethod::Get => write!(f, "GET"),
|
||||
HttpMethod::Post => write!(f, "POST"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EndpointDetails {
|
||||
pub name: String,
|
||||
pub path: Option<String>,
|
||||
#[serde(rename = "http_method")]
|
||||
pub method: Option<HttpMethod>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use proxy_wasm::types::Status;
|
||||
|
||||
use crate::{common_types::open_ai::ChatCompletionChunkResponseError, ratelimit};
|
||||
use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ClientError {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
pub mod common_types;
|
||||
pub mod api;
|
||||
pub mod configuration;
|
||||
pub mod consts;
|
||||
pub mod embeddings;
|
||||
|
|
@ -11,3 +11,4 @@ pub mod routing;
|
|||
pub mod stats;
|
||||
pub mod tokenizer;
|
||||
pub mod tracing;
|
||||
pub mod path;
|
||||
|
|
|
|||
82
crates/common/src/path.rs
Normal file
82
crates/common/src/path.rs
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> Result<String, String> {
|
||||
let mut result = String::new();
|
||||
let mut in_param = false;
|
||||
let mut current_param = String::new();
|
||||
|
||||
for c in path.chars() {
|
||||
if c == '{' {
|
||||
in_param = true;
|
||||
} else if c == '}' {
|
||||
in_param = false;
|
||||
let param_name = current_param.clone();
|
||||
if let Some(value) = params.get(¶m_name) {
|
||||
result.push_str(value);
|
||||
} else {
|
||||
return Err(format!("Missing value for parameter `{}`", param_name));
|
||||
}
|
||||
current_param.clear();
|
||||
} else {
|
||||
if in_param {
|
||||
current_param.push(c);
|
||||
} else {
|
||||
result.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
#[test]
|
||||
fn test_replace_path() {
|
||||
let path = "/cluster.open-cluster-management.io/v1/managedclusters/{cluster_name}";
|
||||
let params = vec![("cluster_name".to_string(), "test1".to_string())]
|
||||
.into_iter()
|
||||
.collect();
|
||||
assert_eq!(
|
||||
super::replace_params_in_path(path, ¶ms),
|
||||
Ok("/cluster.open-cluster-management.io/v1/managedclusters/test1".to_string())
|
||||
);
|
||||
|
||||
let path = "/cluster.open-cluster-management.io/v1/managedclusters";
|
||||
let params = vec![].into_iter().collect();
|
||||
assert_eq!(
|
||||
super::replace_params_in_path(path, ¶ms),
|
||||
Ok("/cluster.open-cluster-management.io/v1/managedclusters".to_string())
|
||||
);
|
||||
|
||||
let path = "/foo/{bar}/baz";
|
||||
let params = vec![("bar".to_string(), "qux".to_string())]
|
||||
.into_iter()
|
||||
.collect();
|
||||
assert_eq!(
|
||||
super::replace_params_in_path(path, ¶ms),
|
||||
Ok("/foo/qux/baz".to_string())
|
||||
);
|
||||
|
||||
let path = "/foo/{bar}/baz/{qux}";
|
||||
let params = vec![
|
||||
("bar".to_string(), "qux".to_string()),
|
||||
("qux".to_string(), "quux".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect();
|
||||
assert_eq!(
|
||||
super::replace_params_in_path(path, ¶ms),
|
||||
Ok("/foo/qux/baz/quux".to_string())
|
||||
);
|
||||
|
||||
let path = "/foo/{bar}/baz/{qux}";
|
||||
let params = vec![("bar".to_string(), "qux".to_string())]
|
||||
.into_iter()
|
||||
.collect();
|
||||
assert_eq!(
|
||||
super::replace_params_in_path(path, ¶ms),
|
||||
Err("Missing value for parameter `qux`".to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::metrics::Metrics;
|
||||
use common::common_types::open_ai::{
|
||||
use common::api::open_ai::{
|
||||
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
|
||||
Message, StreamOptions,
|
||||
};
|
||||
|
|
|
|||
5
crates/prompt_gateway/src/embeddings.rs
Normal file
5
crates/prompt_gateway/src/embeddings.rs
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub enum EmbeddingType {
|
||||
Name,
|
||||
Description,
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
use crate::embeddings::EmbeddingType;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::stream_context::StreamContext;
|
||||
use common::common_types::EmbeddingType;
|
||||
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
|
||||
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
|
||||
use common::consts::DEFAULT_EMBEDDING_MODEL;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext};
|
||||
use common::{
|
||||
common_types::{
|
||||
open_ai::{
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||
},
|
||||
PromptGuardRequest, PromptGuardTask,
|
||||
api::{
|
||||
open_ai::{self, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest},
|
||||
prompt_guard::{PromptGuardRequest, PromptGuardTask},
|
||||
},
|
||||
consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
|
||||
|
|
@ -324,7 +322,7 @@ impl HttpContext for StreamContext {
|
|||
),
|
||||
];
|
||||
|
||||
let mut response_str = to_server_events(chunks);
|
||||
let mut response_str = open_ai::to_server_events(chunks);
|
||||
// append the original response from the model to the stream
|
||||
response_str.push_str(&body_utf8);
|
||||
self.set_http_response_body(0, body_size, response_str.as_bytes());
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ use proxy_wasm::traits::*;
|
|||
use proxy_wasm::types::*;
|
||||
|
||||
mod context;
|
||||
mod embeddings;
|
||||
mod filter_context;
|
||||
mod hallucination;
|
||||
mod http_context;
|
||||
mod metrics;
|
||||
mod stream_context;
|
||||
|
|
|
|||
|
|
@ -1,16 +1,18 @@
|
|||
use crate::embeddings::EmbeddingType;
|
||||
use crate::filter_context::EmbeddingsStore;
|
||||
use crate::hallucination::extract_messages_for_hallucination;
|
||||
use crate::metrics::Metrics;
|
||||
use acap::cos;
|
||||
use common::common_types::open_ai::{
|
||||
use common::api::hallucination::{
|
||||
extract_messages_for_hallucination, HallucinationClassificationRequest,
|
||||
HallucinationClassificationResponse,
|
||||
};
|
||||
use common::api::open_ai::{
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool,
|
||||
ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
|
||||
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
|
||||
};
|
||||
use common::common_types::{
|
||||
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
|
||||
PromptGuardResponse, ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
};
|
||||
use common::api::prompt_guard::PromptGuardResponse;
|
||||
use common::api::zero_shot::{ZeroShotClassificationRequest, ZeroShotClassificationResponse};
|
||||
use common::configuration::{Overrides, PromptGuards, PromptTarget, Tracing};
|
||||
use common::consts::{
|
||||
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
|
||||
|
|
@ -30,6 +32,7 @@ use derivative::Derivative;
|
|||
use http::StatusCode;
|
||||
use log::{debug, info, trace, warn};
|
||||
use proxy_wasm::traits::*;
|
||||
use serde_yaml::Value;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
|
@ -893,9 +896,37 @@ impl StreamContext {
|
|||
let endpoint = prompt_target.endpoint.unwrap();
|
||||
let path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
// only add params that are of string, number and bool type
|
||||
let url_params = tool_params
|
||||
.iter()
|
||||
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
|
||||
.map(|(key, value)| match value {
|
||||
Value::Number(n) => (key.clone(), n.to_string()),
|
||||
Value::String(s) => (key.clone(), s.clone()),
|
||||
Value::Bool(b) => (key.clone(), b.to_string()),
|
||||
Value::Null => todo!(),
|
||||
Value::Sequence(_) => todo!(),
|
||||
Value::Mapping(_) => todo!(),
|
||||
Value::Tagged(_) => todo!(),
|
||||
})
|
||||
.collect::<HashMap<String, String>>();
|
||||
|
||||
let path = match common::path::replace_params_in_path(&path, &url_params) {
|
||||
Ok(path) => path,
|
||||
Err(e) => {
|
||||
return self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("error replacing params in path: {}", e),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let http_method = endpoint.method.unwrap_or_default().to_string();
|
||||
let mut headers = vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()),
|
||||
(":method", "POST"),
|
||||
(":method", &http_method),
|
||||
(":path", &path),
|
||||
(":authority", endpoint.name.as_str()),
|
||||
("content-type", "application/json"),
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
|
||||
use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
|
||||
use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
|
||||
use common::api::hallucination::HallucinationClassificationResponse;
|
||||
use common::api::open_ai::{
|
||||
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
|
||||
};
|
||||
use common::api::prompt_guard::PromptGuardResponse;
|
||||
use common::api::zero_shot::ZeroShotClassificationResponse;
|
||||
use common::configuration::Configuration;
|
||||
use common::embeddings::{
|
||||
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
|
||||
Embedding,
|
||||
};
|
||||
use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
|
||||
use http::StatusCode;
|
||||
use proxy_wasm_test_framework::tester::{self, Tester};
|
||||
use proxy_wasm_test_framework::types::{
|
||||
|
|
@ -360,6 +363,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: api_server
|
||||
path: /weather
|
||||
http_method: POST
|
||||
system_prompt: |
|
||||
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
||||
- Use farenheight for temperature
|
||||
|
|
|
|||
1
demos/currency_exchange/README.md
Normal file
1
demos/currency_exchange/README.md
Normal file
|
|
@ -0,0 +1 @@
|
|||
This demo shows how you can use a publicly hosted rest api and interact it using arch gateway.
|
||||
52
demos/currency_exchange/arch_config.yaml
Normal file
52
demos/currency_exchange/arch_config.yaml
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
version: v0.1
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-4o
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
|
||||
|
||||
prompt_targets:
|
||||
- name: currency_exchange
|
||||
description: Get currency exchange rate from USD to other currencies
|
||||
parameters:
|
||||
- name: currency_symbol
|
||||
description: the currency that needs conversion
|
||||
required: true
|
||||
type: str
|
||||
in_path: true
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/latest?base=USD&symbols={currency_symbol}
|
||||
system_prompt: |
|
||||
You are a helpful assistant. Show me the currency symbol you want to convert from USD.
|
||||
|
||||
- name: get_supported_currencies
|
||||
description: Get list of supported currencies for conversion
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/currencies
|
||||
|
||||
endpoints:
|
||||
frankfurther_api:
|
||||
endpoint: api.frankfurter.dev:443
|
||||
protocol: https
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
trace_arch_internal: true
|
||||
21
demos/currency_exchange/docker-compose.yaml
Normal file
21
demos/currency_exchange/docker-compose.yaml
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
services:
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../shared/chatbot_ui
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
# this is only because we are running the sample app in the same docker container environemtn as archgw
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
|
||||
jaeger:
|
||||
build:
|
||||
context: ../shared/jaeger
|
||||
ports:
|
||||
- "16686:16686"
|
||||
- "4317:4317"
|
||||
- "4318:4318"
|
||||
|
|
@ -34,6 +34,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/workforce
|
||||
http_method: POST
|
||||
parameters:
|
||||
- name: staffing_type
|
||||
type: str
|
||||
|
|
@ -52,6 +53,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/slack_message
|
||||
http_method: POST
|
||||
description: sends a slack message on a channel
|
||||
parameters:
|
||||
- name: slack_message
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /policy/qa
|
||||
http_method: POST
|
||||
description: Handle general Q/A related to insurance.
|
||||
default: true
|
||||
|
||||
|
|
@ -37,6 +38,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /policy/coverage
|
||||
http_method: POST
|
||||
parameters:
|
||||
- name: policy_type
|
||||
type: str
|
||||
|
|
@ -48,6 +50,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /policy/initiate
|
||||
http_method: POST
|
||||
description: Start a policy coverage for car, boat, motorcycle or house.
|
||||
parameters:
|
||||
- name: policy_type
|
||||
|
|
@ -64,6 +67,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /policy/claim
|
||||
http_method: POST
|
||||
description: Update the notes on the claim
|
||||
parameters:
|
||||
- name: claim_id
|
||||
|
|
@ -79,6 +83,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /policy/deductible
|
||||
http_method: POST
|
||||
description: Update the deductible amount for a specific policy coverage.
|
||||
parameters:
|
||||
- name: policy_id
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_summary
|
||||
http_method: POST
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
|
|
@ -36,6 +37,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_reboot
|
||||
http_method: POST
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
|
|
|
|||
5
demos/shared/logfire/Dockerfile
Normal file
5
demos/shared/logfire/Dockerfile
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
FROM otel/opentelemetry-collector:latest
|
||||
|
||||
COPY otel-collector-config.yaml /etc/otel-collector-config.yaml
|
||||
|
||||
ENTRYPOINT ["/otelcol", "--config=/etc/otel-collector-config.yaml"]
|
||||
24
demos/shared/logfire/otel-collector-config.yaml
Normal file
24
demos/shared/logfire/otel-collector-config.yaml
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
receivers:
|
||||
otlp:
|
||||
protocols:
|
||||
grpc:
|
||||
endpoint: 0.0.0.0:4317
|
||||
http:
|
||||
endpoint: 0.0.0.0:4318
|
||||
|
||||
exporters:
|
||||
otlphttp:
|
||||
endpoint: "https://logfire-api.pydantic.dev"
|
||||
headers:
|
||||
Authorization: "${LOGFIRE_API_KEY}"
|
||||
|
||||
processors:
|
||||
batch:
|
||||
timeout: 5s
|
||||
|
||||
service:
|
||||
pipelines:
|
||||
traces:
|
||||
receivers: [otlp]
|
||||
processors: [batch]
|
||||
exporters: [otlphttp]
|
||||
|
|
@ -1,27 +1,63 @@
|
|||
# Function calling
|
||||
|
||||
This demo shows how you can use Arch's core function calling capabilites.
|
||||
|
||||
# Starting the demo
|
||||
|
||||
1. Please make sure the [pre-requisites](https://github.com/katanemo/arch/?tab=readme-ov-file#prerequisites) are installed correctly
|
||||
2. Start Arch
|
||||
|
||||
3.
|
||||
```sh
|
||||
3. ```sh
|
||||
sh run_demo.sh
|
||||
```
|
||||
4. Navigate to http://localhost:18080/
|
||||
5. You can type in queries like "how is the weather?"
|
||||
|
||||
# Observability
|
||||
|
||||
Arch gateway publishes stats endpoint at http://localhost:19901/stats. In this demo we are using prometheus to pull stats from arch and we are using grafana to visalize the stats in dashboard. To see grafana dashboard follow instructions below,
|
||||
|
||||
1. Start grafana and prometheus using following command
|
||||
```yaml
|
||||
docker compose --profile monitoring up
|
||||
```
|
||||
1. Navigate to http://localhost:3000/ to open grafana UI (use admin/grafana as credentials)
|
||||
1. From grafana left nav click on dashboards and select "Intelligent Gateway Overview" to view arch gateway stats
|
||||
|
||||
2. Navigate to http://localhost:3000/ to open grafana UI (use admin/grafana as credentials)
|
||||
3. From grafana left nav click on dashboards and select "Intelligent Gateway Overview" to view arch gateway stats
|
||||
|
||||
Here is a sample interaction,
|
||||
<img width="575" alt="image" src="https://github.com/user-attachments/assets/e0929490-3eb2-4130-ae87-a732aea4d059">
|
||||
|
||||
## Tracing
|
||||
|
||||
To see a tracing dashboard follow instructions below,
|
||||
|
||||
1. For Jaeger, you can either use the default run_demo.sh script or run the following command:
|
||||
|
||||
```sh
|
||||
sh run_demo.sh jaeger
|
||||
```
|
||||
|
||||
2. For Logfire, first make sure to add a LOGFIRE_API_KEY to the .env file. You can either use the default run_demo.sh script or run the following command:
|
||||
|
||||
```sh
|
||||
sh run_demo.sh logfire
|
||||
```
|
||||
|
||||
3. For Signoz, you can either use the default run_demo.sh script or run the following command:
|
||||
|
||||
```sh
|
||||
sh run_demo.sh signoz
|
||||
```
|
||||
|
||||
If using Jaeger, navigate to http://localhost:16686/ to open Jaeger UI
|
||||
|
||||
If using Signoz, navigate to http://localhost:3301/ to open Signoz UI
|
||||
|
||||
If using Logfire, navigate to your logfire dashboard that you got the write key from to view the dashboard
|
||||
|
||||
### Stopping Demo
|
||||
|
||||
1. To end the demo, run the following command:
|
||||
```sh
|
||||
sh run_demo.sh down
|
||||
```
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: weather_forecast_service
|
||||
path: /weather
|
||||
http_method: POST
|
||||
|
||||
- name: default_target
|
||||
default: true
|
||||
|
|
@ -67,6 +68,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: weather_forecast_service
|
||||
path: /default_target
|
||||
http_method: POST
|
||||
system_prompt: |
|
||||
You are a helpful assistant! Summarize the user's request and provide a helpful response.
|
||||
# if it is set to false arch will send response that it received from this prompt target to the user
|
||||
|
|
|
|||
41
demos/weather_forecast/docker-compose-jaeger.yaml
Normal file
41
demos/weather_forecast/docker-compose-jaeger.yaml
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
services:
|
||||
weather_forecast_service:
|
||||
build:
|
||||
context: ./
|
||||
environment:
|
||||
- OLTP_HOST=http://jaeger:4317
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
ports:
|
||||
- "18083:80"
|
||||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../shared/chatbot_ui
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
# this is only because we are running the sample app in the same docker container environemtn as archgw
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
|
||||
jaeger:
|
||||
build:
|
||||
context: ../shared/jaeger
|
||||
ports:
|
||||
- "16686:16686"
|
||||
- "4317:4317"
|
||||
- "4318:4318"
|
||||
|
||||
prometheus:
|
||||
build:
|
||||
context: ../shared/prometheus
|
||||
|
||||
grafana:
|
||||
build:
|
||||
context: ../shared/grafana
|
||||
ports:
|
||||
- "3000:3000"
|
||||
46
demos/weather_forecast/docker-compose-logfire.yaml
Normal file
46
demos/weather_forecast/docker-compose-logfire.yaml
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
services:
|
||||
weather_forecast_service:
|
||||
build:
|
||||
context: ./
|
||||
environment:
|
||||
- OLTP_HOST=http://otel-collector:4317
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
ports:
|
||||
- "18083:80"
|
||||
|
||||
chatbot_ui:
|
||||
build:
|
||||
context: ../shared/chatbot_ui
|
||||
ports:
|
||||
- "18080:8080"
|
||||
environment:
|
||||
# this is only because we are running the sample app in the same docker container environment as archgw
|
||||
- CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:10000/v1
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- ./arch_config.yaml:/app/arch_config.yaml
|
||||
|
||||
otel-collector:
|
||||
build:
|
||||
context: ../shared/logfire/
|
||||
ports:
|
||||
- "4317:4317"
|
||||
- "4318:4318"
|
||||
volumes:
|
||||
- ../shared/logfire/otel-collector-config.yaml:/etc/otel-collector-config.yaml
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- LOGFIRE_API_KEY
|
||||
|
||||
prometheus:
|
||||
build:
|
||||
context: ../shared/prometheus
|
||||
|
||||
grafana:
|
||||
build:
|
||||
context: ../shared/grafana
|
||||
ports:
|
||||
- "3000:3000"
|
||||
|
|
@ -4,7 +4,7 @@ include:
|
|||
services:
|
||||
weather_forecast_service:
|
||||
build:
|
||||
context: ../shared/weather_forecast_service
|
||||
context: .
|
||||
environment:
|
||||
- OLTP_HOST=http://otel-collector:4317
|
||||
extra_hosts:
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
services:
|
||||
weather_forecast_service:
|
||||
build:
|
||||
context: ../shared/weather_forecast_service
|
||||
context: ./
|
||||
environment:
|
||||
- OLTP_HOST=http://jaeger:4317
|
||||
extra_hosts:
|
||||
|
|
|
|||
|
|
@ -1,47 +1,95 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Function to load environment variables from the .env file
|
||||
load_env() {
|
||||
if [ -f ".env" ]; then
|
||||
export $(grep -v '^#' .env | xargs)
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to determine the docker-compose file based on the argument
|
||||
get_compose_file() {
|
||||
case "$1" in
|
||||
jaeger)
|
||||
echo "docker-compose-jaeger.yaml"
|
||||
;;
|
||||
logfire)
|
||||
echo "docker-compose-logfire.yaml"
|
||||
;;
|
||||
signoz)
|
||||
echo "docker-compose-signoz.yaml"
|
||||
;;
|
||||
*)
|
||||
echo "docker-compose.yaml"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Function to start the demo
|
||||
start_demo() {
|
||||
# Step 1: Check if .env file exists
|
||||
# Step 1: Determine the docker-compose file
|
||||
COMPOSE_FILE=$(get_compose_file "$1" 2>/dev/null)
|
||||
|
||||
# Step 2: Check if .env file exists
|
||||
if [ -f ".env" ]; then
|
||||
echo ".env file already exists. Skipping creation."
|
||||
else
|
||||
# Step 2: Create `.env` file and set OpenAI key
|
||||
# Step 3: Check for required environment variables
|
||||
if [ -z "$OPENAI_API_KEY" ]; then
|
||||
echo "Error: OPENAI_API_KEY environment variable is not set for the demo."
|
||||
exit 1
|
||||
fi
|
||||
if [ "$1" == "logfire" ] && [ -z "$LOGFIRE_API_KEY" ]; then
|
||||
echo "Error: LOGFIRE_API_KEY environment variable is required for Logfire."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create .env file
|
||||
echo "Creating .env file..."
|
||||
echo "OPENAI_API_KEY=$OPENAI_API_KEY" > .env
|
||||
echo ".env file created with OPENAI_API_KEY."
|
||||
if [ "$1" == "logfire" ]; then
|
||||
echo "LOGFIRE_API_KEY=$LOGFIRE_API_KEY" >> .env
|
||||
fi
|
||||
echo ".env file created with required API keys."
|
||||
fi
|
||||
|
||||
# Step 3: Start Arch
|
||||
load_env
|
||||
|
||||
if [ "$1" == "logfire" ] && [ -z "$LOGFIRE_API_KEY" ]; then
|
||||
echo "Error: LOGFIRE_API_KEY environment variable is required for Logfire."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Step 4: Start Arch
|
||||
echo "Starting Arch with arch_config.yaml..."
|
||||
archgw up arch_config.yaml
|
||||
|
||||
# Step 4: Start Network Agent
|
||||
echo "Starting Network Agent using Docker Compose..."
|
||||
docker compose up -d # Run in detached mode
|
||||
# Step 5: Start Network Agent with the chosen Docker Compose file
|
||||
echo "Starting Network Agent with $COMPOSE_FILE..."
|
||||
docker compose -f "$COMPOSE_FILE" up -d # Run in detached mode
|
||||
}
|
||||
|
||||
# Function to stop the demo
|
||||
stop_demo() {
|
||||
# Step 1: Stop Docker Compose services
|
||||
echo "Stopping Network Agent using Docker Compose..."
|
||||
docker compose down
|
||||
echo "Stopping all Docker Compose services..."
|
||||
|
||||
# Step 2: Stop Arch
|
||||
# Stop all services by iterating through all configurations
|
||||
for compose_file in ./docker-compose*.yaml; do
|
||||
echo "Stopping services in $compose_file..."
|
||||
docker compose -f "$compose_file" down
|
||||
done
|
||||
|
||||
# Stop Arch
|
||||
echo "Stopping Arch..."
|
||||
archgw down
|
||||
}
|
||||
|
||||
# Main script logic
|
||||
if [ "$1" == "down" ]; then
|
||||
# Call stop_demo with the second argument as the demo to stop
|
||||
stop_demo
|
||||
else
|
||||
# Default action is to bring the demo up
|
||||
start_demo
|
||||
# Use the argument (jaeger, logfire, signoz) to determine the compose file
|
||||
start_demo "$1"
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
# Function calling
|
||||
This demo shows how you can use Arch's core function calling capabilites.
|
||||
|
||||
# Starting the demo
|
||||
1. Please make sure the [pre-requisites](https://github.com/katanemo/arch/?tab=readme-ov-file#prerequisites) are installed correctly
|
||||
2. Start Arch
|
||||
|
||||
3.
|
||||
```sh
|
||||
sh run_demo.sh
|
||||
```
|
||||
4. Navigate to http://localhost:18080/
|
||||
5. You can type in queries like "how is the weather?"
|
||||
|
||||
# Observability
|
||||
Arch gateway publishes stats endpoint at http://localhost:19901/stats. In this demo we are using prometheus to pull stats from arch and we are using grafana to visalize the stats in dashboard. To see grafana dashboard follow instructions below,
|
||||
|
||||
1. Start grafana and prometheus using following command
|
||||
```yaml
|
||||
docker compose --profile monitoring up
|
||||
```
|
||||
1. Navigate to http://localhost:3000/ to open grafana UI (use admin/grafana as credentials)
|
||||
1. From grafana left nav click on dashboards and select "Intelligent Gateway Overview" to view arch gateway stats
|
||||
|
||||
|
||||
Here is a sample interaction,
|
||||
<img width="575" alt="image" src="https://github.com/user-attachments/assets/e0929490-3eb2-4130-ae87-a732aea4d059">
|
||||
|
||||
1. Signoz UI: http://localhost:3301
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
version: "0.1-beta"
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
endpoints:
|
||||
weather_forecast_service:
|
||||
endpoint: host.docker.internal:18083
|
||||
connect_timeout: 0.005s
|
||||
|
||||
overrides:
|
||||
# confidence threshold for prompt target intent matching
|
||||
prompt_target_intent_matching_threshold: 0.6
|
||||
|
||||
llm_providers:
|
||||
- name: gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-4o-mini
|
||||
default: true
|
||||
|
||||
- name: gpt-3.5-turbo-0125
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-3.5-turbo-0125
|
||||
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-4o
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_targets:
|
||||
- name: weather_forecast
|
||||
description: Check weather information for a given city.
|
||||
parameters:
|
||||
- name: city
|
||||
description: the name of the city
|
||||
required: true
|
||||
type: str
|
||||
- name: days
|
||||
description: the number of days
|
||||
type: int
|
||||
required: true
|
||||
- name: units
|
||||
description: the temperature unit, e.g., Celsius and Fahrenheit
|
||||
type: str
|
||||
default: Fahrenheit
|
||||
endpoint:
|
||||
name: weather_forecast_service
|
||||
path: /weather
|
||||
|
||||
- name: default_target
|
||||
default: true
|
||||
description: This is the default target for all unmatched prompts.
|
||||
endpoint:
|
||||
name: weather_forecast_service
|
||||
path: /default_target
|
||||
system_prompt: |
|
||||
You are a helpful assistant! Summarize the user's request and provide a helpful response.
|
||||
# if it is set to false arch will send response that it received from this prompt target to the user
|
||||
# if true arch will forward the response to the default LLM
|
||||
auto_llm_dispatch_on_response: false
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
# trace_arch: true
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
version: v0.1
|
||||
listener:
|
||||
address: 127.0.0.1
|
||||
port: 8080 #If you configure port 443, you'll need to update the listener with tls_certificates
|
||||
message_format: huggingface
|
||||
|
||||
# Centralized way to manage LLMs, manage keys, retry logic, failover and limits in a central way
|
||||
llm_providers:
|
||||
- name: OpenAI
|
||||
provider: openai
|
||||
access_key: $OPENAI_API_KEY
|
||||
model: gpt-3.5-turbo
|
||||
default: true
|
||||
|
||||
# default system prompt used by all prompt targets
|
||||
system_prompt: |
|
||||
You are a network assistant that helps operators with a better understanding of network traffic flow and perform actions on networking operations. No advice on manufacturers or purchasing decisions.
|
||||
|
||||
prompt_targets:
|
||||
- name: device_summary
|
||||
description: Retrieve network statistics for specific devices within a time range
|
||||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/device_summary
|
||||
parameters:
|
||||
- name: device_ids
|
||||
type: list
|
||||
description: A list of device identifiers (IDs) to retrieve statistics for.
|
||||
required: true # device_ids are required to get device statistics
|
||||
- name: days
|
||||
type: int
|
||||
description: The number of days for which to gather device statistics.
|
||||
default: "7"
|
||||
|
||||
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
|
||||
endpoints:
|
||||
app_server:
|
||||
# value could be ip address or a hostname with port
|
||||
# this could also be a list of endpoints for load balancing
|
||||
# for example endpoint: [ ip1:port, ip2:port ]
|
||||
endpoint: host.docker.internal:18083
|
||||
# max time to wait for a connection to be established
|
||||
connect_timeout: 0.005s
|
||||
|
|
@ -7,74 +7,255 @@ Follow this guide to learn how to quickly set up Arch and integrate it into your
|
|||
|
||||
|
||||
Prerequisites
|
||||
----------------------------
|
||||
-------------
|
||||
|
||||
Before you begin, ensure you have the following:
|
||||
|
||||
.. vale Vale.Spelling = NO
|
||||
1. `Docker System <https://docs.docker.com/get-started/get-docker/>`_ (v24)
|
||||
2. `Docker compose <https://docs.docker.com/compose/install/>`_ (v2.29)
|
||||
3. `Python <https://www.python.org/downloads/>`_ (v3.12)
|
||||
|
||||
- ``Docker`` & ``Python`` installed on your system
|
||||
- ``API Keys`` for LLM providers (if using external LLMs)
|
||||
|
||||
The fastest way to get started using Arch is to use `katanemo/archgw <https://hub.docker.com/r/katanemo/archgw>`_ pre-built binaries.
|
||||
You can also build it from source.
|
||||
|
||||
|
||||
Step 1: Install Arch
|
||||
----------------------------
|
||||
Arch's CLI allows you to manage and interact with the Arch gateway efficiently. To install the CLI, simply
|
||||
run the following command:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install archgw
|
||||
|
||||
This will install the archgw command-line tool globally on your system.
|
||||
Arch's CLI allows you to manage and interact with the Arch gateway efficiently. To install the CLI, simply run the following command:
|
||||
|
||||
.. tip::
|
||||
We recommend that developers create a new Python virtual environment to isolate dependencies before installing Arch.
|
||||
This ensures that `archgw` and its dependencies do not interfere with other packages on your system.
|
||||
|
||||
To create and activate a virtual environment, you can run the following commands:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m venv venv
|
||||
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
|
||||
$ pip install archgw
|
||||
|
||||
|
||||
Step 2: Config Arch
|
||||
-------------------
|
||||
|
||||
Arch operates based on a configuration file where you can define LLM providers, prompt targets, and guardrails, etc.
|
||||
Below is an example configuration to get you started, including:
|
||||
|
||||
.. vale Vale.Spelling = NO
|
||||
|
||||
- ``endpoints``: Specifies where Arch listens for incoming prompts.
|
||||
- ``system_prompts``: Defines predefined prompts to set the context for interactions.
|
||||
- ``llm_providers``: Lists the LLM providers Arch can route prompts to.
|
||||
- ``prompt_guards``: Sets up rules to detect and reject undesirable prompts.
|
||||
- ``prompt_targets``: Defines endpoints that handle specific types of prompts.
|
||||
- ``error_target``: Specifies where to route errors for handling.
|
||||
|
||||
.. literalinclude:: includes/quickstart.yaml
|
||||
:language: yaml
|
||||
|
||||
|
||||
Step 3: Start Arch Gateway
|
||||
--------------------------
|
||||
We recommend that developers create a new Python virtual environment to isolate dependencies before installing Arch. This ensures that ``archgw`` and its dependencies do not interfere with other packages on your system.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ archgw up [path_to_config]
|
||||
$ python -m venv venv
|
||||
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
|
||||
$ pip install archgw==0.1.6
|
||||
|
||||
For detailed usage please refer to the `Support <https://github.com/katanemo/arch/blob/main/arch/tools/README.md#setup-instructionsuser-archgw-cli>`
|
||||
|
||||
Build AI Agent with Arch Gateway
|
||||
--------------------------------
|
||||
|
||||
In the following quickstart, we will show you how easy it is to build an AI agent with the Arch gateway. We will build a currency exchange agent using the following simple steps. For this demo, we will use `https://api.frankfurter.dev/` to fetch the latest prices for currencies and assume USD as the base currency.
|
||||
|
||||
Step 1. Create arch config file
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Create ``arch_config.yaml`` file with the following content:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
version: v0.1
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-4o
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
||||
prompt_guards:
|
||||
input_guards:
|
||||
jailbreak:
|
||||
on_exception:
|
||||
message: Looks like you're curious about my abilities, but I can only provide assistance for currency exchange.
|
||||
|
||||
prompt_targets:
|
||||
- name: currency_exchange
|
||||
description: Get currency exchange rate from USD to other currencies
|
||||
parameters:
|
||||
- name: currency_symbol
|
||||
description: the currency that needs conversion
|
||||
required: true
|
||||
type: str
|
||||
in_path: true
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/latest?base=USD&symbols={currency_symbol}
|
||||
system_prompt: |
|
||||
You are a helpful assistant. Show me the currency symbol you want to convert from USD.
|
||||
|
||||
- name: get_supported_currencies
|
||||
description: Get list of supported currencies for conversion
|
||||
endpoint:
|
||||
name: frankfurther_api
|
||||
path: /v1/currencies
|
||||
|
||||
endpoints:
|
||||
frankfurther_api:
|
||||
endpoint: api.frankfurter.dev:443
|
||||
protocol: https
|
||||
|
||||
Step 2. Start arch gateway with currency conversion config
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. code-block:: sh
|
||||
|
||||
$ archgw up arch_config.yaml
|
||||
2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.1.5
|
||||
...
|
||||
2024-12-05 16:56:28,485 - cli.utils - INFO - Schema validation successful!
|
||||
2024-12-05 16:56:28,485 - cli.main - INFO - Starting arch model server and arch gateway
|
||||
...
|
||||
2024-12-05 16:56:51,647 - cli.core - INFO - Container is healthy!
|
||||
|
||||
Once the gateway is up, you can start interacting with it at port 10000 using the OpenAI chat completion API.
|
||||
|
||||
Some sample queries you can ask include: ``what is currency rate for gbp?`` or ``show me list of currencies for conversion``.
|
||||
|
||||
Step 3. Interacting with gateway using curl command
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Here is a sample curl command you can use to interact:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ curl --header 'Content-Type: application/json' \
|
||||
--data '{"messages": [{"role": "user","content": "what is exchange rate for gbp"}]}' \
|
||||
http://localhost:10000/v1/chat/completions | jq ".choices[0].message.content"
|
||||
|
||||
"As of the date provided in your context, December 5, 2024, the exchange rate for GBP (British Pound) from USD (United States Dollar) is 0.78558. This means that 1 USD is equivalent to 0.78558 GBP."
|
||||
|
||||
And to get the list of supported currencies:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ curl --header 'Content-Type: application/json' \
|
||||
--data '{"messages": [{"role": "user","content": "show me list of currencies that are supported for conversion"}]}' \
|
||||
http://localhost:10000/v1/chat/completions | jq ".choices[0].message.content"
|
||||
|
||||
"Here is a list of the currencies that are supported for conversion from USD, along with their symbols:\n\n1. AUD - Australian Dollar\n2. BGN - Bulgarian Lev\n3. BRL - Brazilian Real\n4. CAD - Canadian Dollar\n5. CHF - Swiss Franc\n6. CNY - Chinese Renminbi Yuan\n7. CZK - Czech Koruna\n8. DKK - Danish Krone\n9. EUR - Euro\n10. GBP - British Pound\n11. HKD - Hong Kong Dollar\n12. HUF - Hungarian Forint\n13. IDR - Indonesian Rupiah\n14. ILS - Israeli New Sheqel\n15. INR - Indian Rupee\n16. ISK - Icelandic Króna\n17. JPY - Japanese Yen\n18. KRW - South Korean Won\n19. MXN - Mexican Peso\n20. MYR - Malaysian Ringgit\n21. NOK - Norwegian Krone\n22. NZD - New Zealand Dollar\n23. PHP - Philippine Peso\n24. PLN - Polish Złoty\n25. RON - Romanian Leu\n26. SEK - Swedish Krona\n27. SGD - Singapore Dollar\n28. THB - Thai Baht\n29. TRY - Turkish Lira\n30. USD - United States Dollar\n31. ZAR - South African Rand\n\nIf you want to convert USD to any of these currencies, you can select the one you are interested in."
|
||||
|
||||
|
||||
Use Arch Gateway as LLM Router
|
||||
------------------------------
|
||||
|
||||
Step 1. Create arch config file
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Arch operates based on a configuration file where you can define LLM providers, prompt targets, guardrails, etc. Below is an example configuration that defines OpenAI and Mistral LLM providers.
|
||||
|
||||
Create ``arch_config.yaml`` file with the following content:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
version: v0.1
|
||||
|
||||
listener:
|
||||
address: 0.0.0.0
|
||||
port: 10000
|
||||
message_format: huggingface
|
||||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-4o
|
||||
default: true
|
||||
|
||||
- name: ministral-3b
|
||||
access_key: $MISTRAL_API_KEY
|
||||
provider: mistral
|
||||
model: ministral-3b-latest
|
||||
|
||||
Step 2. Start arch gateway
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Once the config file is created, ensure that you have environment variables set up for ``MISTRAL_API_KEY`` and ``OPENAI_API_KEY`` (or these are defined in a ``.env`` file).
|
||||
|
||||
Start the Arch gateway:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ archgw up arch_config.yaml
|
||||
2024-12-05 11:24:51,288 - cli.main - INFO - Starting archgw cli version: 0.1.5
|
||||
2024-12-05 11:24:51,825 - cli.utils - INFO - Schema validation successful!
|
||||
2024-12-05 11:24:51,825 - cli.main - INFO - Starting arch model server and arch gateway
|
||||
...
|
||||
2024-12-05 11:25:16,131 - cli.core - INFO - Container is healthy!
|
||||
|
||||
Step 3: Interact with LLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Step 3.1: Using OpenAI Python client
|
||||
++++++++++++++++++++++++++++++++++++
|
||||
|
||||
Make outbound calls via the Arch gateway:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# Use the OpenAI client as usual
|
||||
client = OpenAI(
|
||||
# No need to set a specific openai.api_key since it's configured in Arch's gateway
|
||||
api_key='--',
|
||||
# Set the OpenAI API base URL to the Arch gateway endpoint
|
||||
base_url="http://127.0.0.1:12000/v1"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
# we select model from arch_config file
|
||||
model="--",
|
||||
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
||||
)
|
||||
|
||||
print("OpenAI Response:", response.choices[0].message.content)
|
||||
|
||||
Step 3.2: Using curl command
|
||||
++++++++++++++++++++++++++++
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ curl --header 'Content-Type: application/json' \
|
||||
--data '{"messages": [{"role": "user","content": "What is the capital of France?"}]}' \
|
||||
http://localhost:12000/v1/chat/completions
|
||||
|
||||
{
|
||||
...
|
||||
"model": "gpt-4o-2024-08-06",
|
||||
"choices": [
|
||||
{
|
||||
...
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The capital of France is Paris.",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
You can override model selection using the ``x-arch-llm-provider-hint`` header. For example, to use Mistral, use the following curl command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ curl --header 'Content-Type: application/json' \
|
||||
--header 'x-arch-llm-provider-hint: ministral-3b' \
|
||||
--data '{"messages": [{"role": "user","content": "What is the capital of France?"}]}' \
|
||||
http://localhost:12000/v1/chat/completions
|
||||
|
||||
{
|
||||
...
|
||||
"model": "ministral-3b-latest",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The capital of France is Paris. It is the most populous city in France and is known for its iconic landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. Paris is also a major global center for art, fashion, gastronomy, and culture.",
|
||||
},
|
||||
...
|
||||
}
|
||||
],
|
||||
...
|
||||
}
|
||||
|
||||
|
||||
Next Steps
|
||||
-------------------
|
||||
==========
|
||||
|
||||
Congratulations! You've successfully set up Arch and made your first prompt-based request. To further enhance your GenAI applications, explore the following resources:
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/summary
|
||||
http_method: POST
|
||||
# Arch uses the default LLM and treats the response from the endpoint as the prompt to send to the LLM
|
||||
auto_llm_dispatch_on_response: true
|
||||
# override system prompt for this prompt target
|
||||
|
|
@ -41,6 +42,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/action
|
||||
http_method: POST
|
||||
parameters:
|
||||
- name: device_id
|
||||
type: str
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ prompt_targets:
|
|||
endpoint:
|
||||
name: app_server
|
||||
path: /agent/summary
|
||||
http_method: POST
|
||||
# Arch uses the default LLM and treats the response from the endpoint as the prompt to send to the LLM
|
||||
auto_llm_dispatch_on_response: true
|
||||
# override system prompt for this prompt target
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ cd -
|
|||
log building and installing archgw cli
|
||||
log ==================================
|
||||
cd ../arch/tools
|
||||
sh build_cli.sh
|
||||
poetry install
|
||||
cd -
|
||||
|
||||
log building docker image for arch gateway
|
||||
|
|
|
|||
|
|
@ -220,6 +220,12 @@ async def hallucination(req: HallucinationRequest, res: Response):
|
|||
if "messages" in req.parameters:
|
||||
req.parameters.pop("messages")
|
||||
|
||||
if not req.parameters or len(req.parameters) == 0:
|
||||
return {
|
||||
"params_scores": {},
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
|
||||
|
||||
predictions = classifier(
|
||||
|
|
|
|||
1444
model_server/poetry.lock
generated
1444
model_server/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "archgw_modelserver"
|
||||
version = "0.1.5"
|
||||
version = "0.1.6"
|
||||
description = "A model server for serving models"
|
||||
authors = ["Katanemo Labs, Inc <archgw@katanemo.com>"]
|
||||
license = "Apache 2.0"
|
||||
|
|
@ -29,7 +29,7 @@ openai = "1.50.2"
|
|||
tf-keras = "*"
|
||||
onnx = "1.17.0"
|
||||
onnxruntime = "1.19.2"
|
||||
httpx = "*"
|
||||
httpx = "0.27.2" # https://community.openai.com/t/typeerror-asyncclient-init-got-an-unexpected-keyword-argument-proxies/1040287
|
||||
pytest-asyncio = "*"
|
||||
pytest = "*"
|
||||
opentelemetry-api = "^1.28.0"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue