Merge branch 'main' into musa/www

This commit is contained in:
Musa 2025-12-18 14:53:56 -08:00
commit a6f9ca3594
189 changed files with 21252 additions and 14516 deletions

View file

@ -30,7 +30,7 @@ jobs:
- name: build arch docker image
run: |
cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.18 -t katanemo/archgw:latest
cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.22 -t katanemo/archgw:latest
- name: start archgw
env:

View file

@ -1,46 +0,0 @@
name: e2e model server tests
on:
push:
branches:
- main
pull_request:
jobs:
e2e_model_server_tests:
runs-on: ubuntu-latest-m
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
defaults:
run:
working-directory: ./tests/modelserver
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: "pip" # auto-caches based on requirements files
- name: install poetry
run: |
export POETRY_VERSION=2.2.1
curl -sSL https://install.python-poetry.org | python3 -
export PATH="$HOME/.local/bin:$PATH"
- name: install model server and start it
run: |
cd ../../model_server/ && poetry install && poetry run archgw_modelserver start
- name: install test dependencies
run: |
poetry install
- name: run tests
run: |
poetry run pytest

View file

@ -24,7 +24,7 @@ jobs:
- name: build arch docker image
run: |
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.18
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.22
- name: install poetry
run: |
@ -40,11 +40,10 @@ jobs:
curl --location --remote-name https://github.com/Orange-OpenSource/hurl/releases/download/4.0.0/hurl_4.0.0_amd64.deb
sudo dpkg -i hurl_4.0.0_amd64.deb
- name: install model server, arch gateway and test dependencies
- name: install arch gateway and test dependencies
run: |
source venv/bin/activate
cd model_server/ && echo "installing model server" && poetry install
cd ../arch/tools && echo "installing archgw cli" && poetry install
cd arch/tools && echo "installing archgw cli" && poetry install
cd ../../demos/shared/test_runner && echo "installing test dependencies" && poetry install
- name: run demo tests

View file

@ -24,7 +24,7 @@ jobs:
- name: build arch docker image
run: |
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.18
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.22
- name: install poetry
run: |
@ -40,11 +40,10 @@ jobs:
curl --location --remote-name https://github.com/Orange-OpenSource/hurl/releases/download/4.0.0/hurl_4.0.0_amd64.deb
sudo dpkg -i hurl_4.0.0_amd64.deb
- name: install model server, arch gateway and test dependencies
- name: install arch gateway and test dependencies
run: |
source venv/bin/activate
cd model_server/ && echo "installing model server" && poetry install
cd ../arch/tools && echo "installing archgw cli" && poetry install
cd arch/tools && echo "installing archgw cli" && poetry install
cd ../../demos/shared/test_runner && echo "installing test dependencies" && poetry install
- name: run demo tests

View file

@ -14,6 +14,32 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3
# --- Disk inspection & cleanup section (added to free space on GitHub runner) ---
- name: Check disk usage before cleanup
run: |
echo "=== Disk usage before cleanup ==="
df -h
echo "=== Repo size ==="
du -sh .
- name: Free disk space on runner
run: |
echo "=== Cleaning preinstalled SDKs and toolchains to free space ==="
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
# If you still hit disk issues, uncomment this to free more space.
# It just removes cached tool versions; setup-python will re-download what it needs.
# sudo rm -rf /opt/hostedtoolcache || true
echo "=== Docker cleanup (before our builds/compose) ==="
docker system prune -af || true
docker volume prune -f || true
echo "=== Disk usage after cleanup ==="
df -h
# --- End disk cleanup section ---
- name: Set up Python
uses: actions/setup-python@v4
with:
@ -33,6 +59,7 @@ jobs:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }}
GROK_API_KEY : ${{ secrets.GROK_API_KEY }}
run: |
python -mvenv venv
source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh

View file

@ -1,45 +0,0 @@
name: model server tests
on:
push:
branches:
- main # Run tests on pushes to the main branch
pull_request:
branches:
- main # Run tests on pull requests to the main branch
jobs:
test:
runs-on: ubuntu-latest
steps:
# Step 1: Check out the code from your repository
- name: Checkout code
uses: actions/checkout@v3
# Step 2: Set up Python (specify the version)
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.12"
# Step 3: Install Poetry
- name: Install Poetry
run: |
export POETRY_VERSION=2.2.1
curl -sSL https://install.python-poetry.org | python3 -
export PATH="$HOME/.local/bin:$PATH"
# Step 4: Install dependencies using Poetry
- name: Install dependencies
run: |
cd model_server
poetry install
# Step 5: Set PYTHONPATH and run tests
- name: Run model server tests with pytest
env:
PYTHONPATH: model_server # Ensure the app's path is available
run: |
cd model_server
poetry run pytest

View file

@ -29,3 +29,6 @@ jobs:
- name: Run unit tests
run: cargo test --lib
- name: Run trace integration tests
run: cargo test -p common --features trace-collection traces::tests::trace_integration_test

View file

@ -24,7 +24,7 @@ jobs:
- name: build arch docker image
run: |
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.18
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.22
- name: validate arch config
run: |

5
.gitignore vendored
View file

@ -111,11 +111,6 @@ venv.bak/
arch/tools/config
arch/tools/build
# Archgw - model_server
model_server/venv_model_server
model_server/build
model_server/dist
# Archgw - Docs
docs/build/

View file

@ -24,7 +24,7 @@ _Arch is a models-native (edge and service) proxy server for agents._<br><br>
</div>
# About The Latest Release:
[0.3.18] [Preference-aware multi LLM routing for Claude Code 2.0](demos/use_cases/claude_code_router/README.md) <br><img src="docs/source/_static/img/claude_code_router.png" alt="high-level network architecture for ArchGW" width="50%">
[0.3.20] [Preference-aware multi LLM routing for Claude Code 2.0](demos/use_cases/claude_code_router/README.md) <br><img src="docs/source/_static/img/claude_code_router.png" alt="high-level network architecture for ArchGW" width="50%">
# Overview
@ -87,7 +87,7 @@ Arch's CLI allows you to manage and interact with the Arch gateway efficiently.
```console
$ python3.12 -m venv venv
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
$ pip install archgw==0.3.18
$ pip install archgw==0.3.22
```
### Use Arch as a LLM Router
@ -276,7 +276,7 @@ endpoints:
```sh
$ archgw up arch_config.yaml
2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.3.18
2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.3.22
2024-12-05 16:56:28,485 - cli.utils - INFO - Schema validation successful!
2024-12-05 16:56:28,485 - cli.main - INFO - Starting arch model server and arch gateway
2024-12-05 16:56:51,647 - cli.core - INFO - Container is healthy!

View file

@ -14,6 +14,38 @@ properties:
type: array
items:
type: object
properties:
id:
type: string
url:
type: string
additionalProperties: false
required:
- id
- url
filters:
type: array
items:
type: object
properties:
id:
type: string
url:
type: string
type:
type: string
enum:
- mcp
transport:
type: string
enum:
- streamable-http
tool:
type: string
additionalProperties: false
required:
- id
- url
listeners:
oneOf:
- type: array
@ -331,6 +363,31 @@ properties:
model:
type: string
additionalProperties: false
state_storage:
type: object
properties:
type:
type: string
enum:
- memory
- postgres
connection_string:
type: string
description: Required when type is postgres. Supports environment variable substitution using $VAR or ${VAR} syntax.
additionalProperties: false
required:
- type
# Note: connection_string is conditionally required based on type
# If type is 'postgres', connection_string must be provided
# If type is 'memory', connection_string is not needed
allOf:
- if:
properties:
type:
const: postgres
then:
required:
- connection_string
prompt_guards:
type: object
properties:

View file

@ -22,4 +22,3 @@ services:
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
- MISTRAL_API_KEY=${MISTRAL_API_KEY:?error}
- OTEL_TRACING_HTTP_ENDPOINT=http://host.docker.internal:4318/v1/traces
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-51000}

View file

@ -51,11 +51,11 @@ static_resources:
envoy_grpc:
cluster_name: opentelemetry_collector
timeout: 0.250s
service_name: archgw(inbound)
service_name: plano(inbound)
random_sampling:
value: {{ arch_tracing.random_sampling }}
{% endif %}
stat_prefix: ingress_traffic
stat_prefix: plano(inbound)
codec_type: AUTO
scheme_header_transformation:
scheme_to_overwrite: https
@ -95,21 +95,6 @@ static_resources:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
generate_request_id: true
tracing:
provider:
name: envoy.tracers.opentelemetry
typed_config:
"@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
grpc_service:
envoy_grpc:
cluster_name: opentelemetry_collector
timeout: 0.250s
service_name: ingress_traffic
random_sampling:
value: {{ arch_tracing.random_sampling }}
{% endif %}
stat_prefix: ingress_traffic
codec_type: AUTO
scheme_header_transformation:
@ -221,7 +206,7 @@ static_resources:
- name: outbound_api_traffic
address:
socket_address:
address: 0.0.0.0
address: 127.0.0.1
port_value: 11000
traffic_direction: OUTBOUND
filter_chains:
@ -229,21 +214,21 @@ static_resources:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
generate_request_id: true
tracing:
provider:
name: envoy.tracers.opentelemetry
typed_config:
"@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
grpc_service:
envoy_grpc:
cluster_name: opentelemetry_collector
timeout: 0.250s
service_name: outbound_api_traffic
random_sampling:
value: {{ arch_tracing.random_sampling }}
{% endif %}
# {% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
# generate_request_id: true
# tracing:
# provider:
# name: envoy.tracers.opentelemetry
# typed_config:
# "@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
# grpc_service:
# envoy_grpc:
# cluster_name: opentelemetry_collector
# timeout: 0.250s
# service_name: tools
# random_sampling:
# value: {{ arch_tracing.random_sampling }}
# {% endif %}
stat_prefix: outbound_api_traffic
codec_type: AUTO
scheme_header_transformation:
@ -262,19 +247,16 @@ static_resources:
domains:
- "*"
routes:
{% for internal_cluster in ["arch_fc", "model_server"] %}
- match:
prefix: "/"
headers:
- name: "x-arch-upstream"
string_match:
exact: {{ internal_cluster }}
exact: bright_staff
route:
auto_host_rewrite: true
cluster: {{ internal_cluster }}
cluster: bright_staff
timeout: 300s
{% endfor %}
{% for cluster_name, cluster in arch_clusters.items() %}
- match:
prefix: "/"
@ -317,7 +299,7 @@ static_resources:
envoy_grpc:
cluster_name: opentelemetry_collector
timeout: 0.250s
service_name: arch_gateway
service_name: plano(inbound)
random_sampling:
value: {{ arch_tracing.random_sampling }}
{% endif %}
@ -416,7 +398,7 @@ static_resources:
envoy_grpc:
cluster_name: opentelemetry_collector
timeout: 0.250s
service_name: archgw(outbound)
service_name: plano(outbound)
random_sampling:
value: {{ arch_tracing.random_sampling }}
{% endif %}
@ -487,6 +469,50 @@ static_resources:
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
- name: otel_collector_proxy
address:
socket_address:
address: 127.0.0.1
port_value: 9903
traffic_direction: OUTBOUND
filter_chains:
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: otel_proxy
codec_type: AUTO
access_log:
- name: envoy.access_loggers.file
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
path: "/var/log/access_otel.log"
format: |
[%START_TIME%] "%REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% %PROTOCOL%" %RESPONSE_CODE% %RESPONSE_FLAGS% %BYTES_RECEIVED% %BYTES_SENT% %DURATION% %RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)% "%REQ(X-FORWARDED-FOR)%" "%REQ(USER-AGENT)%" "%REQ(X-REQUEST-ID)%" "%REQ(:AUTHORITY)%" "%UPSTREAM_HOST%" "%UPSTREAM_CLUSTER%"
route_config:
name: otel_route
virtual_hosts:
- name: otel_backend
domains: ["*"]
routes:
- match:
prefix: "/v1/traces"
route:
cluster: opentelemetry_collector_http
timeout: 5s
retry_policy:
retry_on: "5xx,connect-failure,refused-stream,reset"
num_retries: 3
per_try_timeout: 2s
host_selection_retry_max_attempts: 5
retriable_status_codes: [500, 502, 503, 504]
http_filters:
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
{% endif %}
- name: egress_traffic_llm
address:
socket_address:
@ -599,7 +625,7 @@ static_resources:
clusters:
- name: arch
connect_timeout: 0.5s
connect_timeout: 5s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
@ -868,24 +894,6 @@ static_resources:
tls_params:
tls_minimum_protocol_version: TLSv1_2
tls_maximum_protocol_version: TLSv1_3
{% for internal_cluster in ["arch_fc", "model_server"] %}
- name: {{ internal_cluster }}
connect_timeout: 0.5s
type: STRICT_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: {{ internal_cluster }}
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: host.docker.internal
port_value: 51000
hostname: {{ internal_cluster }}
{% endfor %}
- name: mistral_7b_instruct
connect_timeout: 0.5s
type: STRICT_DNS
@ -1035,7 +1043,6 @@ static_resources:
port_value: 12001
hostname: arch_listener_llm
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
- name: opentelemetry_collector
type: STRICT_DNS
@ -1069,4 +1076,19 @@ static_resources:
socket_address:
address: host.docker.internal
port_value: 4318
# Circuit breaker configuration to prevent overwhelming OTEL collector
circuit_breakers:
thresholds:
- priority: DEFAULT
max_connections: 100
max_pending_requests: 100
max_requests: 100
max_retries: 3
# Health checking and outlier detection
outlier_detection:
consecutive_5xx: 5
interval: 10s
base_ejection_time: 30s
max_ejection_percent: 50
enforcing_consecutive_5xx: 100
{% endif %}

View file

@ -2,7 +2,7 @@
nodaemon=true
[program:brightstaff]
command=sh -c "RUST_LOG=info /app/brightstaff 2>&1 | tee /var/log/brightstaff.log | while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done"
command=sh -c "envsubst < /app/arch_config_rendered.yaml > /app/arch_config_rendered.env_sub.yaml && RUST_LOG=debug ARCH_CONFIG_PATH_RENDERED=/app/arch_config_rendered.env_sub.yaml /app/brightstaff 2>&1 | tee /var/log/brightstaff.log | while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done"
stdout_logfile=/dev/stdout
redirect_stderr=true
stdout_logfile_maxbytes=0

View file

@ -19,7 +19,7 @@ source venv/bin/activate
### Step 3: Run the build script
```bash
pip install archgw==0.3.18
pip install archgw==0.3.22
```
## Uninstall Instructions: archgw CLI
@ -56,15 +56,8 @@ poetry install
archgw build
```
### Step 5: download models
This will help download models so model_server can load faster. This should be done once.
```bash
archgw download-models
```
### Logs
`archgw` command can also view logs from gateway and model_server. Use following command to view logs,
`archgw` command can also view logs from the gateway. Use following command to view logs,
```bash
archgw logs --follow

View file

@ -13,10 +13,10 @@ SUPPORTED_PROVIDERS_WITH_BASE_URL = [
"ollama",
"qwen",
"amazon_bedrock",
"arch",
]
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [
"arch",
"deepseek",
"groq",
"mistral",
@ -101,8 +101,17 @@ def validate_and_render_schema():
# Process agents section and convert to endpoints
agents = config_yaml.get("agents", [])
for agent in agents:
filters = config_yaml.get("filters", [])
agents_combined = agents + filters
agent_id_keys = set()
for agent in agents_combined:
agent_id = agent.get("id")
if agent_id in agent_id_keys:
raise Exception(
f"Duplicate agent id {agent_id}, please provide unique id for each agent"
)
agent_id_keys.add(agent_id)
agent_endpoint = agent.get("url")
if agent_id and agent_endpoint:
@ -304,6 +313,16 @@ def validate_and_render_schema():
}
)
# Always add arch-function model provider if not already defined
if "arch-function" not in model_provider_name_set:
updated_model_providers.append(
{
"name": "arch-function",
"provider_interface": "arch",
"model": "Arch-Function",
}
)
config_yaml["model_providers"] = deepcopy(updated_model_providers)
listeners_with_provider = 0

View file

@ -1,13 +1,5 @@
import os
KATANEMO_DOCKERHUB_REPO = "katanemo/archgw"
KATANEMO_LOCAL_MODEL_LIST = [
"katanemo/Arch-Guard",
]
SERVICE_NAME_ARCHGW = "archgw"
SERVICE_NAME_MODEL_SERVER = "model_server"
SERVICE_ALL = "all"
MODEL_SERVER_LOG_FILE = "~/archgw_logs/modelserver.log"
ARCHGW_DOCKER_NAME = "archgw"
ARCHGW_DOCKER_IMAGE = os.getenv("ARCHGW_DOCKER_IMAGE", "katanemo/archgw:0.3.18")
ARCHGW_DOCKER_IMAGE = os.getenv("ARCHGW_DOCKER_IMAGE", "katanemo/archgw:0.3.22")

View file

@ -9,9 +9,7 @@ from cli.utils import convert_legacy_listeners, getLogger
from cli.consts import (
ARCHGW_DOCKER_IMAGE,
ARCHGW_DOCKER_NAME,
KATANEMO_LOCAL_MODEL_LIST,
)
from huggingface_hub import snapshot_download
import subprocess
from cli.docker_cli import (
docker_container_status,
@ -144,49 +142,6 @@ def stop_docker_container(service=ARCHGW_DOCKER_NAME):
log.info(f"Failed to shut down services: {str(e)}")
def download_models_from_hf():
for model in KATANEMO_LOCAL_MODEL_LIST:
log.info(f"Downloading model: {model}")
snapshot_download(repo_id=model)
def start_arch_modelserver(foreground):
"""
Start the model server. This assumes that the archgw_modelserver package is installed locally
"""
try:
log.info("archgw_modelserver restart")
if foreground:
subprocess.run(
["archgw_modelserver", "start", "--foreground"],
check=True,
)
else:
subprocess.run(
["archgw_modelserver", "start"],
check=True,
)
except subprocess.CalledProcessError as e:
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1)
def stop_arch_modelserver():
"""
Stop the model server. This assumes that the archgw_modelserver package is installed locally
"""
try:
subprocess.run(
["archgw_modelserver", "stop"],
check=True,
)
except subprocess.CalledProcessError as e:
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1)
def start_cli_agent(arch_config_file=None, settings_json="{}"):
"""Start a CLI client connected to Arch."""

View file

@ -20,20 +20,14 @@ from cli.utils import (
find_config_file,
)
from cli.core import (
start_arch_modelserver,
stop_arch_modelserver,
start_arch,
stop_docker_container,
download_models_from_hf,
start_cli_agent,
)
from cli.consts import (
ARCHGW_DOCKER_IMAGE,
ARCHGW_DOCKER_NAME,
KATANEMO_DOCKERHUB_REPO,
SERVICE_NAME_ARCHGW,
SERVICE_NAME_MODEL_SERVER,
SERVICE_ALL,
)
log = getLogger(__name__)
@ -47,9 +41,8 @@ logo = r"""
"""
# Command to build archgw and model_server Docker images
# Command to build archgw Docker images
ARCHGW_DOCKERFILE = "./arch/Dockerfile"
MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
def get_version():
@ -60,18 +53,6 @@ def get_version():
return "version not found"
def verify_service_name(service):
"""Verify if the service name is valid."""
if service not in [
SERVICE_NAME_ARCHGW,
SERVICE_NAME_MODEL_SERVER,
SERVICE_ALL,
]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
return True
@click.group(invoke_without_command=True)
@click.option("--version", is_flag=True, help="Show the archgw cli version and exit.")
@click.pass_context
@ -89,17 +70,11 @@ def main(ctx, version):
@click.command()
@click.option(
"--service",
default=SERVICE_ALL,
help="Optional parameter to specify which service to build. Options are model_server, archgw",
)
def build(service):
def build():
"""Build Arch from source. Must be in root of cloned repo."""
verify_service_name(service)
# Check if /arch/Dockerfile exists
if service == SERVICE_NAME_ARCHGW or service == SERVICE_ALL:
if os.path.exists(ARCHGW_DOCKERFILE):
if os.path.exists(ARCHGW_DOCKERFILE):
click.echo("Building archgw image...")
try:
@ -110,8 +85,6 @@ def build(service):
"-f",
ARCHGW_DOCKERFILE,
"-t",
f"{KATANEMO_DOCKERHUB_REPO}:latest",
"-t",
f"{ARCHGW_DOCKER_IMAGE}",
".",
"--add-host=host.docker.internal:host-gateway",
@ -128,57 +101,20 @@ def build(service):
click.echo("archgw image built successfully.")
"""Install the model server dependencies using Poetry."""
if service == SERVICE_NAME_MODEL_SERVER or service == SERVICE_ALL:
# Check if pyproject.toml exists
if os.path.exists(MODEL_SERVER_BUILD_FILE):
click.echo("Installing model server dependencies with Poetry...")
try:
subprocess.run(
["poetry", "install", "--no-cache"],
cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE),
check=True,
)
click.echo("Model server dependencies installed successfully.")
except subprocess.CalledProcessError as e:
click.echo(f"Error installing model server dependencies: {e}")
sys.exit(1)
else:
click.echo(f"Error: pyproject.toml not found in {MODEL_SERVER_BUILD_FILE}")
sys.exit(1)
@click.command()
@click.argument("file", required=False) # Optional file argument
@click.option(
"--path", default=".", help="Path to the directory containing arch_config.yaml"
)
@click.option(
"--service",
default=SERVICE_ALL,
help="Service to start. Options are model_server, archgw.",
)
@click.option(
"--foreground",
default=False,
help="Run Arch in the foreground. Default is False",
is_flag=True,
)
def up(file, path, service, foreground):
def up(file, path, foreground):
"""Starts Arch."""
verify_service_name(service)
if service == SERVICE_ALL and foreground:
# foreground can only be specified when starting individual services
log.info("foreground flag is only supported for individual services. Exiting.")
sys.exit(1)
if service == SERVICE_NAME_MODEL_SERVER:
log.info("Download models from HuggingFace...")
download_models_from_hf()
start_arch_modelserver(foreground)
return
# Use the utility function to find config file
arch_config_file = find_config_file(path, file)
@ -202,7 +138,6 @@ def up(file, path, service, foreground):
# Set the ARCH_CONFIG_FILE environment variable
env_stage = {
"OTEL_TRACING_HTTP_ENDPOINT": "http://host.docker.internal:4318/v1/traces",
"MODEL_SERVER_PORT": os.getenv("MODEL_SERVER_PORT", "51000"),
}
env = os.environ.copy()
# Remove PATH variable if present
@ -242,40 +177,13 @@ def up(file, path, service, foreground):
env_stage[access_key] = env_file_dict[access_key]
env.update(env_stage)
if service == SERVICE_NAME_ARCHGW:
start_arch(arch_config_file, env, foreground=foreground)
else:
# Check if ingress_traffic listener is configured before starting model_server
if has_ingress_listener(arch_config_file):
download_models_from_hf()
start_arch_modelserver(foreground)
else:
log.info(
"Skipping model_server startup: no ingress_traffic listener configured in arch_config.yaml"
)
start_arch(arch_config_file, env, foreground=foreground)
start_arch(arch_config_file, env, foreground=foreground)
@click.command()
@click.option(
"--service",
default=SERVICE_ALL,
help="Service to down. Options are all, model_server, archgw. Default is all",
)
def down(service):
def down():
"""Stops Arch."""
verify_service_name(service)
if service == SERVICE_NAME_MODEL_SERVER:
stop_arch_modelserver()
elif service == SERVICE_NAME_ARCHGW:
stop_docker_container()
else:
stop_arch_modelserver()
stop_docker_container(SERVICE_NAME_ARCHGW)
stop_docker_container()
@click.command()
@ -303,7 +211,7 @@ def generate_prompt_targets(file):
@click.command()
@click.option(
"--debug",
help="For detailed debug logs to trace calls from archgw <> model_server <> api_server, etc",
help="For detailed debug logs to trace calls from archgw <> api_server, etc",
is_flag=True,
)
@click.option("--follow", help="Follow the logs", is_flag=True)

View file

@ -57,6 +57,10 @@ def convert_legacy_listeners(
"timeout": "30s",
}
# Handle None case
if listeners is None:
return [llm_gateway_listener], llm_gateway_listener, prompt_gateway_listener
if isinstance(listeners, dict):
# legacy listeners
# check if type is array or object
@ -148,6 +152,24 @@ def get_llm_provider_access_keys(arch_config_file):
if access_key is not None:
access_key_list.append(access_key)
# Extract environment variables from state_storage.connection_string
state_storage = arch_config_yaml.get("state_storage_v1_responses")
if state_storage:
connection_string = state_storage.get("connection_string")
if connection_string and isinstance(connection_string, str):
# Extract all $VAR and ${VAR} patterns from connection string
import re
# Match both $VAR and ${VAR} patterns
pattern = r"\$\{?([A-Z_][A-Z0-9_]*)\}?"
matches = re.findall(pattern, connection_string)
for var in matches:
access_key_list.append(f"${var}")
else:
raise ValueError(
"Invalid connection string received in state_storage_v1_responses"
)
return access_key_list

2481
arch/tools/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,34 +1,28 @@
[project]
name = "archgw"
version = "0.3.18"
description = "Python-based CLI tool to manage Arch Gateway."
authors = [{ name = "Katanemo Labs, Inc." }]
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"archgw_modelserver==0.3.18",
"click>=8.1.7,<9.0.0",
"jinja2>=3.1.4,<4.0.0",
"jsonschema>=4.23.0,<5.0.0",
"pyyaml>=6.0.2,<7.0.0",
]
[project.scripts]
archgw = "cli.main:main"
[dependency-groups]
dev = [
"pytest>=8.4.1,<9.0.0",
]
[tool.poetry]
name = "archgw"
version = "0.3.22"
description = "Python-based CLI tool to manage Arch Gateway."
authors = ["Katanemo Labs, Inc."]
readme = "README.md"
packages = [{ include = "cli" }]
dependencies = { archgw_modelserver = { path = "../../model_server", develop = true } }
[tool.poetry.dependencies]
python = ">=3.10"
click = ">=8.1.7,<9.0.0"
jinja2 = ">=3.1.4,<4.0.0"
jsonschema = ">=4.23.0,<5.0.0"
pyyaml = ">=6.0.2,<7.0.0"
requests = ">=2.31.0,<3.0.0"
[tool.poetry.group.dev.dependencies]
pytest = ">=8.4.1,<9.0.0"
[tool.poetry.scripts]
archgw = "cli.main:main"
[build-system]
requires = ["poetry-core>=2.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
addopts = ["-v"]

View file

@ -94,21 +94,16 @@ def test_validate_and_render_happy_path_agent_config(monkeypatch):
version: v0.3.0
agents:
- name: query_rewriter
kind: openai
endpoint: http://localhost:10500
- name: context_builder
kind: openai
endpoint: http://localhost:10501
- name: response_generator
kind: openai
endpoint: http://localhost:10502
- name: research_agent
kind: openai
endpoint: http://localhost:10500
- name: input_guard_rails
kind: openai
endpoint: http://localhost:10503
- id: query_rewriter
url: http://localhost:10500
- id: context_builder
url: http://localhost:10501
- id: response_generator
url: http://localhost:10502
- id: research_agent
url: http://localhost:10500
- id: input_guard_rails
url: http://localhost:10503
listeners:
- name: tmobile
@ -156,7 +151,7 @@ listeners:
mock.mock_open().return_value, # ARCH_CONFIG_FILE_RENDERED (write)
]
with mock.patch("builtins.open", m_open):
with mock.patch("config_generator.Environment"):
with mock.patch("cli.config_generator.Environment"):
validate_and_render_schema()

View file

@ -12,10 +12,6 @@
"name": "archgw_cli",
"path": "arch/tools"
},
{
"name": "model_server",
"path": "model_server"
},
{
"name": "tests_e2e",
"path": "tests/e2e"
@ -24,10 +20,6 @@
"name": "tests_archgw",
"path": "tests/archgw"
},
{
"name": "tests_modelserver",
"path": "tests/modelserver"
},
{
"name": "chatbot_ui",
"path": "demos/shared/chatbot_ui"
@ -42,6 +34,7 @@
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
},
"rust-analyzer.cargo.features": ["trace-collection"]
},
"extensions": {
"recommendations": [

View file

@ -1 +1 @@
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.2
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.22

590
crates/Cargo.lock generated
View file

@ -78,6 +78,43 @@ dependencies = [
"serde_json",
]
[[package]]
name = "async-openai"
version = "0.30.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bf39a15c8d613eb61892dc9a287c02277639ebead41ee611ad23aaa613f1a82"
dependencies = [
"async-openai-macros",
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand 0.9.2",
"reqwest",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.12",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-openai-macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "async-trait"
version = "0.1.88"
@ -130,6 +167,75 @@ dependencies = [
"time",
]
[[package]]
name = "axum"
version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [
"async-trait",
"axum-core",
"bytes",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower 0.5.2",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-core"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "backoff"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
dependencies = [
"futures-core",
"getrandom 0.2.16",
"instant",
"pin-project-lite",
"rand 0.8.5",
"tokio",
]
[[package]]
name = "backtrace"
version = "0.3.75"
@ -201,10 +307,14 @@ dependencies = [
name = "brightstaff"
version = "0.1.0"
dependencies = [
"async-openai",
"async-trait",
"bytes",
"chrono",
"common",
"eventsource-client",
"eventsource-stream",
"flate2",
"futures",
"futures-util",
"hermesllm",
@ -219,6 +329,7 @@ dependencies = [
"opentelemetry-stdout",
"opentelemetry_sdk",
"pretty_assertions",
"rand 0.9.2",
"reqwest",
"serde",
"serde_json",
@ -227,10 +338,12 @@ dependencies = [
"thiserror 2.0.12",
"time",
"tokio",
"tokio-postgres",
"tokio-stream",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"uuid",
]
[[package]]
@ -250,6 +363,12 @@ version = "3.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.10.1"
@ -281,6 +400,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
version = "0.4.41"
@ -289,8 +414,10 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-link",
]
@ -307,6 +434,7 @@ dependencies = [
name = "common"
version = "0.1.0"
dependencies = [
"axum",
"derivative",
"duration-string",
"governor",
@ -316,12 +444,16 @@ dependencies = [
"pretty_assertions",
"proxy-wasm",
"rand 0.8.5",
"reqwest",
"serde",
"serde_json",
"serde_with",
"serde_yaml",
"serial_test",
"thiserror 1.0.69",
"tiktoken-rs",
"tokio",
"tracing",
"url",
"urlencoding",
]
@ -336,6 +468,16 @@ dependencies = [
"libc",
]
[[package]]
name = "core-foundation"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
@ -426,6 +568,37 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "derive_builder"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.101",
]
[[package]]
name = "derive_builder_macro"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
dependencies = [
"derive_builder_core",
"syn 2.0.101",
]
[[package]]
name = "diff"
version = "0.1.13"
@ -440,6 +613,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
"subtle",
]
[[package]]
@ -527,6 +701,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "fallible-iterator"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
[[package]]
name = "fancy-regex"
version = "0.12.0"
@ -543,6 +723,16 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "flate2"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "fnv"
version = "1.0.7"
@ -650,6 +840,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
@ -685,8 +881,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
@ -696,9 +894,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"r-efi",
"wasi 0.14.2+wasi-0.2.4",
"wasm-bindgen",
]
[[package]]
@ -798,10 +998,12 @@ version = "0.1.0"
dependencies = [
"aws-smithy-eventstream",
"bytes",
"log",
"serde",
"serde_json",
"serde_with",
"thiserror 2.0.12",
"uuid",
]
[[package]]
@ -810,6 +1012,15 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "http"
version = "0.2.12"
@ -934,7 +1145,7 @@ dependencies = [
"hyper 0.14.32",
"log",
"rustls 0.21.12",
"rustls-native-certs",
"rustls-native-certs 0.6.3",
"tokio",
"tokio-rustls 0.24.1",
]
@ -949,6 +1160,7 @@ dependencies = [
"hyper 1.6.0",
"hyper-util",
"rustls 0.23.27",
"rustls-native-certs 0.8.2",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.26.2",
@ -1181,6 +1393,15 @@ dependencies = [
"serde",
]
[[package]]
name = "instant"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
dependencies = [
"cfg-if",
]
[[package]]
name = "ipnet"
version = "2.11.0"
@ -1234,6 +1455,17 @@ version = "0.2.172"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
[[package]]
name = "libredox"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb"
dependencies = [
"bitflags",
"libc",
"redox_syscall",
]
[[package]]
name = "linux-raw-sys"
version = "0.9.4"
@ -1285,6 +1517,12 @@ version = "0.4.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
name = "lru-slab"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]]
name = "matchers"
version = "0.1.0"
@ -1294,6 +1532,22 @@ dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "matchit"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "md-5"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf"
dependencies = [
"cfg-if",
"digest",
]
[[package]]
name = "md5"
version = "0.7.0"
@ -1312,6 +1566,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -1325,6 +1589,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a"
dependencies = [
"adler2",
"simd-adler32",
]
[[package]]
@ -1354,7 +1619,7 @@ dependencies = [
"hyper 1.6.0",
"hyper-util",
"log",
"rand 0.9.1",
"rand 0.9.2",
"regex",
"serde_json",
"serde_urlencoded",
@ -1374,7 +1639,7 @@ dependencies = [
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework 2.11.1",
"security-framework-sys",
"tempfile",
]
@ -1581,7 +1846,7 @@ dependencies = [
"glob",
"opentelemetry",
"percent-encoding",
"rand 0.9.1",
"rand 0.9.2",
"serde_json",
"thiserror 2.0.12",
"tracing",
@ -1628,6 +1893,24 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "phf"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078"
dependencies = [
"phf_shared",
]
[[package]]
name = "phf_shared"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
dependencies = [
"siphasher",
]
[[package]]
name = "pin-project"
version = "1.1.10"
@ -1672,6 +1955,37 @@ version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e"
[[package]]
name = "postgres-protocol"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbef655056b916eb868048276cfd5d6a7dea4f81560dfd047f97c8c6fe3fcfd4"
dependencies = [
"base64 0.22.1",
"byteorder",
"bytes",
"fallible-iterator",
"hmac",
"md-5",
"memchr",
"rand 0.9.2",
"sha2",
"stringprep",
]
[[package]]
name = "postgres-types"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48"
dependencies = [
"bytes",
"fallible-iterator",
"postgres-protocol",
"serde",
"serde_json",
]
[[package]]
name = "potential_utf"
version = "0.1.2"
@ -1770,6 +2084,61 @@ dependencies = [
"log",
]
[[package]]
name = "quinn"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
dependencies = [
"bytes",
"cfg_aliases",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash 2.1.1",
"rustls 0.23.27",
"socket2",
"thiserror 2.0.12",
"tokio",
"tracing",
"web-time",
]
[[package]]
name = "quinn-proto"
version = "0.11.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
dependencies = [
"bytes",
"getrandom 0.3.3",
"lru-slab",
"rand 0.9.2",
"ring",
"rustc-hash 2.1.1",
"rustls 0.23.27",
"rustls-pki-types",
"slab",
"thiserror 2.0.12",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
]
[[package]]
name = "quote"
version = "1.0.40"
@ -1798,9 +2167,9 @@ dependencies = [
[[package]]
name = "rand"
version = "0.9.1"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
@ -1846,9 +2215,9 @@ dependencies = [
[[package]]
name = "redox_syscall"
version = "0.5.12"
version = "0.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af"
checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d"
dependencies = [
"bitflags",
]
@ -1941,10 +2310,14 @@ dependencies = [
"js-sys",
"log",
"mime",
"mime_guess",
"native-tls",
"once_cell",
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls 0.23.27",
"rustls-native-certs 0.8.2",
"rustls-pki-types",
"serde",
"serde_json",
@ -1952,6 +2325,7 @@ dependencies = [
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-rustls 0.26.2",
"tokio-util",
"tower 0.5.2",
"tower-http",
@ -1963,6 +2337,22 @@ dependencies = [
"web-sys",
]
[[package]]
name = "reqwest-eventsource"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
dependencies = [
"eventsource-stream",
"futures-core",
"futures-timer",
"mime",
"nom",
"pin-project-lite",
"reqwest",
"thiserror 1.0.69",
]
[[package]]
name = "ring"
version = "0.17.14"
@ -1989,6 +2379,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustix"
version = "1.0.7"
@ -2021,6 +2417,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321"
dependencies = [
"once_cell",
"ring",
"rustls-pki-types",
"rustls-webpki 0.103.3",
"subtle",
@ -2036,7 +2433,19 @@ dependencies = [
"openssl-probe",
"rustls-pemfile",
"schannel",
"security-framework",
"security-framework 2.11.1",
]
[[package]]
name = "rustls-native-certs"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9980d917ebb0c0536119ba501e90834767bffc3d60641457fd84a1f3fd337923"
dependencies = [
"openssl-probe",
"rustls-pki-types",
"schannel",
"security-framework 3.5.1",
]
[[package]]
@ -2054,6 +2463,7 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79"
dependencies = [
"web-time",
"zeroize",
]
@ -2142,6 +2552,16 @@ version = "3.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21"
[[package]]
name = "secrecy"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
dependencies = [
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.11.1"
@ -2149,7 +2569,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework"
version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef"
dependencies = [
"bitflags",
"core-foundation 0.10.1",
"core-foundation-sys",
"libc",
"security-framework-sys",
@ -2191,12 +2624,23 @@ version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
dependencies = [
"indexmap 2.9.0",
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
dependencies = [
"itoa",
"serde",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
@ -2313,12 +2757,24 @@ dependencies = [
"libc",
]
[[package]]
name = "simd-adler32"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2"
[[package]]
name = "similar"
version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
[[package]]
name = "siphasher"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
[[package]]
name = "slab"
version = "0.4.9"
@ -2359,6 +2815,17 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "stringprep"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1"
dependencies = [
"unicode-bidi",
"unicode-normalization",
"unicode-properties",
]
[[package]]
name = "strsim"
version = "0.11.1"
@ -2420,7 +2887,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation 0.9.4",
"system-configuration-sys",
]
@ -2509,7 +2976,7 @@ dependencies = [
"fancy-regex",
"lazy_static",
"parking_lot",
"rustc-hash",
"rustc-hash 1.1.0",
]
[[package]]
@ -2553,6 +3020,21 @@ dependencies = [
"zerovec",
]
[[package]]
name = "tinyvec"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.45.1"
@ -2602,6 +3084,32 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-postgres"
version = "0.7.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c95d533c83082bb6490e0189acaa0bbeef9084e60471b696ca6988cd0541fb0"
dependencies = [
"async-trait",
"byteorder",
"bytes",
"fallible-iterator",
"futures-channel",
"futures-util",
"log",
"parking_lot",
"percent-encoding",
"phf",
"pin-project-lite",
"postgres-protocol",
"postgres-types",
"rand 0.9.2",
"socket2",
"tokio",
"tokio-util",
"whoami",
]
[[package]]
name = "tokio-rustls"
version = "0.24.1"
@ -2705,6 +3213,7 @@ dependencies = [
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
@ -2743,6 +3252,7 @@ version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
@ -2829,12 +3339,39 @@ version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-bidi"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5"
[[package]]
name = "unicode-ident"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "unicode-normalization"
version = "0.1.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8"
dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-properties"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d"
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
@ -2870,6 +3407,18 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "uuid"
version = "1.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2"
dependencies = [
"getrandom 0.3.3",
"js-sys",
"serde",
"wasm-bindgen",
]
[[package]]
name = "valuable"
version = "0.1.1"
@ -2918,6 +3467,12 @@ dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wasite"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b"
[[package]]
name = "wasm-bindgen"
version = "0.2.100"
@ -3022,6 +3577,17 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "whoami"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d"
dependencies = [
"libredox",
"wasite",
"web-sys",
]
[[package]]
name = "winapi"
version = "0.3.9"

View file

@ -1,3 +1,7 @@
[workspace]
resolver = "2"
members = ["llm_gateway", "prompt_gateway", "common", "brightstaff", "hermesllm"]
[workspace.metadata.rust-analyzer]
# Enable features for better IDE support
cargo.features = ["trace-collection"]

View file

@ -4,10 +4,14 @@ version = "0.1.0"
edition = "2021"
[dependencies]
async-openai = "0.30.1"
async-trait = "0.1"
bytes = "1.10.1"
common = { version = "0.1.0", path = "../common" }
chrono = "0.4"
common = { version = "0.1.0", path = "../common", features = ["trace-collection"] }
eventsource-client = "0.15.0"
eventsource-stream = "0.2.3"
flate2 = "1.0"
futures = "0.3.31"
futures-util = "0.3.31"
hermesllm = { version = "0.1.0", path = "../hermesllm" }
@ -21,6 +25,7 @@ opentelemetry-otlp = {version="0.29.0", features=["trace", "tonic", "grpc-tonic"
opentelemetry-stdout = "0.29.0"
opentelemetry_sdk = "0.29.0"
pretty_assertions = "1.4.1"
rand = "0.9.2"
reqwest = { version = "0.12.15", features = ["stream"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
@ -28,10 +33,12 @@ serde_with = "3.13.0"
serde_yaml = "0.9.34"
thiserror = "2.0.12"
tokio = { version = "1.44.2", features = ["full"] }
tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] }
tokio-stream = "0.1"
time = { version = "0.3", features = ["formatting", "macros"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.0", features = ["v4", "serde"] }
[dev-dependencies]
mockito = "1.0"

View file

@ -1,16 +1,24 @@
use std::sync::Arc;
use std::time::{Instant, SystemTime};
use bytes::Bytes;
use hermesllm::apis::openai::ChatCompletionsRequest;
use common::consts::TRACE_PARENT_HEADER;
use common::traces::{SpanBuilder, SpanKind, parse_traceparent, generate_random_span_id};
use hermesllm::apis::OpenAIMessage;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::providers::request::ProviderRequest;
use hermesllm::ProviderRequestType;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::{Request, Response};
use serde::ser::Error as SerError;
use tracing::{debug, info, warn};
use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::llm_router::RouterService;
use crate::tracing::{OperationNameBuilder, operation_component, http};
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
@ -33,11 +41,51 @@ pub async fn agent_chat(
_: String,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
trace_collector: Arc<common::traces::TraceCollector>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
match handle_agent_chat(request, router_service, agents_list, listeners).await {
match handle_agent_chat(
request,
router_service,
agents_list,
listeners,
trace_collector,
)
.await
{
Ok(response) => Ok(response),
Err(err) => {
// Print detailed error information with full error chain
// Check if this is a client error from the pipeline that should be cascaded
if let AgentFilterChainError::Pipeline(PipelineError::ClientError {
agent,
status,
body,
}) = &err
{
warn!(
"Client error from agent '{}' (HTTP {}): {}",
agent, status, body
);
// Create error response with the original status code and body
let error_json = serde_json::json!({
"error": "ClientError",
"agent": agent,
"status": status,
"agent_response": body
});
let json_string = error_json.to_string();
let mut response = Response::new(ResponseHandler::create_full_body(json_string));
*response.status_mut() = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
return Ok(response);
}
// Print detailed error information with full error chain for other errors
let mut error_chain = Vec::new();
let mut current_error: &dyn std::error::Error = &err;
@ -78,10 +126,11 @@ async fn handle_agent_chat(
router_service: Arc<RouterService>,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
trace_collector: Arc<common::traces::TraceCollector>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
// Initialize services
let agent_selector = AgentSelector::new(router_service);
let pipeline_processor = PipelineProcessor::default();
let mut pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
// Extract listener name from headers
@ -101,6 +150,13 @@ async fn handle_agent_chat(
info!("Handling request for listener: {}", listener.name);
// Parse request body
let request_path = request
.uri()
.path()
.to_string()
.strip_prefix("/agents")
.unwrap()
.to_string();
let request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
@ -109,61 +165,141 @@ async fn handle_agent_chat(
String::from_utf8_lossy(&chat_request_bytes)
);
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
warn!(
"Failed to parse request body as ChatCompletionsRequest: {}",
err
);
AgentFilterChainError::RequestParsing(err)
// Determine the API type from the endpoint
let api_type =
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
let err_msg = format!("Unsupported endpoint: {}", request_path);
warn!("{}", err_msg);
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
})?;
let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
Ok(request) => request,
Err(err) => {
warn!("Failed to parse request as ProviderRequestType: {}", err);
let err_msg = format!("Failed to parse request: {}", err);
return Err(AgentFilterChainError::RequestParsing(
serde_json::Error::custom(err_msg),
));
}
};
let message: Vec<OpenAIMessage> = client_request.get_messages();
// let chat_completions_request: ChatCompletionsRequest =
// serde_json::from_slice(&chat_request_bytes).map_err(|err| {
// warn!(
// "Failed to parse request body as ChatCompletionsRequest: {}",
// err
// );
// AgentFilterChainError::RequestParsing(err)
// })?;
// Extract trace parent for routing
let trace_parent = request_headers
.iter()
.find(|(key, _)| key.as_str() == "traceparent")
.find(|(key, _)| key.as_str() == TRACE_PARENT_HEADER)
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
// Select appropriate agent using arch router llm model
let selected_agent = agent_selector
.select_agent(&chat_completions_request.messages, &listener, trace_parent)
.await?;
debug!("Processing agent pipeline: {}", selected_agent.id);
// Create agent map for pipeline processing
// Create agent map for pipeline processing and agent selection
let agent_map = {
let agents = agents_list.read().await;
let agents = agents.as_ref().unwrap();
agent_selector.create_agent_map(agents)
};
// Parse trace parent to get trace_id and parent_span_id
let (trace_id, parent_span_id) = if let Some(ref tp) = trace_parent {
parse_traceparent(tp)
} else {
(String::new(), None)
};
// Select appropriate agent using arch router llm model
let selected_agent = agent_selector
.select_agent(&message, &listener, trace_parent.clone())
.await?;
debug!("Processing agent pipeline: {}", selected_agent.id);
// Record the start time for agent span
let agent_start_time = SystemTime::now();
let agent_start_instant = Instant::now();
// let (span_id, trace_id) = trace_collector.start_span(
// trace_parent.clone(),
// operation_component::AGENT,
// &format!("/agents{}", request_path),
// &selected_agent.id,
// );
let span_id = generate_random_span_id();
// Process the filter chain
let processed_messages = pipeline_processor
let chat_history = pipeline_processor
.process_filter_chain(
&chat_completions_request,
&message,
&selected_agent,
&agent_map,
&request_headers,
Some(&trace_collector),
trace_id.clone(),
span_id.clone(),
)
.await?;
// Get terminal agent and send final response
let terminal_agent_name = selected_agent.id;
let terminal_agent_name = selected_agent.id.clone();
let terminal_agent = agent_map.get(&terminal_agent_name).unwrap();
debug!("Processing terminal agent: {}", terminal_agent_name);
debug!("Terminal agent details: {:?}", terminal_agent);
let llm_response = pipeline_processor
.invoke_upstream_agent(
&processed_messages,
&chat_completions_request,
.invoke_agent(
&chat_history,
client_request,
terminal_agent,
&request_headers,
trace_id.clone(),
span_id.clone(),
)
.await?;
// Record agent span after processing is complete
let agent_end_time = SystemTime::now();
let agent_elapsed = agent_start_instant.elapsed();
// Build full path with /agents prefix
let full_path = format!("/agents{}", request_path);
// Build operation name: POST {full_path} {agent_name}
let operation_name = OperationNameBuilder::new()
.with_method("POST")
.with_path(&full_path)
.with_target(&terminal_agent_name)
.build();
let mut span_builder = SpanBuilder::new(&operation_name)
.with_span_id(span_id)
.with_kind(SpanKind::Internal)
.with_start_time(agent_start_time)
.with_end_time(agent_end_time)
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, full_path)
.with_attribute("agent.name", terminal_agent_name.clone())
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
if !trace_id.is_empty() {
span_builder = span_builder.with_trace_id(trace_id);
}
if let Some(parent_id) = parent_span_id {
span_builder = span_builder.with_parent_span_id(parent_id);
}
let span = span_builder.build();
// Use plano(agent) as service name for the agent processing span
trace_collector.record_span(operation_component::AGENT, span);
// Create streaming response
response_handler
.create_streaming_response(llm_response)

View file

@ -20,6 +20,8 @@ pub enum AgentSelectionError {
RoutingError(String),
#[error("Default agent not found for listener: {0}")]
DefaultAgentNotFound(String),
#[error("MCP client error: {0}")]
McpError(String),
}
/// Service for selecting agents based on routing preferences and listener configuration
@ -29,7 +31,9 @@ pub struct AgentSelector {
impl AgentSelector {
pub fn new(router_service: Arc<RouterService>) -> Self {
Self { router_service }
Self {
router_service,
}
}
/// Find listener by name from the request headers
@ -77,7 +81,9 @@ impl AgentSelector {
return Ok(agents[0].clone());
}
let usage_preferences = self.convert_agent_description_to_routing_preferences(agents);
let usage_preferences = self
.convert_agent_description_to_routing_preferences(agents)
.await;
debug!(
"Agents usage preferences for agent routing str: {}",
serde_json::to_string(&usage_preferences).unwrap_or_default()
@ -131,20 +137,23 @@ impl AgentSelector {
}
/// Convert agent descriptions to routing preferences
fn convert_agent_description_to_routing_preferences(
async fn convert_agent_description_to_routing_preferences(
&self,
agents: &[AgentFilterChain],
) -> Vec<ModelUsagePreference> {
agents
.iter()
.map(|agent| ModelUsagePreference {
model: agent.id.clone(),
let mut preferences = Vec::new();
for agent_chain in agents {
preferences.push(ModelUsagePreference {
model: agent_chain.id.clone(),
routing_preferences: vec![RoutingPreference {
name: agent.id.clone(),
description: agent.description.as_ref().unwrap_or(&String::new()).clone(),
name: agent_chain.id.clone(),
description: agent_chain.description.clone().unwrap_or_default(),
}],
})
.collect()
});
}
preferences
}
}
@ -183,8 +192,10 @@ mod tests {
fn create_test_agent_struct(name: &str) -> Agent {
Agent {
id: name.to_string(),
kind: Some("test".to_string()),
agent_type: Some("test".to_string()),
url: "http://localhost:8080".to_string(),
tool: None,
transport: None,
}
}
@ -240,8 +251,8 @@ mod tests {
assert!(agent_map.contains_key("agent2"));
}
#[test]
fn test_convert_agent_description_to_routing_preferences() {
#[tokio::test]
async fn test_convert_agent_description_to_routing_preferences() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
@ -250,7 +261,9 @@ mod tests {
create_test_agent("agent2", "Second agent description", false),
];
let preferences = selector.convert_agent_description_to_routing_preferences(&agents);
let preferences = selector
.convert_agent_description_to_routing_preferences(&agents)
.await;
assert_eq!(preferences.len(), 2);
assert_eq!(preferences[0].model, "agent1");

View file

@ -1,276 +0,0 @@
use bytes::Bytes;
use common::configuration::{ModelAlias, ModelUsagePreference};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::clients::SupportedAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
pub async fn chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
"Received request body (raw utf8): {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let mut client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!("Failed to parse request as ProviderRequestType: {}", err);
let err_msg = format!("Failed to parse request: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
let model_from_request = client_request.model().to_string();
let is_streaming_request = client_request.is_streaming();
let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() {
if let Some(model_alias) = model_aliases.get(&model_from_request) {
debug!(
"Model Alias: 'From {}' -> 'To{}'",
model_from_request, model_alias.target
);
model_alias.target.clone()
} else {
model_from_request.clone()
}
} else {
model_from_request.clone()
};
client_request.set_model(resolved_model.clone());
// Clone metadata for routing and remove archgw_preference_config from original
let routing_metadata = client_request.metadata().clone();
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("Removed archgw_preference_config from metadata");
}
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
match ProviderRequestType::try_from((
client_request,
&SupportedUpstreamAPIs::OpenAIChatCompletions(
hermesllm::apis::OpenAIApi::ChatCompletions,
),
)) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok(
ProviderRequestType::MessagesRequest(_)
| ProviderRequestType::BedrockConverse(_)
| ProviderRequestType::BedrockConverseStream(_),
) => {
// This should not happen after conversion to OpenAI format
warn!("Unexpected: got MessagesRequest after converting to OpenAI format");
let err_msg = "Request conversion failed".to_string();
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
Err(err) => {
warn!(
"Failed to convert request to ChatCompletionsRequest: {}",
err
);
let err_msg = format!("Failed to convert request: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
debug!(
"[ARCH_ROUTER REQ]: {}",
&serde_json::to_string(&chat_completions_request_for_arch_router).unwrap()
);
let trace_parent = request_headers
.iter()
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("archgw_preference_config")
.map(|value| value.to_string())
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
let latest_message_for_log = chat_completions_request_for_arch_router
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content.to_string().replace('\n', "\\n")
});
const MAX_MESSAGE_LENGTH: usize = 50;
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
let truncated: String = latest_message_for_log
.chars()
.take(MAX_MESSAGE_LENGTH)
.collect();
format!("{}...", truncated)
} else {
latest_message_for_log
};
info!(
"request received, request type: chat_completion, usage preferences from request: {}, request path: {}, latest message: {}",
usage_preferences.is_some(),
request_path,
latest_message_for_log
);
debug!("usage preferences from request: {:?}", usage_preferences);
let model_name = match router_service
.determine_route(
&chat_completions_request_for_arch_router.messages,
trace_parent.clone(),
usage_preferences,
)
.await
{
Ok(route) => match route {
Some((_, model_name)) => model_name,
None => {
info!(
"No route determined, using default model from request: {}",
chat_completions_request_for_arch_router.model
);
chat_completions_request_for_arch_router.model.clone()
}
},
Err(err) => {
let err_msg = format!("Failed to determine route: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
debug!(
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
full_qualified_llm_provider_url, model_name
);
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&model_name).unwrap(),
);
request_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
);
if let Some(trace_parent) = trace_parent {
request_headers.insert(
header::HeaderName::from_static("traceparent"),
header::HeaderValue::from_str(&trace_parent).unwrap(),
);
}
// remove content-length header if it exists
request_headers.remove(header::CONTENT_LENGTH);
let llm_response = match reqwest::Client::new()
.post(full_qualified_llm_provider_url)
.headers(request_headers)
.body(client_request_bytes_for_upstream)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
// copy over the headers and status code from the original response
let response_headers = llm_response.headers().clone();
let upstream_status = llm_response.status();
let mut response = Response::builder().status(upstream_status);
let headers = response.headers_mut().unwrap();
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
// channel to create async stream
let (tx, rx) = mpsc::channel::<Bytes>(16);
// Spawn a task to send data as it becomes available
tokio::spawn(async move {
let mut byte_stream = llm_response.bytes_stream();
while let Some(item) = byte_stream.next().await {
let item = match item {
Ok(item) => item,
Err(err) => {
warn!("Error receiving chunk: {:?}", err);
break;
}
};
if tx.send(item).await.is_err() {
warn!("Receiver dropped");
break;
}
}
});
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
match response.body(stream_body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -42,19 +42,23 @@ mod integration_tests {
// Setup services
let router_service = create_test_router_service();
let agent_selector = AgentSelector::new(router_service);
let pipeline_processor = PipelineProcessor::default();
let mut pipeline_processor = PipelineProcessor::default();
// Create test data
let agents = vec![
Agent {
id: "filter-agent".to_string(),
kind: Some("filter".to_string()),
agent_type: Some("filter".to_string()),
url: "http://localhost:8081".to_string(),
tool: None,
transport: None,
},
Agent {
id: "terminal-agent".to_string(),
kind: Some("terminal".to_string()),
agent_type: Some("terminal".to_string()),
url: "http://localhost:8082".to_string(),
tool: None,
transport: None,
},
];
@ -107,7 +111,15 @@ mod integration_tests {
let headers = HeaderMap::new();
let result = pipeline_processor
.process_filter_chain(&request, &test_pipeline, &agent_map, &headers)
.process_filter_chain(
&request.messages,
&test_pipeline,
&agent_map,
&headers,
None,
String::new(),
String::new(),
)
.await;
println!("Pipeline processing result: {:?}", result);

View file

@ -0,0 +1,49 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const JSON_RPC_VERSION: &str = "2.0";
pub const TOOL_CALL_METHOD : &str = "tools/call";
pub const MCP_INITIALIZE: &str = "initialize";
pub const MCP_INITIALIZE_NOTIFICATION: &str = "initialize/notification";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum JsonRpcId {
String(String),
Number(u64),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: JsonRpcId,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcNotification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: JsonRpcId,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}

View file

@ -0,0 +1,462 @@
use bytes::Bytes;
use common::configuration::{LlmProvider, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
use common::traces::TraceCollector;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
StateStorage, StateStorageError,
extract_input_items, retrieve_and_combine_input
};
use crate::tracing::operation_component;
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
pub async fn llm_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<Vec<LlmProvider>>>,
trace_collector: Arc<TraceCollector>,
state_storage: Option<Arc<dyn StateStorage>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
let request_id = request_headers
.get(REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| "unknown".to_string());
// Extract or generate traceparent - this establishes the trace context for all spans
let traceparent: String = request_headers
.get(TRACE_PARENT_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| {
use uuid::Uuid;
let trace_id = Uuid::new_v4().to_string().replace("-", "");
format!("00-{}-0000000000000000-01", trace_id)
});
let mut request_headers = request_headers;
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
"[PLANO_REQ_ID:{}] | REQUEST_BODY (UTF8): {}",
request_id,
String::from_utf8_lossy(&chat_request_bytes)
);
let mut client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err);
let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
// === v1/responses state management: Extract input items early ===
let mut original_input_items = Vec::new();
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)));
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
let model_from_request = client_request.model().to_string();
let temperature = client_request.get_temperature();
let is_streaming_request = client_request.is_streaming();
let resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
// Extract tool names and user message preview for span attributes
let tool_names = client_request.get_tool_names();
let user_message_preview = client_request.get_recent_user_message()
.map(|msg| truncate_message(&msg, 50));
client_request.set_model(resolved_model.clone());
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id);
}
// === v1/responses state management: Determine upstream API and combine input if needed ===
// Do this BEFORE routing since routing consumes the request
// Only process state if state_storage is configured
let mut should_manage_state = false;
if is_responses_api_client && state_storage.is_some() {
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request {
// Extract original input once
original_input_items = extract_input_items(&responses_req.input);
// Get the upstream path and check if it's ResponsesAPI
let upstream_path = get_upstream_path(
&llm_providers,
&resolved_model,
&request_path,
&resolved_model,
is_streaming_request,
).await;
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)));
if should_manage_state {
// Retrieve and combine conversation history if previous_response_id exists
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
match retrieve_and_combine_input(
state_storage.as_ref().unwrap().clone(),
prev_resp_id,
original_input_items, // Pass ownership instead of cloning
)
.await
{
Ok(combined_input) => {
// Update both the request and original_input_items
responses_req.input = InputParam::Items(combined_input.clone());
original_input_items = combined_input;
info!("[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Updated request with conversation history ({} items)", request_id, original_input_items.len());
}
Err(StateStorageError::NotFound(_)) => {
// Return 409 Conflict when previous_response_id not found
warn!("[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Previous response_id not found: {}", request_id, prev_resp_id);
let err_msg = format!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Conversation state not found for previous_response_id: {}",
request_id, prev_resp_id
);
let mut conflict_response = Response::new(full(err_msg));
*conflict_response.status_mut() = StatusCode::CONFLICT;
return Ok(conflict_response);
}
Err(e) => {
// Log warning but continue on other storage errors
warn!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to retrieve conversation state for {}: {}",
request_id, prev_resp_id, e
);
// Restore original_input_items since we passed ownership
original_input_items = extract_input_items(&responses_req.input);
}
}
}
} else {
debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id);
}
}
}
// Serialize request for upstream BEFORE router consumes it
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
// Determine routing using the dedicated router_chat module
let routing_result = match router_chat_get_upstream_model(
router_service,
client_request, // Pass the original request - router_chat will convert it
&request_headers,
trace_collector.clone(),
&traceparent,
&request_path,
)
.await
{
Ok(result) => result,
Err(err) => {
let mut internal_error = Response::new(full(err.message));
*internal_error.status_mut() = err.status_code;
return Ok(internal_error);
}
};
let model_name = routing_result.model_name;
debug!(
"[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Resolved Model: {}",
request_id, full_qualified_llm_provider_url, model_name
);
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&model_name).unwrap(),
);
request_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
);
// remove content-length header if it exists
request_headers.remove(header::CONTENT_LENGTH);
// Capture start time right before sending request to upstream
let request_start_time = std::time::Instant::now();
let request_start_system_time = std::time::SystemTime::now();
let llm_response = match reqwest::Client::new()
.post(full_qualified_llm_provider_url)
.headers(request_headers)
.body(client_request_bytes_for_upstream)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
// copy over the headers and status code from the original response
let response_headers = llm_response.headers().clone();
let upstream_status = llm_response.status();
let mut response = Response::builder().status(upstream_status);
let headers = response.headers_mut().unwrap();
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
// Build LLM span with actual status code using constants
let byte_stream = llm_response.bytes_stream();
// Build the LLM span (will be finalized after streaming completes)
let llm_span = build_llm_span(
&traceparent,
&request_path,
&resolved_model,
&model_name,
upstream_status.as_u16(),
is_streaming_request,
request_start_system_time,
tool_names,
user_message_preview,
temperature,
&llm_providers,
).await;
// Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new(
trace_collector,
operation_component::LLM,
llm_span,
request_start_time,
);
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() {
// Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
// Wrap with state management processor to store state after response completes
let state_processor = ResponsesStateProcessor::new(
base_processor,
state_storage.unwrap(),
original_input_items,
resolved_model.clone(),
model_name.clone(),
is_streaming_request,
false, // Not OpenAI upstream since should_manage_state is true
content_encoding,
request_id.clone(),
);
create_streaming_response(byte_stream, state_processor, 16)
} else {
// Use base processor without state management
create_streaming_response(byte_stream, base_processor, 16)
};
match response.body(streaming_response.body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
}
}
/// Resolves model aliases by looking up the requested model in the model_aliases map.
/// Returns the target model if an alias is found, otherwise returns the original model.
fn resolve_model_alias(
model_from_request: &str,
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
) -> String {
if let Some(aliases) = model_aliases.as_ref() {
if let Some(model_alias) = aliases.get(model_from_request) {
debug!(
"Model Alias: 'From {}' -> 'To {}'",
model_from_request, model_alias.target
);
return model_alias.target.clone();
}
}
model_from_request.to_string()
}
/// Builds the LLM span with all required and optional attributes.
async fn build_llm_span(
traceparent: &str,
request_path: &str,
resolved_model: &str,
model_name: &str,
status_code: u16,
is_streaming: bool,
start_time: std::time::SystemTime,
tool_names: Option<Vec<String>>,
user_message_preview: Option<String>,
temperature: Option<f32>,
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
) -> common::traces::Span {
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
use crate::tracing::{http, llm, OperationNameBuilder};
// Calculate the upstream path based on provider configuration
let upstream_path = get_upstream_path(
llm_providers,
model_name,
request_path,
resolved_model,
is_streaming,
).await;
// Build operation name showing path transformation if different
let operation_name = if request_path != upstream_path {
OperationNameBuilder::new()
.with_method("POST")
.with_path(&format!("{} >> {}", request_path, upstream_path))
.with_target(resolved_model)
.build()
} else {
OperationNameBuilder::new()
.with_method("POST")
.with_path(request_path)
.with_target(resolved_model)
.build()
};
let (trace_id, parent_span_id) = parse_traceparent(traceparent);
let mut span_builder = SpanBuilder::new(&operation_name)
.with_trace_id(&trace_id)
.with_kind(SpanKind::Client)
.with_start_time(start_time)
.with_attribute(http::METHOD, "POST")
.with_attribute(http::STATUS_CODE, status_code.to_string())
.with_attribute(http::TARGET, request_path.to_string())
.with_attribute(http::UPSTREAM_TARGET, upstream_path)
.with_attribute(llm::MODEL_NAME, resolved_model.to_string())
.with_attribute(llm::IS_STREAMING, is_streaming.to_string());
// Only set parent span ID if it exists (not a root span)
if let Some(parent) = parent_span_id {
span_builder = span_builder.with_parent_span_id(&parent);
}
// Add optional attributes
if let Some(temp) = temperature {
span_builder = span_builder.with_attribute(llm::TEMPERATURE, temp.to_string());
}
if let Some(tools) = tool_names {
let formatted_tools = tools.iter()
.map(|name| format!("{}(...)", name))
.collect::<Vec<_>>()
.join("\n");
span_builder = span_builder.with_attribute(llm::TOOLS, formatted_tools);
}
if let Some(preview) = user_message_preview {
span_builder = span_builder.with_attribute(llm::USER_MESSAGE_PREVIEW, preview);
}
span_builder.build()
}
/// Calculates the upstream path for the provider based on the model name.
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
/// then uses target_endpoint_for_provider to calculate the correct upstream path.
async fn get_upstream_path(
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
model_name: &str,
request_path: &str,
resolved_model: &str,
is_streaming: bool,
) -> String {
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
// Calculate the upstream path using the proper API
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
.expect("Should have valid API endpoint");
client_api.target_endpoint_for_provider(
&provider_id,
request_path,
resolved_model,
is_streaming,
base_url_path_prefix.as_deref(),
)
}
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
async fn get_provider_info(
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
model_name: &str,
) -> (hermesllm::ProviderId, Option<String>) {
let providers_lock = llm_providers.read().await;
// First, try to find by model name or provider name
let provider = providers_lock.iter().find(|p| {
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|| p.name == model_name
});
if let Some(provider) = provider {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
return (provider_id, prefix);
}
let default_provider = providers_lock.iter().find(|p| {
p.default.unwrap_or(false)
});
if let Some(provider) = default_provider {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
} else {
// Last resort: use OpenAI as hardcoded fallback
warn!("No default provider found, falling back to OpenAI");
(hermesllm::ProviderId::OpenAI, None)
}
}

View file

@ -1,9 +1,13 @@
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod chat_completions;
pub mod llm;
pub mod router_chat;
pub mod models;
pub mod function_calling;
pub mod pipeline_processor;
pub mod response_handler;
pub mod utils;
pub mod jsonrpc;
#[cfg(test)]
mod integration_tests;

View file

@ -1,10 +1,24 @@
use std::collections::HashMap;
use common::configuration::{Agent, AgentFilterChain};
use common::consts::{ARCH_UPSTREAM_HOST_HEADER, ENVOY_RETRY_HEADER};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
use common::consts::{
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
};
use common::traces::{SpanBuilder, SpanKind, generate_random_span_id};
use hermesllm::apis::openai::Message;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::header::HeaderMap;
use tracing::{debug, warn};
use std::time::{Instant, SystemTime};
use tracing::{debug, info, warn};
use crate::tracing::operation_component::{self};
use crate::tracing::{http, OperationNameBuilder};
use crate::handlers::jsonrpc::{
JsonRpcId, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, JSON_RPC_VERSION,
MCP_INITIALIZE, MCP_INITIALIZE_NOTIFICATION, TOOL_CALL_METHOD,
};
use uuid::Uuid;
/// Errors that can occur during pipeline processing
#[derive(Debug, thiserror::Error)]
@ -19,19 +33,41 @@ pub enum PipelineError {
NoChoicesInResponse(String),
#[error("No content in response from agent '{0}'")]
NoContentInResponse(String),
#[error("No result in response from agent '{0}'")]
NoResultInResponse(String),
#[error("No structured content in response from agent '{0}'")]
NoStructuredContentInResponse(String),
#[error("No messages in response from agent '{0}'")]
NoMessagesInResponse(String),
#[error("Client error from agent '{agent}' (HTTP {status}): {body}")]
ClientError {
agent: String,
status: u16,
body: String,
},
#[error("Server error from agent '{agent}' (HTTP {status}): {body}")]
ServerError {
agent: String,
status: u16,
body: String,
},
}
/// Service for processing agent pipelines
pub struct PipelineProcessor {
client: reqwest::Client,
url: String,
agent_id_session_map: HashMap<String, String>,
}
const ENVOY_API_ROUTER_ADDRESS: &str = "http://localhost:11000";
impl Default for PipelineProcessor {
fn default() -> Self {
Self {
client: reqwest::Client::new(),
url: "http://localhost:11000/v1/chat/completions".to_string(),
url: ENVOY_API_ROUTER_ADDRESS.to_string(),
agent_id_session_map: HashMap::new(),
}
}
}
@ -41,18 +77,128 @@ impl PipelineProcessor {
Self {
client: reqwest::Client::new(),
url,
agent_id_session_map: HashMap::new(),
}
}
/// Record a span for filter execution
fn record_filter_span(
&self,
collector: &std::sync::Arc<common::traces::TraceCollector>,
agent_name: &str,
tool_name: &str,
start_time: SystemTime,
end_time: SystemTime,
elapsed: std::time::Duration,
trace_id: String,
parent_span_id: String,
span_id: String,
) -> String {
// let (trace_id, parent_span_id) = self.extract_trace_context();
// Build operation name: POST /agents/* {filter_name}
// Using generic path since we don't have access to specific endpoint here
let operation_name = OperationNameBuilder::new()
.with_method("POST")
.with_path("/agents/*")
.with_target(agent_name)
.build();
let mut span_builder = SpanBuilder::new(&operation_name)
.with_span_id(span_id.clone())
.with_kind(SpanKind::Client)
.with_start_time(start_time)
.with_end_time(end_time)
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, "/agents/*")
.with_attribute("filter.name", agent_name.to_string())
.with_attribute("filter.tool_name", tool_name.to_string())
.with_attribute(
"duration_ms",
format!("{:.2}", elapsed.as_secs_f64() * 1000.0),
);
if !trace_id.is_empty() {
span_builder = span_builder.with_trace_id(trace_id);
}
if !parent_span_id.is_empty() {
span_builder = span_builder.with_parent_span_id(parent_span_id);
}
let span = span_builder.build();
// Use plano(filter) as service name for filter execution spans
collector.record_span(operation_component::AGENT_FILTER, span);
span_id.clone()
}
/// Record a span for MCP protocol interactions
fn record_mcp_span(
&self,
collector: &std::sync::Arc<common::traces::TraceCollector>,
operation: &str,
agent_id: &str,
start_time: SystemTime,
end_time: SystemTime,
elapsed: std::time::Duration,
additional_attrs: Option<HashMap<&str, String>>,
trace_id: String,
parent_span_id: String,
span_id: Option<String>,
) {
// let (trace_id, parent_span_id) = self.extract_trace_context();
// Build operation name: POST /mcp {agent_id}
let operation_name = OperationNameBuilder::new()
.with_method("POST")
.with_path("/mcp")
.with_operation(operation)
.with_target(agent_id)
.build();
let mut span_builder = SpanBuilder::new(&operation_name)
.with_span_id(span_id.unwrap_or_else(|| generate_random_span_id()))
.with_kind(SpanKind::Client)
.with_start_time(start_time)
.with_end_time(end_time)
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string()))
.with_attribute("mcp.operation", operation.to_string())
.with_attribute("mcp.agent_id", agent_id.to_string())
.with_attribute(
"duration_ms",
format!("{:.2}", elapsed.as_secs_f64() * 1000.0),
);
if let Some(attrs) = additional_attrs {
for (key, value) in attrs {
span_builder = span_builder.with_attribute(key, value);
}
}
if !trace_id.is_empty() {
span_builder = span_builder.with_trace_id(trace_id);
}
if !parent_span_id.is_empty() {
span_builder = span_builder.with_parent_span_id(parent_span_id);
}
let span = span_builder.build();
// MCP spans also use plano(filter) service name as they are part of filter operations
collector.record_span(operation_component::AGENT_FILTER, span);
}
/// Process the filter chain of agents (all except the terminal agent)
pub async fn process_filter_chain(
&self,
initial_request: &ChatCompletionsRequest,
&mut self,
chat_history: &[Message],
agent_filter_chain: &AgentFilterChain,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
trace_collector: Option<&std::sync::Arc<common::traces::TraceCollector>>,
trace_id: String,
parent_span_id: String,
) -> Result<Vec<Message>, PipelineError> {
let mut chat_completions_history = initial_request.messages.clone();
let mut chat_history_updated = chat_history.to_vec();
for agent_name in &agent_filter_chain.filter_chain {
debug!("Processing filter agent: {}", agent_name);
@ -61,101 +207,490 @@ impl PipelineProcessor {
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
debug!("Agent details: {:?}", agent);
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
let response_content = self
.send_agent_filter_chain_request(
&chat_completions_history,
initial_request,
info!(
"executing filter: {}/{}, url: {}, conversation length: {}",
agent_name,
tool_name,
agent.url,
chat_history.len()
);
let start_time = SystemTime::now();
let start_instant = Instant::now();
// Generate filter span ID before execution so MCP spans can use it as parent
let filter_span_id = generate_random_span_id();
chat_history_updated = self
.execute_filter(
&chat_history_updated,
agent,
request_headers,
trace_collector,
trace_id.clone(),
filter_span_id.clone(),
)
.await?;
debug!("Received response from filter agent {}", agent_name);
let end_time = SystemTime::now();
let elapsed = start_instant.elapsed();
// Parse the response content as new message history
chat_completions_history =
serde_json::from_str(&response_content).inspect_err(|err| {
warn!(
"Failed to parse response from agent {}, err: {}, response: {}",
agent_name, err, response_content
)
})?;
info!(
"Filter '{}' completed in {:.2}ms, updated conversation length: {}",
agent_name,
elapsed.as_secs_f64() * 1000.0,
chat_history_updated.len()
);
// Record span for this filter execution
if let Some(collector) = trace_collector {
self.record_filter_span(
collector,
agent_name,
tool_name,
start_time,
end_time,
elapsed,
trace_id.clone(),
parent_span_id.clone(),
filter_span_id,
);
}
}
Ok(chat_completions_history)
Ok(chat_history_updated)
}
/// Send request to a specific agent and return the response content
async fn send_agent_filter_chain_request(
/// Build common MCP headers for requests
fn build_mcp_headers(
&self,
messages: &[Message],
original_request: &ChatCompletionsRequest,
agent: &Agent,
request_headers: &HeaderMap,
) -> Result<String, PipelineError> {
let mut request = original_request.clone();
request.messages = messages.to_vec();
agent_id: &str,
session_id: Option<&str>,
trace_id: String,
parent_span_id: String,
) -> Result<HeaderMap, PipelineError> {
let trace_parent = format!("00-{}-{}-01", trace_id, parent_span_id);
let mut headers = request_headers.clone();
headers.remove(hyper::header::CONTENT_LENGTH);
let request_body = serde_json::to_string(&request)?;
debug!("Sending request to agent {}", agent.id);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
headers.remove(TRACE_PARENT_HEADER);
headers.insert(
TRACE_PARENT_HEADER,
hyper::header::HeaderValue::from_str(&trace_parent).unwrap(),
);
agent_headers.insert(
headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(agent_id)
.map_err(|_| PipelineError::AgentNotFound(agent_id.to_string()))?,
);
headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
);
headers.insert(
"Content-Type",
hyper::header::HeaderValue::from_static("application/json"),
);
if let Some(sid) = session_id {
headers.insert(
"mcp-session-id",
hyper::header::HeaderValue::from_str(sid).unwrap(),
);
}
Ok(headers)
}
/// Parse SSE formatted response and extract JSON-RPC data
fn parse_sse_response(
&self,
response_bytes: &[u8],
agent_id: &str,
) -> Result<String, PipelineError> {
let response_str = String::from_utf8_lossy(response_bytes);
let lines: Vec<&str> = response_str.lines().collect();
// Validate SSE format: first line should be "event: message"
if lines.is_empty() || lines[0] != "event: message" {
warn!(
"Invalid SSE response format from agent {}: expected 'event: message' as first line, got: {:?}",
agent_id,
lines.first()
);
return Err(PipelineError::NoContentInResponse(format!(
"Invalid SSE response format from agent {}: expected 'event: message' as first line",
agent_id
)));
}
// Find the data line
let data_lines: Vec<&str> = lines
.iter()
.filter(|line| line.starts_with("data: "))
.copied()
.collect();
if data_lines.len() != 1 {
warn!(
"Expected exactly one 'data:' line from agent {}, found {}",
agent_id,
data_lines.len()
);
return Err(PipelineError::NoContentInResponse(format!(
"Expected exactly one 'data:' line from agent {}, found {}",
agent_id,
data_lines.len()
)));
}
// Skip "data: " prefix
Ok(data_lines[0][6..].to_string())
}
/// Send an MCP request and return the response
async fn send_mcp_request(
&self,
json_rpc_request: &JsonRpcRequest,
headers: HeaderMap,
agent_id: &str,
) -> Result<reqwest::Response, PipelineError> {
let request_body = serde_json::to_string(json_rpc_request)?;
debug!(
"Sending MCP request to agent {}: {}",
agent_id, request_body
);
let response = self
.client
.post(&self.url)
.headers(agent_headers)
.post(format!("{}/mcp", self.url))
.headers(headers)
.body(request_body)
.send()
.await?;
Ok(response)
}
/// Build a tools/call JSON-RPC request
fn build_tool_call_request(
&self,
tool_name: &str,
messages: &[Message],
) -> Result<JsonRpcRequest, PipelineError> {
let mut arguments = HashMap::new();
arguments.insert("messages".to_string(), serde_json::to_value(messages)?);
let mut params = HashMap::new();
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
params.insert("arguments".to_string(), serde_json::to_value(arguments)?);
Ok(JsonRpcRequest {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: JsonRpcId::String(Uuid::new_v4().to_string()),
method: TOOL_CALL_METHOD.to_string(),
params: Some(params),
})
}
/// Send request to a specific agent and return the response content
async fn execute_filter(
&mut self,
messages: &[Message],
agent: &Agent,
request_headers: &HeaderMap,
trace_collector: Option<&std::sync::Arc<common::traces::TraceCollector>>,
trace_id: String,
filter_span_id: String,
) -> Result<Vec<Message>, PipelineError> {
// Get or create MCP session
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
session_id.clone()
} else {
let session_id = self
.get_new_session_id(
&agent.id,
trace_id.clone(),
filter_span_id.clone(),
)
.await;
self.agent_id_session_map
.insert(agent.id.clone(), session_id.clone());
session_id
};
info!(
"Using MCP session ID {} for agent {}",
mcp_session_id, agent.id
);
// Build JSON-RPC request
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
let json_rpc_request = self.build_tool_call_request(tool_name, messages)?;
// Generate span ID for this MCP tool call (child of filter span)
let mcp_span_id = generate_random_span_id();
// Build headers
let agent_headers =
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id), trace_id.clone(), mcp_span_id.clone())?;
// Send request with tracing
let start_time = SystemTime::now();
let start_instant = Instant::now();
let response = self
.send_mcp_request(
&json_rpc_request,
agent_headers,
&agent.id,
)
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
// Parse the response as JSON to extract the content
let response_json: serde_json::Value = serde_json::from_slice(&response_bytes)?;
let end_time = SystemTime::now();
let elapsed = start_instant.elapsed();
let content = response_json
.get("choices")
.and_then(|choices| choices.as_array())
.and_then(|choices| choices.first())
.and_then(|choice| choice.get("message"))
.and_then(|message| message.get("content"))
.and_then(|content| content.as_str())
.ok_or_else(|| PipelineError::NoContentInResponse(agent.id.clone()))?
// Record MCP tool call span
if let Some(collector) = trace_collector {
let mut attrs = HashMap::new();
attrs.insert("mcp.method", "tools/call".to_string());
attrs.insert("mcp.tool_name", tool_name.to_string());
attrs.insert("mcp.session_id", mcp_session_id.clone());
attrs.insert("http.status_code", http_status.as_u16().to_string());
self.record_mcp_span(
collector,
"tool_call",
&agent.id,
start_time,
end_time,
elapsed,
Some(attrs),
trace_id.clone(),
filter_span_id.clone(),
Some(mcp_span_id),
);
}
// Handle HTTP errors
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
PipelineError::ClientError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
} else {
PipelineError::ServerError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
});
}
info!(
"Response from agent {}: {}",
agent.id,
String::from_utf8_lossy(&response_bytes)
);
// Parse SSE response
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
let response_result = response
.result
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
// Check if error field is set in response result
if response_result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
let error_message = response_result
.get("content")
.and_then(|v| v.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.get("text"))
.and_then(|v| v.as_str())
.unwrap_or("unknown_error")
.to_string();
return Err(PipelineError::ClientError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_message,
});
}
// Extract structured content and parse messages
let response_json = response_result
.get("structuredContent")
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
let messages: Vec<Message> = response_json
.get("result")
.and_then(|v| v.as_array())
.ok_or_else(|| PipelineError::NoMessagesInResponse(agent.id.clone()))?
.iter()
.map(|msg_value| serde_json::from_value(msg_value.clone()))
.collect::<Result<Vec<Message>, _>>()
.map_err(PipelineError::ParseError)?;
Ok(messages)
}
/// Build an initialize JSON-RPC request
fn build_initialize_request(&self) -> JsonRpcRequest {
JsonRpcRequest {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: JsonRpcId::String(Uuid::new_v4().to_string()),
method: MCP_INITIALIZE.to_string(),
params: Some({
let mut params = HashMap::new();
params.insert(
"protocolVersion".to_string(),
serde_json::Value::String("2024-11-05".to_string()),
);
params.insert("capabilities".to_string(), serde_json::json!({}));
params.insert(
"clientInfo".to_string(),
serde_json::json!({
"name": BRIGHT_STAFF_SERVICE_NAME,
"version": "1.0.0"
}),
);
params
}),
}
}
/// Send initialized notification after session creation
async fn send_initialized_notification(
&self,
agent_id: &str,
session_id: &str,
trace_id: String,
parent_span_id: String,
) -> Result<(), PipelineError> {
let initialized_notification = JsonRpcNotification {
jsonrpc: JSON_RPC_VERSION.to_string(),
method: MCP_INITIALIZE_NOTIFICATION.to_string(),
params: None,
};
let notification_body = serde_json::to_string(&initialized_notification)?;
debug!("Sending initialized notification for agent {}", agent_id);
let headers = self.build_mcp_headers(&HeaderMap::new(), agent_id, Some(session_id), trace_id.clone(), parent_span_id.clone())?;
let response = self
.client
.post(format!("{}/mcp", self.url))
.headers(headers)
.body(notification_body)
.send()
.await?;
info!(
"Initialized notification response status: {}",
response.status()
);
Ok(())
}
async fn get_new_session_id(
&self,
agent_id: &str,
trace_id: String,
parent_span_id: String,
) -> String {
info!("Initializing MCP session for agent {}", agent_id);
let initialize_request = self.build_initialize_request();
let headers = self
.build_mcp_headers(&HeaderMap::new(), agent_id, None, trace_id.clone(), parent_span_id.clone())
.expect("Failed to build headers for initialization");
let response = self
.send_mcp_request(&initialize_request, headers, agent_id)
.await
.expect("Failed to initialize MCP session");
info!("Initialize response status: {}", response.status());
let session_id = response
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.expect("No mcp-session-id in response")
.to_string();
Ok(content)
info!(
"Created new MCP session for agent {}: {}",
agent_id, session_id
);
// Send initialized notification
self.send_initialized_notification(
agent_id,
&session_id,
trace_id.clone(),
parent_span_id.clone(),
)
.await
.expect("Failed to send initialized notification");
session_id
}
/// Send request to terminal agent and return the raw response for streaming
pub async fn invoke_upstream_agent(
pub async fn invoke_agent(
&self,
messages: &[Message],
original_request: &ChatCompletionsRequest,
mut original_request: ProviderRequestType,
terminal_agent: &Agent,
request_headers: &HeaderMap,
trace_id: String,
agent_span_id: String,
) -> Result<reqwest::Response, PipelineError> {
let mut request = original_request.clone();
request.messages = messages.to_vec();
// let mut request = original_request.clone();
original_request.set_messages(messages);
let request_body = serde_json::to_string(&request)?;
let request_body = ProviderRequestType::to_bytes(&original_request).unwrap();
// let request_body = serde_json::to_string(&request)?;
debug!("Sending request to terminal agent {}", terminal_agent.id);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
// Set traceparent header to make the egress span a child of the agent span
if !trace_id.is_empty() && !agent_span_id.is_empty() {
let trace_parent = format!("00-{}-{}-01", trace_id, agent_span_id);
agent_headers.remove(TRACE_PARENT_HEADER);
agent_headers.insert(
TRACE_PARENT_HEADER,
hyper::header::HeaderValue::from_str(&trace_parent).unwrap(),
);
}
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&terminal_agent.id)
@ -169,7 +704,7 @@ impl PipelineProcessor {
let response = self
.client
.post(&self.url)
.post(format!("{}/v1/chat/completions", self.url))
.headers(agent_headers)
.body(request_body)
.send()
@ -183,6 +718,7 @@ impl PipelineProcessor {
mod tests {
use super::*;
use hermesllm::apis::openai::{Message, MessageContent, Role};
use mockito::Server;
use std::collections::HashMap;
fn create_test_message(role: Role, content: &str) -> Message {
@ -206,23 +742,149 @@ mod tests {
#[tokio::test]
async fn test_agent_not_found_error() {
let processor = PipelineProcessor::default();
let mut processor = PipelineProcessor::default();
let agent_map = HashMap::new();
let request_headers = HeaderMap::new();
let initial_request = ChatCompletionsRequest {
messages: vec![create_test_message(Role::User, "Hello")],
model: "test-model".to_string(),
..Default::default()
};
let messages = vec![create_test_message(Role::User, "Hello")];
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
let result = processor
.process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers)
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers, None, String::new(), String::new())
.await;
assert!(result.is_err());
matches!(result.unwrap_err(), PipelineError::AgentNotFound(_));
}
#[tokio::test]
async fn test_execute_filter_http_status_error() {
let mut server = Server::new_async().await;
let _m = server
.mock("POST", "/mcp")
.with_status(500)
.with_body("boom")
.create();
let server_url = server.url();
let mut processor = PipelineProcessor::new(server_url.clone());
processor
.agent_id_session_map
.insert("agent-1".to_string(), "session-1".to_string());
let agent = Agent {
id: "agent-1".to_string(),
transport: None,
tool: None,
url: server_url,
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Hello")];
let request_headers = HeaderMap::new();
let result = processor
.execute_filter(&messages, &agent, &request_headers, None, "trace-123".to_string(), "span-123".to_string())
.await;
match result {
Err(PipelineError::ServerError { status, body, .. }) => {
assert_eq!(status, 500);
assert_eq!(body, "boom");
}
_ => panic!("Expected server error for 500 status"),
}
}
#[tokio::test]
async fn test_execute_filter_http_client_error() {
let mut server = Server::new_async().await;
let _m = server
.mock("POST", "/mcp")
.with_status(400)
.with_body("bad request")
.create();
let server_url = server.url();
let mut processor = PipelineProcessor::new(server_url.clone());
processor
.agent_id_session_map
.insert("agent-3".to_string(), "session-3".to_string());
let agent = Agent {
id: "agent-3".to_string(),
transport: None,
tool: None,
url: server_url,
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Ping")];
let request_headers = HeaderMap::new();
let result = processor
.execute_filter(&messages, &agent, &request_headers, None, "trace-456".to_string(), "span-456".to_string())
.await;
match result {
Err(PipelineError::ClientError { status, body, .. }) => {
assert_eq!(status, 400);
assert_eq!(body, "bad request");
}
_ => panic!("Expected client error for 400 status"),
}
}
#[tokio::test]
async fn test_execute_filter_mcp_error_flag() {
let rpc_body = serde_json::json!({
"jsonrpc": JSON_RPC_VERSION,
"id": "1",
"result": {
"isError": true,
"content": [
{ "text": "bad tool call" }
]
}
});
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string());
let mut server = Server::new_async().await;
let _m = server
.mock("POST", "/mcp")
.with_status(200)
.with_body(sse_body)
.create();
let server_url = server.url();
let mut processor = PipelineProcessor::new(server_url.clone());
processor
.agent_id_session_map
.insert("agent-2".to_string(), "session-2".to_string());
let agent = Agent {
id: "agent-2".to_string(),
transport: None,
tool: None,
url: server_url,
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Hi")];
let request_headers = HeaderMap::new();
let result = processor
.execute_filter(&messages, &agent, &request_headers, None, "trace-789".to_string(), "span-789".to_string())
.await;
match result {
Err(PipelineError::ClientError { status, body, .. }) => {
assert_eq!(status, 200);
assert_eq!(body, "bad tool call");
}
_ => panic!("Expected client error when isError flag is set"),
}
}
}

View file

@ -0,0 +1,249 @@
use common::configuration::ModelUsagePreference;
use common::consts::{REQUEST_ID_HEADER};
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::StatusCode;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::tracing::{OperationNameBuilder, operation_component, http, routing};
pub struct RoutingResult {
pub model_name: String
}
pub struct RoutingError {
pub message: String,
pub status_code: StatusCode,
}
impl RoutingError {
pub fn internal_error(message: String) -> Self {
Self {
message,
status_code: StatusCode::INTERNAL_SERVER_ERROR
}
}
}
/// Determines the routing decision if
///
/// # Returns
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
/// * `Err(RoutingError)` - Contains error details and optional span ID
pub async fn router_chat_get_upstream_model(
router_service: Arc<RouterService>,
client_request: ProviderRequestType,
request_headers: &hyper::HeaderMap,
trace_collector: Arc<TraceCollector>,
traceparent: &str,
request_path: &str,
) -> Result<RoutingResult, RoutingError> {
// Clone metadata for routing before converting (which consumes client_request)
let routing_metadata = client_request.metadata().clone();
let request_id = request_headers
.get(REQUEST_ID_HEADER)
.and_then(|value| value.to_str().ok())
.unwrap_or("unknown");
// Convert to ChatCompletionsRequest for routing (regardless of input type)
let chat_request = match ProviderRequestType::try_from((
client_request,
&SupportedUpstreamAPIs::OpenAIChatCompletions(
hermesllm::apis::OpenAIApi::ChatCompletions,
),
)) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok(
ProviderRequestType::MessagesRequest(_)
| ProviderRequestType::BedrockConverse(_)
| ProviderRequestType::BedrockConverseStream(_)
| ProviderRequestType::ResponsesAPIRequest(_),
) => {
warn!("Unexpected: got non-ChatCompletions request after converting to OpenAI format");
return Err(RoutingError::internal_error(
"Request conversion failed".to_string(),
));
}
Err(err) => {
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
return Err(RoutingError::internal_error(format!(
"Failed to convert request: {}",
err
)));
}
};
debug!(
"[PLANO_REQ_ID: {}]: ROUTER_REQ: {}",
request_id,
&serde_json::to_string(&chat_request).unwrap()
);
// Extract trace_parent from headers
let trace_parent = request_headers
.iter()
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
// Extract usage preferences from metadata
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("archgw_preference_config")
.map(|value| value.to_string())
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
// Prepare log message with latest message from chat request
let latest_message_for_log = chat_request
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content.to_string().replace('\n', "\\n")
});
const MAX_MESSAGE_LENGTH: usize = 50;
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
let truncated: String = latest_message_for_log
.chars()
.take(MAX_MESSAGE_LENGTH)
.collect();
format!("{}...", truncated)
} else {
latest_message_for_log
};
info!(
"[PLANO_REQ_ID: {}] | ROUTER_REQ | Usage preferences from request: {}, request_path: {}, latest message: {}",
request_id,
usage_preferences.is_some(),
request_path,
latest_message_for_log
);
// Capture start time for routing span
let routing_start_time = std::time::Instant::now();
let routing_start_system_time = std::time::SystemTime::now();
// Attempt to determine route using the router service
let routing_result = router_service
.determine_route(&chat_request.messages, trace_parent, usage_preferences)
.await;
match routing_result {
Ok(route) => match route {
Some((_, model_name)) => {
// Record successful routing span
let mut attrs: HashMap<String, String> = HashMap::new();
attrs.insert("route.selected_model".to_string(), model_name.clone());
record_routing_span(
trace_collector,
traceparent,
routing_start_time,
routing_start_system_time,
attrs,
)
.await;
Ok(RoutingResult {
model_name
})
}
None => {
// No route determined, use default model from request
info!(
"[PLANO_REQ_ID: {}] | ROUTER_REQ | No route determined, using default model from request: {}",
request_id,
chat_request.model
);
let default_model = chat_request.model.clone();
let mut attrs = HashMap::new();
attrs.insert("route.selected_model".to_string(), default_model.clone());
record_routing_span(
trace_collector,
traceparent,
routing_start_time,
routing_start_system_time,
attrs,
)
.await;
Ok(RoutingResult {
model_name: default_model
})
}
},
Err(err) => {
// Record failed routing span
let mut attrs = HashMap::new();
attrs.insert("route.selected_model".to_string(), "unknown".to_string());
attrs.insert("error.message".to_string(), err.to_string());
record_routing_span(
trace_collector,
traceparent,
routing_start_time,
routing_start_system_time,
attrs,
)
.await;
Err(RoutingError::internal_error(
format!("Failed to determine route: {}", err)
))
}
}
}
/// Helper function to record a routing span with the given attributes.
/// Reduces code duplication across different routing outcomes.
async fn record_routing_span(
trace_collector: Arc<TraceCollector>,
traceparent: &str,
start_time: std::time::Instant,
start_system_time: std::time::SystemTime,
attrs: HashMap<String, String>,
) {
// The routing always uses OpenAI Chat Completions format internally,
// so we log that as the actual API being used for routing
let routing_api_path = "/v1/chat/completions";
let routing_operation_name = OperationNameBuilder::new()
.with_method("POST")
.with_path(routing_api_path)
.with_target("Arch-Router-1.5B")
.build();
let (trace_id, parent_span_id) = parse_traceparent(traceparent);
// Build the routing span directly using constants
let mut span_builder = SpanBuilder::new(&routing_operation_name)
.with_trace_id(&trace_id)
.with_kind(SpanKind::Client)
.with_start_time(start_system_time)
.with_end_time(std::time::SystemTime::now())
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, routing_api_path.to_string())
.with_attribute(routing::ROUTE_DETERMINATION_MS, start_time.elapsed().as_millis().to_string());
// Only set parent span ID if it exists (not a root span)
if let Some(parent) = parent_span_id {
span_builder = span_builder.with_parent_span_id(&parent);
}
// Add all custom attributes
for (key, value) in attrs {
span_builder = span_builder.with_attribute(key, value);
}
let span = span_builder.build();
// Record the span directly to the collector
trace_collector.record_span(operation_component::ROUTING, span);
}

View file

@ -0,0 +1,259 @@
use bytes::Bytes;
use common::traces::{Span, Attribute, AttributeValue, TraceCollector, Event};
use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody;
use hyper::body::Frame;
use std::sync::Arc;
use std::time::{Instant, SystemTime};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::warn;
// Import tracing constants
use crate::tracing::{llm, error};
/// Trait for processing streaming chunks
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
pub trait StreamProcessor: Send + 'static {
/// Process an incoming chunk of bytes
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String>;
/// Called when the first bytes are received (for time-to-first-token tracking)
fn on_first_bytes(&mut self) {}
/// Called when streaming completes successfully
fn on_complete(&mut self) {}
/// Called when streaming encounters an error
fn on_error(&mut self, _error: &str) {}
}
/// A processor that tracks streaming metrics and finalizes the span
pub struct ObservableStreamProcessor {
collector: Arc<TraceCollector>,
service_name: String,
span: Span,
total_bytes: usize,
chunk_count: usize,
start_time: Instant,
time_to_first_token: Option<u128>,
}
impl ObservableStreamProcessor {
/// Create a new passthrough processor
///
/// # Arguments
/// * `collector` - The trace collector to record the span to
/// * `service_name` - The service name for this span (e.g., "archgw(llm)")
/// * `span` - The span to finalize after streaming completes
/// * `start_time` - When the request started (for duration calculation)
pub fn new(
collector: Arc<TraceCollector>,
service_name: impl Into<String>,
span: Span,
start_time: Instant,
) -> Self {
Self {
collector,
service_name: service_name.into(),
span,
total_bytes: 0,
chunk_count: 0,
start_time,
time_to_first_token: None,
}
}
}
impl StreamProcessor for ObservableStreamProcessor {
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
self.total_bytes += chunk.len();
self.chunk_count += 1;
Ok(Some(chunk))
}
fn on_first_bytes(&mut self) {
// Record time to first token (only for streaming)
if self.time_to_first_token.is_none() {
self.time_to_first_token = Some(self.start_time.elapsed().as_millis());
}
}
fn on_complete(&mut self) {
// Update span with streaming metrics and end time
let end_time_nanos = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
self.span.end_time_unix_nano = format!("{}", end_time_nanos);
// Add streaming metrics as attributes using constants
self.span.attributes.push(Attribute {
key: llm::RESPONSE_BYTES.to_string(),
value: AttributeValue {
string_value: Some(self.total_bytes.to_string()),
},
});
self.span.attributes.push(Attribute {
key: llm::DURATION_MS.to_string(),
value: AttributeValue {
string_value: Some(self.start_time.elapsed().as_millis().to_string()),
},
});
// Add time to first token if available (streaming only)
if let Some(ttft) = self.time_to_first_token {
self.span.attributes.push(Attribute {
key: llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
value: AttributeValue {
string_value: Some(ttft.to_string()),
},
});
// Add time to first token as a span event
// Calculate the timestamp by adding ttft duration to span start time
if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::<u128>() {
// Convert ttft from milliseconds to nanoseconds and add to start time
let event_timestamp = start_time_nanos + (ttft * 1_000_000);
let mut event = Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
event.add_attribute(
llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
ttft.to_string(),
);
// Initialize events vector if needed
if self.span.events.is_none() {
self.span.events = Some(Vec::new());
}
if let Some(ref mut events) = self.span.events {
events.push(event);
}
}
}
// Record the finalized span
self.collector.record_span(&self.service_name, self.span.clone());
}
fn on_error(&mut self, error_msg: &str) {
warn!("Stream error in PassthroughProcessor: {}", error_msg);
// Update span with error info and end time
let end_time_nanos = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
self.span.end_time_unix_nano = format!("{}", end_time_nanos);
self.span.attributes.push(Attribute {
key: error::ERROR.to_string(),
value: AttributeValue {
string_value: Some("true".to_string()),
},
});
self.span.attributes.push(Attribute {
key: error::MESSAGE.to_string(),
value: AttributeValue {
string_value: Some(error_msg.to_string()),
},
});
self.span.attributes.push(Attribute {
key: llm::DURATION_MS.to_string(),
value: AttributeValue {
string_value: Some(self.start_time.elapsed().as_millis().to_string()),
},
});
// Record the error span
self.collector.record_span(&self.service_name, self.span.clone());
}
}
/// Result of creating a streaming response
pub struct StreamingResponse {
pub body: BoxBody<Bytes, hyper::Error>,
pub processor_handle: tokio::task::JoinHandle<()>,
}
pub fn create_streaming_response<S, P>(
mut byte_stream: S,
mut processor: P,
buffer_size: usize,
) -> StreamingResponse
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size);
// Spawn a task to process and forward chunks
let processor_handle = tokio::spawn(async move {
let mut is_first_chunk = true;
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
let err_msg = format!("Error receiving chunk: {:?}", err);
warn!("{}", err_msg);
processor.on_error(&err_msg);
break;
}
};
// Call on_first_bytes for the first chunk
if is_first_chunk {
processor.on_first_bytes();
is_first_chunk = false;
}
// Process the chunk
match processor.process_chunk(chunk) {
Ok(Some(processed_chunk)) => {
if tx.send(processed_chunk).await.is_err() {
warn!("Receiver dropped");
break;
}
}
Ok(None) => {
// Skip this chunk
continue;
}
Err(err) => {
warn!("Processor error: {}", err);
processor.on_error(&err);
break;
}
}
}
processor.on_complete();
});
// Convert channel receiver to HTTP stream
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
StreamingResponse {
body: stream_body,
processor_handle,
}
}
/// Truncates a message to the specified maximum length, adding "..." if truncated.
pub fn truncate_message(message: &str, max_length: usize) -> String {
if message.chars().count() > max_length {
let truncated: String = message.chars().take(max_length).collect();
format!("{}...", truncated)
} else {
message.to_string()
}
}

View file

@ -1,3 +1,5 @@
pub mod handlers;
pub mod router;
pub mod state;
pub mod tracing;
pub mod utils;

View file

@ -1,11 +1,16 @@
use brightstaff::handlers::agent_chat_completions::agent_chat;
use brightstaff::handlers::chat_completions::chat;
use brightstaff::handlers::function_calling::function_calling_chat_handler;
use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService;
use brightstaff::state::StateStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::Configuration;
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH};
use common::configuration::{Agent, Configuration};
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
use common::traces::TraceCollector;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
use hyper::body::Incoming;
use hyper::server::conn::http1;
@ -45,10 +50,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let _tracer_provider = init_tracer();
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
info!(
"current working directory: {}",
env::current_dir().unwrap().display()
);
// loading arch_config.yaml file
let arch_config_path = env::var("ARCH_CONFIG_PATH_RENDERED")
.unwrap_or_else(|_| "./arch_config_rendered.yaml".to_string());
@ -62,22 +63,23 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let arch_config = Arc::new(config);
// combine agents and filters into a single list of agents
let all_agents: Vec<Agent> = arch_config
.agents
.as_deref()
.unwrap_or_default()
.iter()
.chain(arch_config.filters.as_deref().unwrap_or_default())
.cloned()
.collect();
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
debug!(
"arch_config: {:?}",
&serde_json::to_string(arch_config.as_ref()).unwrap()
);
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
info!("llm provider url: {}", llm_provider_url);
info!("listening on http://{}", bind_address);
let listener = TcpListener::bind(bind_address).await?;
let routing_model_name: String = arch_config
.routing
.as_ref()
@ -99,18 +101,69 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model_aliases = Arc::new(arch_config.model_aliases.clone());
// Initialize trace collector and start background flusher
// Tracing is enabled if the tracing config is present in arch_config.yaml
// Pass Some(true/false) to override, or None to use env var OTEL_TRACING_ENABLED
let tracing_enabled = if arch_config.tracing.is_some() {
info!("Tracing configuration found in arch_config.yaml");
Some(true)
} else {
info!(
"No tracing configuration in arch_config.yaml, will check OTEL_TRACING_ENABLED env var"
);
None
};
let trace_collector = Arc::new(TraceCollector::new(tracing_enabled));
let _flusher_handle = trace_collector.clone().start_background_flusher();
// Initialize conversation state storage for v1/responses
// Configurable via arch_config.yaml state_storage section
// If not configured, state management is disabled
// Environment variables are substituted by envsubst before config is read
let state_storage: Option<Arc<dyn StateStorage>> = if let Some(storage_config) = &arch_config.state_storage {
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
common::configuration::StateStorageType::Memory => {
info!("Initialized conversation state storage: Memory");
Arc::new(MemoryConversationalStorage::new())
}
common::configuration::StateStorageType::Postgres => {
let connection_string = storage_config
.connection_string
.as_ref()
.expect("connection_string is required for postgres state_storage");
debug!("Postgres connection string (full): {}", connection_string);
info!("Initializing conversation state storage: Postgres");
Arc::new(
PostgreSQLConversationStorage::new(connection_string.clone())
.await
.expect("Failed to initialize Postgres state storage"),
)
}
};
Some(storage)
} else {
info!("No state_storage configured - conversation state management disabled");
None
};
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
let router_service: Arc<RouterService> = Arc::clone(&router_service);
let model_aliases = Arc::clone(&model_aliases);
let model_aliases: Arc<
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
> = Arc::clone(&model_aliases);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let agents_list = agents_list.clone();
let agents_list = combined_agents_filters_list.clone();
let listeners = listeners.clone();
let trace_collector = trace_collector.clone();
let state_storage = state_storage.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
@ -119,28 +172,46 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let listeners = listeners.clone();
let trace_collector = trace_collector.clone();
let state_storage = state_storage.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
chat(req, router_service, fully_qualified_url, model_aliases)
.with_context(parent_cx)
.await
}
(&Method::POST, "/agents/v1/chat/completions") => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
agent_chat(
let path = req.uri().path();
// Check if path starts with /agents
if path.starts_with("/agents") {
// Check if it matches one of the agent API paths
let stripped_path = path.strip_prefix("/agents").unwrap();
if matches!(
stripped_path,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path);
return agent_chat(
req,
router_service,
fully_qualified_url,
agents_list,
listeners,
trace_collector,
)
.with_context(parent_cx)
.await
.await;
}
}
match (req.method(), path) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, path);
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage)
.with_context(parent_cx)
.await
}
(&Method::POST, "/function_calling") => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, "/v1/chat/completions");
function_calling_chat_handler(req, fully_qualified_url)
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
Ok(list_models(llm_providers).await)

View file

@ -1,3 +1,5 @@
pub mod llm_router;
pub mod orchestrator_model;
pub mod orchestrator_model_v1;
pub mod router_model;
pub mod router_model_v1;

View file

@ -0,0 +1,30 @@
use common::configuration::AgentUsagePreference;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum OrchestratorModelError {
#[error("Failed to parse JSON: {0}")]
JsonError(#[from] serde_json::Error),
}
pub type Result<T> = std::result::Result<T, OrchestratorModelError>;
/// OrchestratorModel trait for handling orchestration requests.
/// Unlike RouterModel which returns a single route, OrchestratorModel
/// can return multiple routes as the model output format is:
/// {"route": ["route_name_1", "route_name_2", ...]}
pub trait OrchestratorModel: Send + Sync {
fn generate_request(
&self,
messages: &[Message],
usage_preferences: &Option<Vec<AgentUsagePreference>>,
) -> ChatCompletionsRequest;
/// Returns a vector of (route_name, model_name) tuples for all matched routes.
fn parse_response(
&self,
content: &str,
usage_preferences: &Option<Vec<AgentUsagePreference>>,
) -> Result<Option<Vec<(String, String)>>>;
fn get_model_name(&self) -> String;
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,611 @@
use super::{OpenAIConversationState, StateStorage, StateStorageError};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, warn};
/// In-memory storage backend for conversation state
/// Uses a HashMap wrapped in Arc<RwLock<>> for thread-safe access
#[derive(Clone)]
pub struct MemoryConversationalStorage {
storage: Arc<RwLock<HashMap<String, OpenAIConversationState>>>,
}
impl MemoryConversationalStorage {
pub fn new() -> Self {
Self {
storage: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for MemoryConversationalStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StateStorage for MemoryConversationalStorage {
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
let response_id = state.response_id.clone();
let mut storage = self.storage.write().await;
debug!(
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Storing conversation state: model={}, provider={}, input_items={}",
response_id, state.model, state.provider, state.input_items.len()
);
storage.insert(response_id, state);
Ok(())
}
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
let storage = self.storage.read().await;
match storage.get(response_id) {
Some(state) => {
debug!(
"[PLANO | MEMORY_STORAGE | RESP_ID:{} | Retrieved conversation state: input_items={}",
response_id, state.input_items.len()
);
Ok(state.clone())
}
None => {
warn!(
"[PLANO_RESP_ID:{} | MEMORY_STORAGE | Conversation state not found",
response_id
);
Err(StateStorageError::NotFound(response_id.to_string()))
}
}
}
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
let storage = self.storage.read().await;
Ok(storage.contains_key(response_id))
}
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
let mut storage = self.storage.write().await;
if storage.remove(response_id).is_some() {
debug!(
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Deleted conversation state",
response_id
);
Ok(())
} else {
Err(StateStorageError::NotFound(response_id.to_string()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent};
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
let mut input_items = Vec::new();
for i in 0..num_messages {
input_items.push(InputItem::Message(InputMessage {
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
content: MessageContent::Items(vec![InputContent::InputText {
text: format!("Message {}", i),
}]),
}));
}
OpenAIConversationState {
response_id: response_id.to_string(),
input_items,
created_at: 1234567890,
model: "claude-3".to_string(),
provider: "anthropic".to_string(),
}
}
#[tokio::test]
async fn test_put_and_get_success() {
let storage = MemoryConversationalStorage::new();
let state: OpenAIConversationState = create_test_state("resp_001", 3);
// Store
storage.put(state.clone()).await.unwrap();
// Retrieve
let retrieved = storage.get("resp_001").await.unwrap();
assert_eq!(retrieved.response_id, state.response_id);
assert_eq!(retrieved.model, state.model);
assert_eq!(retrieved.provider, state.provider);
assert_eq!(retrieved.input_items.len(), 3);
assert_eq!(retrieved.created_at, state.created_at);
}
#[tokio::test]
async fn test_put_overwrites_existing() {
let storage = MemoryConversationalStorage::new();
// First state
let state1 = create_test_state("resp_002", 2);
storage.put(state1).await.unwrap();
// Overwrite with new state
let state2 = OpenAIConversationState {
response_id: "resp_002".to_string(),
input_items: vec![],
created_at: 9999999999,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
};
storage.put(state2.clone()).await.unwrap();
// Should retrieve the new state
let retrieved = storage.get("resp_002").await.unwrap();
assert_eq!(retrieved.model, "gpt-4");
assert_eq!(retrieved.provider, "openai");
assert_eq!(retrieved.input_items.len(), 0);
assert_eq!(retrieved.created_at, 9999999999);
}
#[tokio::test]
async fn test_get_not_found() {
let storage = MemoryConversationalStorage::new();
let result = storage.get("nonexistent").await;
assert!(result.is_err());
match result.unwrap_err() {
StateStorageError::NotFound(id) => {
assert_eq!(id, "nonexistent");
}
_ => panic!("Expected NotFound error"),
}
}
#[tokio::test]
async fn test_exists_returns_false_for_nonexistent() {
let storage = MemoryConversationalStorage::new();
assert!(!storage.exists("resp_003").await.unwrap());
}
#[tokio::test]
async fn test_exists_returns_true_after_put() {
let storage = MemoryConversationalStorage::new();
let state = create_test_state("resp_004", 1);
assert!(!storage.exists("resp_004").await.unwrap());
storage.put(state).await.unwrap();
assert!(storage.exists("resp_004").await.unwrap());
}
#[tokio::test]
async fn test_delete_success() {
let storage = MemoryConversationalStorage::new();
let state = create_test_state("resp_005", 2);
storage.put(state).await.unwrap();
assert!(storage.exists("resp_005").await.unwrap());
// Delete
storage.delete("resp_005").await.unwrap();
// Should no longer exist
assert!(!storage.exists("resp_005").await.unwrap());
assert!(storage.get("resp_005").await.is_err());
}
#[tokio::test]
async fn test_delete_not_found() {
let storage = MemoryConversationalStorage::new();
let result = storage.delete("nonexistent").await;
assert!(result.is_err());
match result.unwrap_err() {
StateStorageError::NotFound(id) => {
assert_eq!(id, "nonexistent");
}
_ => panic!("Expected NotFound error"),
}
}
#[tokio::test]
async fn test_merge_combines_inputs() {
let storage = MemoryConversationalStorage::new();
// Create a previous state with 2 messages
let prev_state = create_test_state("resp_006", 2);
// Create current input with 1 message
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "New message".to_string(),
}]),
})];
// Merge
let merged = storage.merge(&prev_state, current_input);
// Should have 3 messages total (2 from prev + 1 current)
assert_eq!(merged.len(), 3);
}
#[tokio::test]
async fn test_merge_preserves_order() {
let storage = MemoryConversationalStorage::new();
// Previous state has messages 0 and 1
let prev_state = create_test_state("resp_007", 2);
// Current input has message 2
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Message 2".to_string(),
}]),
})];
let merged = storage.merge(&prev_state, current_input);
// Verify order: prev messages first, then current
let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") };
match &msg.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
_ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") };
match &msg.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
_ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"),
}
}
#[tokio::test]
async fn test_merge_with_empty_current_input() {
let storage = MemoryConversationalStorage::new();
let prev_state = create_test_state("resp_008", 3);
let merged = storage.merge(&prev_state, vec![]);
// Should just have the previous state's items
assert_eq!(merged.len(), 3);
}
#[tokio::test]
async fn test_merge_with_empty_previous_state() {
let storage = MemoryConversationalStorage::new();
let prev_state = OpenAIConversationState {
response_id: "resp_009".to_string(),
input_items: vec![],
created_at: 1234567890,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
};
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Only message".to_string(),
}]),
})];
let merged = storage.merge(&prev_state, current_input);
// Should just have the current input
assert_eq!(merged.len(), 1);
}
#[tokio::test]
async fn test_concurrent_access() {
let storage = MemoryConversationalStorage::new();
// Spawn multiple tasks that write concurrently
let mut handles = vec![];
for i in 0..10 {
let storage_clone = storage.clone();
let handle = tokio::spawn(async move {
let state = create_test_state(&format!("resp_{}", i), i % 3);
storage_clone.put(state).await.unwrap();
});
handles.push(handle);
}
// Wait for all tasks
for handle in handles {
handle.await.unwrap();
}
// Verify all states were stored
for i in 0..10 {
assert!(storage.exists(&format!("resp_{}", i)).await.unwrap());
}
}
#[tokio::test]
async fn test_multiple_operations_on_same_id() {
let storage = MemoryConversationalStorage::new();
let state = create_test_state("resp_010", 1);
// Put
storage.put(state.clone()).await.unwrap();
// Get
let retrieved = storage.get("resp_010").await.unwrap();
assert_eq!(retrieved.response_id, "resp_010");
// Exists
assert!(storage.exists("resp_010").await.unwrap());
// Put again (overwrite)
let new_state = create_test_state("resp_010", 5);
storage.put(new_state).await.unwrap();
// Get updated
let updated = storage.get("resp_010").await.unwrap();
assert_eq!(updated.input_items.len(), 5);
// Delete
storage.delete("resp_010").await.unwrap();
// Should not exist
assert!(!storage.exists("resp_010").await.unwrap());
}
#[tokio::test]
async fn test_merge_with_tool_call_flow() {
// This test simulates a realistic tool call conversation flow:
// 1. User sends message: "What's the weather?"
// 2. Model responds with function call (converted to assistant message)
// 3. User sends function call output in next request with previous_response_id
// The merge should combine: user message + assistant function call + function output
let storage = MemoryConversationalStorage::new();
// Step 1: Previous state contains the initial exchange
// - User message: "What's the weather in SF?"
// - Assistant message (converted from FunctionCall): "Called function: get_weather..."
let prev_state = OpenAIConversationState {
response_id: "resp_tool_001".to_string(),
input_items: vec![
// Original user message
InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "What's the weather in San Francisco?".to_string(),
}]),
}),
// Assistant's function call (converted from OutputItem::FunctionCall)
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(),
}]),
}),
],
created_at: 1234567890,
model: "claude-3".to_string(),
provider: "anthropic".to_string(),
};
// Step 2: Current request includes function call output
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(),
}]),
})];
// Step 3: Merge should combine all conversation history
let merged = storage.merge(&prev_state, current_input);
// Should have 3 items: user question + assistant function call + function output
assert_eq!(merged.len(), 3);
// Verify the order and content
let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") };
assert!(matches!(msg1.role, MessageRole::User));
match &msg1.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => {
assert!(text.contains("weather in San Francisco"));
}
_ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") };
assert!(matches!(msg2.role, MessageRole::Assistant));
match &msg2.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => {
assert!(text.contains("get_weather"));
}
_ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") };
assert!(matches!(msg3.role, MessageRole::User));
match &msg3.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => {
assert!(text.contains("Function result"));
assert!(text.contains("temperature"));
}
_ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"),
}
}
#[tokio::test]
async fn test_merge_with_multiple_tool_calls() {
// Test a more complex scenario with multiple tool calls
let storage = MemoryConversationalStorage::new();
// Previous state has: user message + 2 function calls from assistant
let prev_state = OpenAIConversationState {
response_id: "resp_tool_002".to_string(),
input_items: vec![
InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "What's the weather and time in SF?".to_string(),
}]),
}),
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(),
}]),
}),
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(),
}]),
}),
],
created_at: 1234567890,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
};
// Current input: function outputs for both calls
let current_input = vec![
InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Weather result: {\"temp\": 68}".to_string(),
}]),
}),
InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Time result: {\"time\": \"14:30\"}".to_string(),
}]),
}),
];
let merged = storage.merge(&prev_state, current_input);
// Should have 5 items total: 1 user + 2 assistant calls + 2 function outputs
assert_eq!(merged.len(), 5);
// Verify first item is original user message
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
assert!(matches!(first.role, MessageRole::User));
// Verify last two are function outputs
let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") };
assert!(matches!(second_last.role, MessageRole::User));
match &second_last.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert!(text.contains("Weather result")),
_ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
assert!(matches!(last.role, MessageRole::User));
match &last.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert!(text.contains("Time result")),
_ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"),
}
}
#[tokio::test]
async fn test_merge_preserves_conversation_context_for_multi_turn() {
// Simulate a multi-turn conversation with tool calls
let storage = MemoryConversationalStorage::new();
// Previous state: full conversation history up to this point
let prev_state = OpenAIConversationState {
response_id: "resp_tool_003".to_string(),
input_items: vec![
// Turn 1: User asks about weather
InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "What's the weather?".to_string(),
}]),
}),
// Turn 1: Assistant calls get_weather
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Called function: get_weather".to_string(),
}]),
}),
// Turn 2: User provides function output
InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Weather: sunny, 72°F".to_string(),
}]),
}),
// Turn 2: Assistant responds with text
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: "It's sunny and 72°F in San Francisco today!".to_string(),
}]),
}),
],
created_at: 1234567890,
model: "claude-3".to_string(),
provider: "anthropic".to_string(),
};
// Turn 3: User asks follow-up question
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Should I bring an umbrella?".to_string(),
}]),
})];
let merged = storage.merge(&prev_state, current_input);
// Should have all 5 messages in order
assert_eq!(merged.len(), 5);
// Verify the entire conversation flow is preserved
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
match &first.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
_ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
match &last.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert!(text.contains("umbrella")),
_ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"),
}
}
}

View file

@ -0,0 +1,147 @@
use async_trait::async_trait;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
use std::sync::Arc;
use tracing::{debug};
pub mod memory;
pub mod response_state_processor;
pub mod postgresql;
/// Represents the conversational state for a v1/responses request
/// Contains the complete input/output history that can be restored
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIConversationState {
/// The response ID this state is associated with
pub response_id: String,
/// The complete input history (original input + accumulated outputs)
/// This is what gets prepended to new requests via previous_response_id
pub input_items: Vec<InputItem>,
/// Timestamp when this state was created
pub created_at: i64,
/// Model used for this response
pub model: String,
/// Provider that generated this response (e.g., "anthropic", "openai")
pub provider: String,
}
/// Error types for state storage operations
#[derive(Debug)]
pub enum StateStorageError {
/// State not found for given response_id
NotFound(String),
/// Storage backend error (network, database, etc.)
StorageError(String),
/// Serialization/deserialization error
SerializationError(String),
}
impl fmt::Display for StateStorageError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id),
StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg),
StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
}
}
}
impl Error for StateStorageError {}
/// Trait for conversation state storage backends
#[async_trait]
pub trait StateStorage: Send + Sync {
/// Store conversation state for a response
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError>;
/// Retrieve conversation state by response_id
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError>;
/// Check if state exists for a response_id
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError>;
/// Delete state for a response_id (optional, for cleanup)
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError>;
fn merge(
&self,
prev_state: &OpenAIConversationState,
current_input: Vec<InputItem>,
) -> Vec<InputItem> {
// Default implementation: prepend previous input, append current
let prev_count = prev_state.input_items.len();
let current_count = current_input.len();
let mut combined_input = prev_state.input_items.clone();
combined_input.extend(current_input);
debug!(
"PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}",
prev_state.response_id,
prev_count,
current_count,
combined_input.len(),
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
);
combined_input
}
}
/// Storage backend type enum
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageBackend {
Memory,
Supabase,
}
impl StorageBackend {
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"memory" => Some(StorageBackend::Memory),
"supabase" => Some(StorageBackend::Supabase),
_ => None,
}
}
}
// === Utility functions for state management ===
/// Extract input items from InputParam, converting text to structured format
pub fn extract_input_items(input: &InputParam) -> Vec<InputItem> {
match input {
InputParam::Text(text) => {
vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: text.clone(),
}]),
})]
}
InputParam::Items(items) => items.clone(),
}
}
/// Retrieve previous conversation state and combine with current input
/// Returns combined input if previous state found, or original input if not found/error
pub async fn retrieve_and_combine_input(
storage: Arc<dyn StateStorage>,
previous_response_id: &str,
current_input: Vec<InputItem>,
) -> Result<Vec<InputItem>, StateStorageError> {
// First get the previous state
let prev_state = storage.get(previous_response_id).await?;
let combined_input = storage.merge(&prev_state, current_input);
Ok(combined_input)
}

View file

@ -0,0 +1,432 @@
use super::{OpenAIConversationState, StateStorage, StateStorageError};
use async_trait::async_trait;
use serde_json;
use std::sync::Arc;
use tokio::sync::OnceCell;
use tokio_postgres::{Client, NoTls};
use tracing::{debug, info, warn};
/// Supabase/PostgreSQL storage backend for conversation state
#[derive(Clone)]
pub struct PostgreSQLConversationStorage {
client: Arc<Client>,
table_verified: Arc<OnceCell<()>>,
}
impl PostgreSQLConversationStorage {
/// Creates a new Supabase storage instance with the given connection string
pub async fn new(connection_string: String) -> Result<Self, StateStorageError> {
let (client, connection) = tokio_postgres::connect(&connection_string, NoTls)
.await
.map_err(|e| {
StateStorageError::StorageError(format!("Failed to connect to database: {}", e))
})?;
// Spawn the connection to run in the background
tokio::spawn(async move {
if let Err(e) = connection.await {
warn!("Database connection error: {}", e);
}
});
Ok(Self {
client: Arc::new(client),
table_verified: Arc::new(OnceCell::new()),
})
}
/// Ensures the conversation_states table exists (checks once, caches result)
async fn ensure_ready(&self) -> Result<(), StateStorageError> {
self.table_verified
.get_or_try_init(|| async {
let row = self
.client
.query_one(
"SELECT EXISTS (
SELECT FROM pg_tables
WHERE tablename = 'conversation_states'
)",
&[],
)
.await
.map_err(|e| {
StateStorageError::StorageError(format!(
"Failed to verify table existence: {}",
e
))
})?;
let exists: bool = row.get(0);
if !exists {
return Err(StateStorageError::StorageError(
"Table 'conversation_states' does not exist. \
Please run the setup SQL from docs/db_setup/conversation_states.sql"
.to_string(),
));
}
info!("Conversation state storage table verified");
Ok(())
})
.await?;
Ok(())
}
}
#[async_trait]
impl StateStorage for PostgreSQLConversationStorage {
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
self.ensure_ready().await?;
// Serialize input_items to JSONB
let input_items_json = serde_json::to_value(&state.input_items).map_err(|e| {
StateStorageError::StorageError(format!("Failed to serialize input_items: {}", e))
})?;
// Upsert the conversation state
self.client
.execute(
r#"
INSERT INTO conversation_states
(response_id, input_items, created_at, model, provider, updated_at)
VALUES ($1, $2, $3, $4, $5, NOW())
ON CONFLICT (response_id)
DO UPDATE SET
input_items = EXCLUDED.input_items,
model = EXCLUDED.model,
provider = EXCLUDED.provider,
updated_at = NOW()
"#,
&[
&state.response_id,
&input_items_json,
&state.created_at,
&state.model,
&state.provider,
],
)
.await
.map_err(|e| {
StateStorageError::StorageError(format!(
"Failed to store conversation state for {}: {}",
state.response_id, e
))
})?;
debug!("Stored conversation state for {}", state.response_id);
Ok(())
}
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
self.ensure_ready().await?;
let row = self
.client
.query_opt(
r#"
SELECT response_id, input_items, created_at, model, provider
FROM conversation_states
WHERE response_id = $1
"#,
&[&response_id],
)
.await
.map_err(|e| {
StateStorageError::StorageError(format!(
"Failed to fetch conversation state for {}: {}",
response_id, e
))
})?;
match row {
Some(row) => {
let response_id: String = row.get("response_id");
let input_items_json: serde_json::Value = row.get("input_items");
let created_at: i64 = row.get("created_at");
let model: String = row.get("model");
let provider: String = row.get("provider");
// Deserialize input_items from JSONB
let input_items =
serde_json::from_value(input_items_json).map_err(|e| {
StateStorageError::StorageError(format!(
"Failed to deserialize input_items: {}",
e
))
})?;
Ok(OpenAIConversationState {
response_id,
input_items,
created_at,
model,
provider,
})
}
None => Err(StateStorageError::NotFound(format!(
"Conversation state not found for response_id: {}",
response_id
))),
}
}
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
self.ensure_ready().await?;
let row = self
.client
.query_one(
"SELECT EXISTS(SELECT 1 FROM conversation_states WHERE response_id = $1)",
&[&response_id],
)
.await
.map_err(|e| {
StateStorageError::StorageError(format!(
"Failed to check existence for {}: {}",
response_id, e
))
})?;
let exists: bool = row.get(0);
Ok(exists)
}
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
self.ensure_ready().await?;
let rows_affected = self
.client
.execute(
"DELETE FROM conversation_states WHERE response_id = $1",
&[&response_id],
)
.await
.map_err(|e| {
StateStorageError::StorageError(format!(
"Failed to delete conversation state for {}: {}",
response_id, e
))
})?;
if rows_affected == 0 {
return Err(StateStorageError::NotFound(format!(
"Conversation state not found for response_id: {}",
response_id
)));
}
debug!("Deleted conversation state for {}", response_id);
Ok(())
}
}
/*
PostgreSQL schema is maintained in docs/db_setup/conversation_states.sql
Run that SQL file against your database before using this storage backend.
*/
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole};
fn create_test_state(response_id: &str) -> OpenAIConversationState {
OpenAIConversationState {
response_id: response_id.to_string(),
input_items: vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Test message".to_string(),
}]),
})],
created_at: 1234567890,
model: "gpt-4".to_string(),
provider: "openai".to_string(),
}
}
// Note: These tests require a running PostgreSQL database
// Set TEST_DATABASE_URL environment variable to run integration tests
// Example: TEST_DATABASE_URL=postgresql://user:pass@localhost/test_db
async fn get_test_storage() -> Option<PostgreSQLConversationStorage> {
if let Ok(db_url) = std::env::var("TEST_DATABASE_URL") {
match PostgreSQLConversationStorage::new(db_url).await {
Ok(storage) => Some(storage),
Err(e) => {
eprintln!("Failed to create test storage: {}", e);
None
}
}
} else {
eprintln!("TEST_DATABASE_URL not set, skipping Supabase integration tests");
None
}
}
#[tokio::test]
async fn test_supabase_put_and_get_success() {
let Some(storage) = get_test_storage().await else {
return;
};
let state = create_test_state("test_resp_001");
storage.put(state.clone()).await.unwrap();
let retrieved = storage.get("test_resp_001").await.unwrap();
assert_eq!(retrieved.response_id, "test_resp_001");
assert_eq!(retrieved.input_items.len(), 1);
assert_eq!(retrieved.model, "gpt-4");
assert_eq!(retrieved.provider, "openai");
// Cleanup
let _ = storage.delete("test_resp_001").await;
}
#[tokio::test]
async fn test_supabase_put_overwrites_existing() {
let Some(storage) = get_test_storage().await else {
return;
};
let state1 = create_test_state("test_resp_002");
storage.put(state1).await.unwrap();
let mut state2 = create_test_state("test_resp_002");
state2.model = "gpt-4-turbo".to_string();
state2.input_items.push(InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Response".to_string(),
}]),
}));
storage.put(state2).await.unwrap();
let retrieved = storage.get("test_resp_002").await.unwrap();
assert_eq!(retrieved.model, "gpt-4-turbo");
assert_eq!(retrieved.input_items.len(), 2);
// Cleanup
let _ = storage.delete("test_resp_002").await;
}
#[tokio::test]
async fn test_supabase_get_not_found() {
let Some(storage) = get_test_storage().await else {
return;
};
let result = storage.get("nonexistent_id").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
}
#[tokio::test]
async fn test_supabase_exists_returns_false() {
let Some(storage) = get_test_storage().await else {
return;
};
let exists = storage.exists("nonexistent_id").await.unwrap();
assert!(!exists);
}
#[tokio::test]
async fn test_supabase_exists_returns_true_after_put() {
let Some(storage) = get_test_storage().await else {
return;
};
let state = create_test_state("test_resp_003");
storage.put(state).await.unwrap();
let exists = storage.exists("test_resp_003").await.unwrap();
assert!(exists);
// Cleanup
let _ = storage.delete("test_resp_003").await;
}
#[tokio::test]
async fn test_supabase_delete_success() {
let Some(storage) = get_test_storage().await else {
return;
};
let state = create_test_state("test_resp_004");
storage.put(state).await.unwrap();
storage.delete("test_resp_004").await.unwrap();
let exists = storage.exists("test_resp_004").await.unwrap();
assert!(!exists);
}
#[tokio::test]
async fn test_supabase_delete_not_found() {
let Some(storage) = get_test_storage().await else {
return;
};
let result = storage.delete("nonexistent_id").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
}
#[tokio::test]
async fn test_supabase_merge_works() {
let Some(storage) = get_test_storage().await else {
return;
};
let prev_state = create_test_state("test_resp_005");
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "New message".to_string(),
}]),
})];
let merged = storage.merge(&prev_state, current_input);
// Should have 2 messages (1 from prev + 1 current)
assert_eq!(merged.len(), 2);
}
#[tokio::test]
async fn test_supabase_table_verification() {
let Some(storage) = get_test_storage().await else {
return;
};
// This should trigger table verification
let result = storage.ensure_ready().await;
assert!(result.is_ok(), "Table verification should succeed");
// Second call should use cached result
let result2 = storage.ensure_ready().await;
assert!(result2.is_ok(), "Cached verification should succeed");
}
#[tokio::test]
#[ignore] // Run manually with: cargo test test_verify_data_in_supabase -- --ignored
async fn test_verify_data_in_supabase() {
let Some(storage) = get_test_storage().await else {
return;
};
// Create a test record that persists
let state = create_test_state("manual_test_verification");
storage.put(state).await.unwrap();
println!("✅ Data written to Supabase!");
println!("Check your Supabase dashboard:");
println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';");
println!("\nTo cleanup, run:");
println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';");
// DON'T cleanup - leave it for manual verification
}
}

View file

@ -0,0 +1,302 @@
use bytes::Bytes;
use flate2::read::GzDecoder;
use hermesllm::apis::openai_responses::{
InputItem, OutputItem, ResponsesAPIStreamEvent,
};
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
use hermesllm::transforms::response::output_to_input::outputs_to_inputs;
use std::io::Read;
use std::sync::Arc;
use tracing::{info, debug, warn};
use crate::handlers::utils::StreamProcessor;
use crate::state::{OpenAIConversationState, StateStorage};
/// Processor that wraps another processor and handles v1/responses state management
/// Captures response_id and output from streaming responses, stores state after completion
pub struct ResponsesStateProcessor<P: StreamProcessor> {
/// The underlying processor (e.g., ObservableStreamProcessor for metrics)
inner: P,
/// State storage backend
storage: Arc<dyn StateStorage>,
/// Original input items from the request
original_input: Vec<InputItem>,
/// Model name
model: String,
/// Provider name
provider: String,
/// Whether this is a streaming request
is_streaming: bool,
/// Whether upstream is OpenAI (skip storage if true)
is_openai_upstream: bool,
/// Content-Encoding header value (e.g., "gzip", "br", None)
content_encoding: Option<String>,
/// Request ID for logging
request_id: String,
/// Buffer for accumulating chunks (needed for non-streaming compressed responses)
chunk_buffer: Vec<u8>,
/// Captured response_id from response.completed event
response_id: Option<String>,
/// Captured output items from response.completed event
output_items: Option<Vec<OutputItem>>,
}
impl<P: StreamProcessor> ResponsesStateProcessor<P> {
pub fn new(
inner: P,
storage: Arc<dyn StateStorage>,
original_input: Vec<InputItem>,
model: String,
provider: String,
is_streaming: bool,
is_openai_upstream: bool,
content_encoding: Option<String>,
request_id: String,
) -> Self {
Self {
inner,
storage,
original_input,
model,
provider,
is_streaming,
is_openai_upstream,
content_encoding,
request_id,
chunk_buffer: Vec::new(),
response_id: None,
output_items: None,
}
}
/// Decompress accumulated buffer based on Content-Encoding header
fn decompress_buffer(&self) -> Vec<u8> {
if self.chunk_buffer.is_empty() {
return Vec::new();
}
match self.content_encoding.as_deref() {
Some("gzip") => {
let mut decoder = GzDecoder::new(self.chunk_buffer.as_slice());
let mut decompressed = Vec::new();
match decoder.read_to_end(&mut decompressed) {
Ok(_) => {
debug!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Successfully decompressed {} bytes to {} bytes",
self.request_id,
self.chunk_buffer.len(),
decompressed.len()
);
decompressed
}
Err(e) => {
warn!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to decompress gzip buffer: {}",
self.request_id,
e
);
self.chunk_buffer.clone()
}
}
}
Some(encoding) => {
warn!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Unsupported Content-Encoding: {}. Only gzip is currently supported.",
self.request_id,
encoding
);
self.chunk_buffer.clone()
}
None => self.chunk_buffer.clone(),
}
}
/// Parse response to extract response_id and output
/// For streaming: parse SSE events looking for response.completed (per chunk)
/// For non-streaming: buffer all chunks, then decompress and parse on completion
fn try_parse_response_chunk(&mut self, chunk: &[u8]) {
if self.is_streaming {
// Streaming: Try to parse SSE events from this chunk
// Note: For compressed streaming, we'd need to buffer and decompress first
// but most streaming responses aren't compressed since SSE needs to be readable
let sse_iter = match SseStreamIter::try_from(chunk) {
Ok(iter) => iter,
Err(_) => return, // Not valid SSE format, skip
};
// Process each SSE event in the chunk, looking for data lines with response.completed
for event in sse_iter {
// Only process data lines (skip event-only lines)
if let Some(data_str) = &event.data {
// Try to parse as ResponsesAPIStreamEvent
if let Ok(stream_event) = serde_json::from_str::<ResponsesAPIStreamEvent>(data_str) {
// Check if this is a ResponseCompleted event
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event {
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
self.request_id,
response.id,
response.output.len()
);
self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone());
return; // Found what we need, exit early
}
}
}
}
} else {
// Non-streaming: Buffer chunks, will decompress and parse on completion
self.chunk_buffer.extend_from_slice(chunk);
}
}
/// Parse buffered non-streaming response (called on completion)
fn try_parse_buffered_response(&mut self) {
if self.is_streaming || self.chunk_buffer.is_empty() {
return;
}
// Decompress if needed
let decompressed = self.decompress_buffer();
// Parse complete JSON response
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(&decompressed) {
Ok(response) => {
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}",
self.request_id,
response.id,
response.output.len()
);
self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone());
}
Err(e) => {
// Log parse error with chunk preview for debugging
let chunk_preview = String::from_utf8_lossy(&decompressed);
let preview_len = chunk_preview.len().min(200);
warn!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to parse non-streaming ResponsesAPIResponse: {}. Decompressed preview (first {} bytes): {}",
self.request_id,
e,
preview_len,
&chunk_preview[..preview_len]
);
}
}
}
}
impl<P: StreamProcessor> StreamProcessor for ResponsesStateProcessor<P> {
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
// Buffer/parse chunk for response extraction
self.try_parse_response_chunk(&chunk);
// Forward to inner processor
self.inner.process_chunk(chunk)
}
fn on_first_bytes(&mut self) {
self.inner.on_first_bytes();
}
fn on_complete(&mut self) {
// For non-streaming, decompress and parse buffered response
self.try_parse_buffered_response();
// First, let the inner processor complete
self.inner.on_complete();
// Skip storage for OpenAI upstream
if self.is_openai_upstream {
debug!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Skipping state storage for OpenAI upstream provider",
self.request_id
);
return;
}
// Store state if we captured response_id and output
if let (Some(response_id), Some(output_items)) = (&self.response_id, &self.output_items) {
// Convert output items to input items for next request
let output_as_inputs = outputs_to_inputs(output_items);
debug!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Converting outputs to inputs: output_items_count={}, converted_input_items_count={}",
self.request_id, output_items.len(), output_as_inputs.len()
);
// Combine original input + output as new input history
let mut combined_input = self.original_input.clone();
combined_input.extend(output_as_inputs);
debug!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Storing state: original_input_count={}, combined_input_count={}, combined_json={}",
self.request_id,
self.original_input.len(),
combined_input.len(),
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
);
let state = OpenAIConversationState {
response_id: response_id.clone(),
input_items: combined_input,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64,
model: self.model.clone(),
provider: self.provider.clone(),
};
// Store asynchronously (fire and forget with logging)
let storage = self.storage.clone();
let response_id_clone = response_id.clone();
let request_id = self.request_id.clone();
let items_count = state.input_items.len();
tokio::spawn(async move {
match storage.put(state).await {
Ok(()) => {
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Successfully stored conversation state for response_id: {}, items_count={}",
request_id,
response_id_clone,
items_count
);
}
Err(e) => {
warn!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to store conversation state for response_id {}: {}",
request_id,
response_id_clone,
e
);
}
}
});
} else {
warn!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | No response_id captured from upstream response - cannot store conversation state. response_id present: {}, output present: {}",
self.request_id,
self.response_id.is_some(),
self.output_items.is_some()
);
}
}
fn on_error(&mut self, error: &str) {
self.inner.on_error(error);
}
}

View file

@ -0,0 +1,335 @@
/// OpenTelemetry Semantic Conventions
///
/// This module defines standard attribute keys following OTEL semantic conventions.
/// See: https://opentelemetry.io/docs/specs/semconv/
// =============================================================================
// Span Attributes - HTTP
// =============================================================================
/// Semantic conventions for HTTP-related span attributes
pub mod http {
/// HTTP request method
/// Example: "GET", "POST", "PUT"
pub const METHOD: &str = "http.method";
/// HTTP response status code
/// Example: "200", "404", "500"
pub const STATUS_CODE: &str = "http.status_code";
/// Full HTTP request URL
pub const URL: &str = "http.url";
/// HTTP request target (path + query)
/// Example: "/v1/chat/completions?stream=true"
pub const TARGET: &str = "http.target";
/// Upstream target path after routing transformation
/// Example: "/api/paas/v4/chat/completions" (for Zhipu provider)
pub const UPSTREAM_TARGET: &str = "http.upstream_target";
/// HTTP request scheme
/// Example: "http", "https"
pub const SCHEME: &str = "http.scheme";
/// Value of the HTTP User-Agent header
pub const USER_AGENT: &str = "http.user_agent";
/// Size of the request payload body in bytes
pub const REQUEST_CONTENT_LENGTH: &str = "http.request_content_length";
/// Size of the response payload body in bytes
pub const RESPONSE_CONTENT_LENGTH: &str = "http.response_content_length";
}
// =============================================================================
// Span Attributes - LLM Specific
// =============================================================================
/// Custom attributes for LLM operations
/// These follow the emerging OTEL GenAI semantic conventions
pub mod llm {
/// Name of the LLM model being called
/// Example: "gpt-4", "claude-3-sonnet", "llama-2-70b"
pub const MODEL_NAME: &str = "llm.model";
/// Provider of the LLM
/// Example: "openai", "anthropic", "azure-openai"
pub const PROVIDER: &str = "llm.provider";
/// Type of LLM operation
/// Example: "chat", "completion", "embedding"
pub const OPERATION_TYPE: &str = "llm.operation_type";
/// Whether the request is streaming
pub const IS_STREAMING: &str = "llm.is_streaming";
/// Total bytes received in the response
pub const RESPONSE_BYTES: &str = "llm.response_bytes";
/// Duration of the LLM call in milliseconds
pub const DURATION_MS: &str = "llm.duration_ms";
/// Time to first token in milliseconds (streaming only)
pub const TIME_TO_FIRST_TOKEN_MS: &str = "llm.time_to_first_token";
/// Number of prompt tokens used
pub const PROMPT_TOKENS: &str = "llm.usage.prompt_tokens";
/// Number of completion tokens generated
pub const COMPLETION_TOKENS: &str = "llm.usage.completion_tokens";
/// Total tokens used (prompt + completion)
pub const TOTAL_TOKENS: &str = "llm.usage.total_tokens";
/// Temperature parameter used
pub const TEMPERATURE: &str = "llm.temperature";
/// Max tokens parameter used
pub const MAX_TOKENS: &str = "llm.max_tokens";
/// Top-p parameter used
pub const TOP_P: &str = "llm.top_p";
/// List of tool names provided in the request
pub const TOOLS: &str = "llm.tools";
/// Preview of the user message (truncated)
pub const USER_MESSAGE_PREVIEW: &str = "llm.user_message_preview";
}
// =============================================================================
// Span Attributes - Routing & Gateway
// =============================================================================
/// Attributes specific to LLM routing and gateway operations
pub mod routing {
/// Strategy used to select the LLM endpoint
/// Example: "round-robin", "least-latency", "cost-optimized"
pub const STRATEGY: &str = "routing.strategy";
/// Selected upstream endpoint
pub const UPSTREAM_ENDPOINT: &str = "routing.upstream_endpoint";
/// Time taken to determine the route in milliseconds
pub const ROUTE_DETERMINATION_MS: &str = "routing.determination_ms";
/// Whether a fallback endpoint was used
pub const IS_FALLBACK: &str = "routing.is_fallback";
/// Reason for route selection
pub const SELECTION_REASON: &str = "routing.selection_reason";
}
// =============================================================================
// Span Attributes - Error Handling
// =============================================================================
/// Attributes for error and exception tracking
pub mod error {
/// Whether an error occurred
pub const ERROR: &str = "error";
/// Type/class of the error
/// Example: "TimeoutError", "AuthenticationError"
pub const TYPE: &str = "error.type";
/// Error message
pub const MESSAGE: &str = "error.message";
/// Stack trace of the error
pub const STACK_TRACE: &str = "error.stack_trace";
}
// =============================================================================
// Operation Names
// =============================================================================
/// Canonical operation name components for Arch Gateway
pub mod operation_component {
/// Inbound request handling
pub const INBOUND: &str = "plano(inbound)";
/// Routing decision phase
pub const ROUTING: &str = "plano(routing)";
/// Handoff to upstream service
pub const HANDOFF: &str = "plano(handoff)";
/// Agent filter execution
pub const AGENT_FILTER: &str = "plano(filter)";
/// Agent execution
pub const AGENT: &str = "plano(agent)";
/// LLM call
pub const LLM: &str = "plano(llm)";
}
/// Builder for constructing standardized operation names
///
/// Format: `{method} {path} {target}`
///
/// The operation component (e.g., "archgw(llm)") is now part of the service name,
/// so the operation name focuses on the HTTP request details and target.
///
/// # Examples
/// ```
/// use brightstaff::tracing::OperationNameBuilder;
///
/// // LLM call operation: "POST /v1/chat/completions gpt-4"
/// // (service name will be "archgw(llm)")
/// let op = OperationNameBuilder::new()
/// .with_method("POST")
/// .with_path("/v1/chat/completions")
/// .with_target("gpt-4")
/// .build();
///
/// // Agent filter operation: "POST /agents/v1/chat/completions hallucination-detector"
/// // (service name will be "archgw(agent filter)")
/// let op = OperationNameBuilder::new()
/// .with_method("POST")
/// .with_path("/agents/v1/chat/completions")
/// .with_target("hallucination-detector")
/// .build();
///
/// // Routing operation: "POST /v1/chat/completions"
/// // (service name will be "archgw(routing)")
/// let op = OperationNameBuilder::new()
/// .with_method("POST")
/// .with_path("/v1/chat/completions")
/// .build();
/// ```
pub struct OperationNameBuilder {
method: Option<String>,
path: Option<String>,
operation: Option<String>,
target: Option<String>,
}
impl OperationNameBuilder {
/// Create a new operation name builder
pub fn new() -> Self {
Self {
method: None,
path: None,
operation: None,
target: None,
}
}
/// Set the HTTP method
///
/// # Arguments
/// * `method` - HTTP method (e.g., "GET", "POST", "PUT")
pub fn with_method(mut self, method: impl Into<String>) -> Self {
self.method = Some(method.into());
self
}
/// Set the request path
///
/// # Arguments
/// * `path` - Request path (e.g., "/v1/chat/completions", "/agents/v1/chat/completions")
pub fn with_path(mut self, path: impl Into<String>) -> Self {
self.path = Some(path.into());
self
}
/// Set the operation type (optional, for MCP operations)
///
/// # Arguments
/// * `operation` - Operation type (e.g., "tool_call", "session_init", "notification")
pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
self.operation = Some(operation.into());
self
}
/// Set the target (model name, agent name, or filter name)
///
/// # Arguments
/// * `target` - Target identifier (e.g., "gpt-4", "my-agent", "hallucination-detector")
pub fn with_target(mut self, target: impl Into<String>) -> Self {
self.target = Some(target.into());
self
}
/// Build the operation name string
///
/// # Format
/// - With all components: `{method} {path} ({operation}) {target}`
/// - Without operation: `{method} {path} {target}`
/// - Without target: `{method} {path}`
/// - Without path: `{method}`
/// - Empty: returns empty string
pub fn build(self) -> String {
let mut parts = Vec::new();
if let Some(method) = self.method {
parts.push(method);
}
if let Some(path) = self.path {
if let Some(operation) = self.operation {
parts.push(format!("{} ({})", path, operation));
} else {
parts.push(path);
}
}
if let Some(target) = self.target {
parts.push(target);
}
parts.join(" ")
}
}
impl Default for OperationNameBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_operation_name_full() {
let op = OperationNameBuilder::new()
.with_method("POST")
.with_path("/v1/chat/completions")
.with_target("gpt-4")
.build();
assert_eq!(op, "POST /v1/chat/completions gpt-4");
}
#[test]
fn test_operation_name_no_target() {
let op = OperationNameBuilder::new()
.with_method("POST")
.with_path("/v1/chat/completions")
.build();
assert_eq!(op, "POST /v1/chat/completions");
}
#[test]
fn test_operation_name_agent_filter() {
let op = OperationNameBuilder::new()
.with_method("POST")
.with_path("/agents/v1/chat/completions")
.with_target("content-filter")
.build();
assert_eq!(op, "POST /agents/v1/chat/completions content-filter");
}
#[test]
fn test_operation_name_minimal() {
let op = OperationNameBuilder::new().build();
assert_eq!(op, "");
}
}

View file

@ -0,0 +1,3 @@
mod constants;
pub use constants::{OperationNameBuilder, operation_component, http, llm, error, routing};

1
crates/build.sh Normal file
View file

@ -0,0 +1 @@
cargo build --release --target wasm32-wasip1 -p prompt_gateway -p llm_gateway && cargo build --release -p brightstaff

View file

@ -14,13 +14,25 @@ derivative = "2.2.0"
thiserror = "1.0.64"
tiktoken-rs = "0.5.9"
rand = "0.8.5"
serde_json = "1.0"
serde_json = { version = "1.0", features = ["preserve_order"] }
hex = "0.4.3"
urlencoding = "2.1.3"
url = "2.5.4"
hermesllm = { version = "0.1.0", path = "../hermesllm" }
serde_with = "3.13.0"
# Optional dependencies for trace collection (not available in WASM)
tokio = { version = "1.44", features = ["sync", "time"], optional = true }
reqwest = { version = "0.12", features = ["json"], optional = true }
tracing = { version = "0.1", optional = true }
[features]
default = []
trace-collection = ["tokio", "reqwest", "tracing"]
[dev-dependencies]
pretty_assertions = "1.4.1"
serde_json = "1.0.64"
serial_test = "3.2"
axum = "0.7"
tokio = { version = "1.44", features = ["sync", "time", "macros", "rt"] }

View file

@ -4,7 +4,6 @@ use crate::{
};
use core::{panic, str};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
@ -265,7 +264,7 @@ pub struct ToolCall {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDetail {
pub name: String,
pub arguments: Option<HashMap<String, Value>>,
pub arguments: String,
}
#[derive(Debug, Deserialize, Serialize)]

View file

@ -21,8 +21,11 @@ pub struct ModelAlias {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Agent {
pub id: String,
pub kind: Option<String>,
pub transport: Option<String>,
pub tool: Option<String>,
pub url: String,
#[serde(rename = "type")]
pub agent_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -41,6 +44,20 @@ pub struct Listener {
pub port: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateStorageConfig {
#[serde(rename = "type")]
pub storage_type: StateStorageType,
pub connection_string: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum StateStorageType {
Memory,
Postgres,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
@ -57,7 +74,9 @@ pub struct Configuration {
pub mode: Option<GatewayMode>,
pub routing: Option<Routing>,
pub agents: Option<Vec<Agent>>,
pub filters: Option<Vec<Agent>>,
pub listeners: Vec<Listener>,
pub state_storage: Option<StateStorageConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
@ -252,6 +271,39 @@ pub struct RoutingPreference {
pub description: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct AgentUsagePreference {
pub model: String,
pub orchestration_preferences: Vec<OrchestrationPreference>,
}
/// OrchestrationPreference with custom serialization to always include default parameters.
/// The parameters field is always serialized as:
/// {"type": "object", "properties": {}, "required": []}
#[derive(Debug, Clone, Deserialize)]
pub struct OrchestrationPreference {
pub name: String,
pub description: String,
}
impl serde::Serialize for OrchestrationPreference {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("OrchestrationPreference", 3)?;
state.serialize_field("name", &self.name)?;
state.serialize_field("description", &self.description)?;
state.serialize_field("parameters", &serde_json::json!({
"type": "object",
"properties": {},
"required": []
}))?;
state.end()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
//TODO: use enum for model, but if there is a new model, we need to update the code
pub struct LlmProvider {

View file

@ -7,12 +7,13 @@ pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
pub const DEFAULT_TARGET_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
pub const API_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
pub const MODEL_SERVER_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const MODEL_SERVER_NAME: &str = "bright_staff";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const MESSAGES_KEY: &str = "messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const ARCH_IS_STREAMING_HEADER: &str = "x-arch-streaming-request";
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses";
pub const MESSAGES_PATH: &str = "/v1/messages";
pub const HEALTHZ_PATH: &str = "/healthz";
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
@ -31,3 +32,4 @@ pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
pub const OTEL_POST_PATH: &str = "/v1/traces";
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";
pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries";
pub const BRIGHT_STAFF_SERVICE_NAME : &str = "brightstaff";

View file

@ -11,4 +11,5 @@ pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod tracing;
pub mod traces;
pub mod utils;

View file

@ -40,8 +40,14 @@ pub fn get_llm_provider(
let mut rng = thread_rng();
llm_providers
.iter()
.filter(|(_, provider)| {
provider.model
.as_ref()
.map(|m| !m.starts_with("Arch"))
.unwrap_or(true)
})
.choose(&mut rng)
.expect("There should always be at least one llm provider")
.expect("There should always be at least one non-Arch llm provider")
.1
.clone()
}

View file

@ -0,0 +1,285 @@
use super::shapes::Span;
use super::resource_span_builder::ResourceSpanBuilder;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::time::{interval, Duration};
use tracing::{debug, error, warn};
/// Parse W3C traceparent header into trace_id and parent_span_id
/// Format: "00-{trace_id}-{parent_span_id}-01"
///
/// Returns (trace_id, Option<parent_span_id>)
/// - parent_span_id is None if it's all zeros (0000000000000000), indicating a root span
pub fn parse_traceparent(traceparent: &str) -> (String, Option<String>) {
let parts: Vec<&str> = traceparent.split('-').collect();
if parts.len() == 4 {
let trace_id = parts[1].to_string();
let parent_span_id = parts[2].to_string();
// If parent_span_id is all zeros, this is a root span with no parent
let parent = if parent_span_id == "0000000000000000" {
None
} else {
Some(parent_span_id)
};
(trace_id, parent)
} else {
warn!("Invalid traceparent format: {}", traceparent);
// Return empty trace ID and None for parent if parsing fails
(String::new(), None)
}
}
/// Collects and batches spans, flushing them to an OTEL collector
///
/// Supports multiple services, with each service (e.g., "archgw(routing)", "archgw(llm)")
/// maintaining its own span queue. Flushes all services together periodically.
///
/// Tracing can be enabled/disabled in two ways:
/// 1. Via arch_config.yaml: presence of `tracing` configuration section
/// 2. Via environment variable: `OTEL_TRACING_ENABLED=true/false`
///
/// When disabled, span recording and flushing are no-ops.
pub struct TraceCollector {
/// Spans grouped by service name
/// Key: service name (e.g., "archgw(routing)", "archgw(llm)")
/// Value: queue of spans for that service
spans_by_service: Arc<Mutex<HashMap<String, VecDeque<Span>>>>,
flush_interval: Duration,
otel_url: String,
/// Whether tracing is enabled
enabled: bool,
}
impl TraceCollector {
/// Create a new trace collector
///
/// # Arguments
/// * `enabled` - Whether tracing is enabled
/// - `Some(true)` - Force enable tracing
/// - `Some(false)` - Force disable tracing
/// - `None` - Check `OTEL_TRACING_ENABLED` env var (defaults to true if not set)
///
/// Other parameters are read from environment variables:
/// - `TRACE_FLUSH_INTERVAL_MS` - Flush interval in milliseconds (default: 1000)
/// - `OTEL_COLLECTOR_URL` - OTEL collector endpoint (default: http://localhost:9903/v1/traces)
pub fn new(enabled: Option<bool>) -> Self {
let flush_interval_ms = std::env::var("TRACE_FLUSH_INTERVAL_MS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1000);
let otel_url = std::env::var("OTEL_COLLECTOR_URL")
.unwrap_or_else(|_| "http://localhost:9903/v1/traces".to_string());
// Determine if tracing is enabled:
// 1. Use explicit parameter if provided
// 2. Otherwise check OTEL_TRACING_ENABLED env var
// 3. Default to false if neither is set (tracing opt-in, not opt-out)
let enabled = enabled.unwrap_or_else(|| {
std::env::var("OTEL_TRACING_ENABLED")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(false)
});
debug!(
"TraceCollector initialized: flush_interval={}ms, url={}, enabled={}",
flush_interval_ms, otel_url, enabled
);
Self {
spans_by_service: Arc::new(Mutex::new(HashMap::new())),
flush_interval: Duration::from_millis(flush_interval_ms),
otel_url,
enabled,
}
}
/// Record a span for a specific service
///
/// # Arguments
/// * `service_name` - Name of the service (e.g., "archgw(routing)", "archgw(llm)")
/// * `span` - The span to record
pub fn record_span(&self, service_name: impl Into<String>, span: Span) {
// Skip recording if tracing is disabled
if !self.enabled {
return;
}
let service_name = service_name.into();
// Use try_lock to avoid blocking in async contexts
// If the lock is held, we skip recording (telemetry shouldn't block the app)
if let Ok(mut spans_by_service) = self.spans_by_service.try_lock() {
// Get or create the queue for this service
let spans = spans_by_service
.entry(service_name)
.or_insert_with(VecDeque::new);
spans.push_back(span);
} else {
// Lock contention - skip recording this span
debug!("Skipped span recording due to lock contention");
}
// Flushing is handled by the periodic background flusher (see `start_background_flusher`).
}
/// Flush all buffered spans to the OTEL collector
/// Builds ResourceSpans for each service with spans
pub async fn flush(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Skip flushing if tracing is disabled
if !self.enabled {
return Ok(());
}
let mut spans_by_service = self.spans_by_service.lock().await;
if spans_by_service.is_empty() {
return Ok(());
}
// Snapshot and drain all services' spans
let service_batches: Vec<(String, Vec<Span>)> = spans_by_service
.iter_mut()
.filter_map(|(service_name, spans)| {
if spans.is_empty() {
None
} else {
Some((service_name.clone(), spans.drain(..).collect()))
}
})
.collect();
drop(spans_by_service); // Release lock before HTTP call
if service_batches.is_empty() {
return Ok(());
}
let total_spans: usize = service_batches.iter().map(|(_, spans)| spans.len()).sum();
debug!("Flushing {} spans across {} services to OTEL collector", total_spans, service_batches.len());
// Build canonical OTEL payload structure - one ResourceSpan per service
let resource_spans = self.build_resource_spans(service_batches);
match self.send_to_otel(resource_spans).await {
Ok(_) => {
debug!("Successfully flushed {} spans", total_spans);
Ok(())
}
Err(e) => {
warn!("Failed to send spans to OTEL collector: {:?}", e);
Err(e)
}
}
}
/// Build OTEL-compliant resource spans from collected spans, one ResourceSpan per service
fn build_resource_spans(&self, service_batches: Vec<(String, Vec<Span>)>) -> Vec<super::shapes::ResourceSpan> {
service_batches
.into_iter()
.map(|(service_name, spans)| {
ResourceSpanBuilder::new(&service_name)
.add_spans(spans)
.build()
})
.collect()
}
/// Send resource spans to OTEL collector
/// Serializes as {"resourceSpans": [...]} per OTEL spec
async fn send_to_otel(
&self,
resource_spans: Vec<super::shapes::ResourceSpan>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let client = reqwest::Client::new();
// Create OTEL payload with proper structure
let payload = serde_json::json!({
"resourceSpans": resource_spans
});
let response = client
.post(&self.otel_url)
.header("Content-Type", "application/json")
.json(&payload)
.timeout(Duration::from_secs(5))
.send()
.await?;
if !response.status().is_success() {
warn!(
"OTEL collector returned non-success status: {}",
response.status()
);
return Err(format!("OTEL collector error: {}", response.status()).into());
}
Ok(())
}
/// Start a background task that periodically flushes traces
/// Returns a join handle that can be used to stop the flusher
pub fn start_background_flusher(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
let flush_interval = self.flush_interval;
tokio::spawn(async move {
let mut ticker = interval(flush_interval);
loop {
ticker.tick().await;
if let Err(e) = self.flush().await {
error!("Background trace flush failed: {:?}", e);
}
}
})
}
/// Get current number of buffered spans across all services (for testing/monitoring)
pub async fn buffered_count(&self) -> usize {
self.spans_by_service
.lock()
.await
.values()
.map(|spans| spans.len())
.sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traces::SpanBuilder;
#[tokio::test]
async fn test_collector_basic() {
let collector = TraceCollector::new(Some(true));
let span = SpanBuilder::new("test_operation")
.with_trace_id("abc123")
.build();
collector.record_span("test-service", span);
assert_eq!(collector.buffered_count().await, 1);
}
#[tokio::test]
async fn test_collector_auto_flush() {
// Since batch-triggered flush behavior was removed, record two spans and verify both are buffered
let collector = Arc::new(TraceCollector::new(Some(true)));
let span1 = SpanBuilder::new("test1").build();
let span2 = SpanBuilder::new("test2").build();
collector.record_span("test-service", span1);
collector.record_span("test-service", span2);
// With no batch-triggered flush, both spans should remain buffered
assert_eq!(collector.buffered_count().await, 2);
}
}

View file

@ -0,0 +1,27 @@
/// OpenTelemetry semantic convention constants for tracing
///
/// These constants ensure consistency across the codebase and prevent typos
/// Resource attribute keys following OTEL semantic conventions
pub mod resource {
/// Logical name of the service
pub const SERVICE_NAME: &str = "service.name";
/// Version of the service
pub const SERVICE_VERSION: &str = "service.version";
/// Service namespace/environment
pub const SERVICE_NAMESPACE: &str = "service.namespace";
/// Service instance ID
pub const SERVICE_INSTANCE_ID: &str = "service.instance.id";
}
/// Instrumentation scope defaults
pub mod scope {
/// Default scope name for tracing instrumentation
pub const DEFAULT_NAME: &str = "brightstaff.tracing";
/// Default scope version
pub const DEFAULT_VERSION: &str = "1.0.0";
}

View file

@ -0,0 +1,26 @@
// Original tracing types (OTEL structures)
mod shapes;
// New tracing utilities
mod span_builder;
mod resource_span_builder;
mod constants;
#[cfg(feature = "trace-collection")]
mod collector;
#[cfg(all(test, feature = "trace-collection"))]
mod tests;
// Re-export original types
pub use shapes::{
Span, Event, Traceparent, TraceparentNewError,
ResourceSpan, Resource, ScopeSpan, Scope, Attribute, AttributeValue,
};
// Re-export new utilities
pub use span_builder::{SpanBuilder, SpanKind, generate_random_span_id};
pub use resource_span_builder::ResourceSpanBuilder;
pub use constants::*;
#[cfg(feature = "trace-collection")]
pub use collector::{TraceCollector, parse_traceparent};

View file

@ -0,0 +1,121 @@
use super::shapes::{ResourceSpan, Resource, ScopeSpan, Scope, Span, Attribute, AttributeValue};
use super::constants::{resource, scope};
use std::collections::HashMap;
/// Builder for creating OTEL ResourceSpan structures
///
/// Provides a fluent API for building the resource/scope/span hierarchy
pub struct ResourceSpanBuilder {
service_name: String,
resource_attributes: HashMap<String, String>,
scope_name: String,
scope_version: String,
spans: Vec<Span>,
}
impl ResourceSpanBuilder {
/// Create a new ResourceSpan builder with service name
pub fn new(service_name: impl Into<String>) -> Self {
Self {
service_name: service_name.into(),
resource_attributes: HashMap::new(),
scope_name: scope::DEFAULT_NAME.to_string(),
scope_version: scope::DEFAULT_VERSION.to_string(),
spans: Vec::new(),
}
}
/// Add a resource attribute (e.g., deployment.environment, host.name)
pub fn with_resource_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.resource_attributes.insert(key.into(), value.into());
self
}
/// Set the instrumentation scope name
pub fn with_scope_name(mut self, name: impl Into<String>) -> Self {
self.scope_name = name.into();
self
}
/// Set the instrumentation scope version
pub fn with_scope_version(mut self, version: impl Into<String>) -> Self {
self.scope_version = version.into();
self
}
/// Add a single span
pub fn add_span(mut self, span: Span) -> Self {
self.spans.push(span);
self
}
/// Add multiple spans
pub fn add_spans(mut self, spans: Vec<Span>) -> Self {
self.spans.extend(spans);
self
}
/// Build the ResourceSpan
pub fn build(self) -> ResourceSpan {
// Build resource attributes
let mut attributes = vec![
Attribute {
key: resource::SERVICE_NAME.to_string(),
value: AttributeValue {
string_value: Some(self.service_name),
},
}
];
// Add custom resource attributes
for (key, value) in self.resource_attributes {
attributes.push(Attribute {
key,
value: AttributeValue {
string_value: Some(value),
},
});
}
let resource = Resource { attributes };
let scope = Scope {
name: self.scope_name,
version: self.scope_version,
attributes: Vec::new(),
};
let scope_span = ScopeSpan {
scope,
spans: self.spans,
};
ResourceSpan {
resource,
scope_spans: vec![scope_span],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traces::SpanBuilder;
#[test]
fn test_resource_span_builder() {
let span1 = SpanBuilder::new("operation1").build();
let span2 = SpanBuilder::new("operation2").build();
let resource_span = ResourceSpanBuilder::new("test-service")
.with_resource_attribute("deployment.environment", "production")
.with_scope_name("test-scope")
.add_span(span1)
.add_span(span2)
.build();
assert_eq!(resource_span.resource.attributes.len(), 2); // service.name + custom
assert_eq!(resource_span.scope_spans.len(), 1);
assert_eq!(resource_span.scope_spans[0].spans.len(), 2);
}
}

View file

@ -0,0 +1,123 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
pub struct ResourceSpan {
pub resource: Resource,
#[serde(rename = "scopeSpans")]
pub scope_spans: Vec<ScopeSpan>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Resource {
pub attributes: Vec<Attribute>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ScopeSpan {
pub scope: Scope,
pub spans: Vec<Span>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Scope {
pub name: String,
pub version: String,
pub attributes: Vec<Attribute>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Span {
#[serde(rename = "traceId")]
pub trace_id: String,
#[serde(rename = "spanId")]
pub span_id: String,
#[serde(rename = "parentSpanId")]
pub parent_span_id: Option<String>, // Optional in case there's no parent span
pub name: String,
#[serde(rename = "startTimeUnixNano")]
pub start_time_unix_nano: String,
#[serde(rename = "endTimeUnixNano")]
pub end_time_unix_nano: String,
pub kind: u32,
pub attributes: Vec<Attribute>,
pub events: Option<Vec<Event>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Event {
#[serde(rename = "timeUnixNano")]
pub time_unix_nano: String,
pub name: String,
pub attributes: Vec<Attribute>,
}
impl Event {
pub fn new(name: String, time_unix_nano: u128) -> Self {
Event {
time_unix_nano: format!("{}", time_unix_nano),
name,
attributes: Vec::new(),
}
}
pub fn add_attribute(&mut self, key: String, value: String) {
self.attributes.push(Attribute {
key,
value: AttributeValue {
string_value: Some(value),
},
});
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Attribute {
pub key: String,
pub value: AttributeValue,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct AttributeValue {
#[serde(rename = "stringValue")]
pub string_value: Option<String>, // Use Option to handle different value types
}
pub struct Traceparent {
pub version: String,
pub trace_id: String,
pub parent_id: String,
pub flags: String,
}
impl std::fmt::Display for Traceparent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}-{}-{}-{}",
self.version, self.trace_id, self.parent_id, self.flags
)
}
}
#[derive(thiserror::Error, Debug)]
pub enum TraceparentNewError {
#[error("Invalid traceparent: \'{0}\'")]
InvalidTraceparent(String),
}
impl TryFrom<String> for Traceparent {
type Error = TraceparentNewError;
fn try_from(traceparent: String) -> Result<Self, Self::Error> {
let traceparent_tokens: Vec<&str> = traceparent.split("-").collect::<Vec<&str>>();
if traceparent_tokens.len() != 4 {
return Err(TraceparentNewError::InvalidTraceparent(traceparent));
}
Ok(Traceparent {
version: traceparent_tokens[0].to_string(),
trace_id: traceparent_tokens[1].to_string(),
parent_id: traceparent_tokens[2].to_string(),
flags: traceparent_tokens[3].to_string(),
})
}
}

View file

@ -0,0 +1,200 @@
use super::shapes::{Span, Attribute, AttributeValue};
use std::collections::HashMap;
use std::time::SystemTime;
/// OpenTelemetry span kinds
/// https://opentelemetry.io/docs/specs/otel/trace/api/#spankind
#[derive(Debug, Clone, Copy)]
pub enum SpanKind {
/// Default value. Indicates that the span represents an internal operation within an application
Internal = 0,
/// Indicates that the span describes a request to some remote service
Client = 3,
}
/// Builder for creating OTEL-compliant spans with a fluent API
///
/// This is the recommended way to create spans with proper trace context.
///
/// # Example
/// ```no_run
/// use common::traces::{SpanBuilder, SpanKind};
/// use std::time::SystemTime;
///
/// let span = SpanBuilder::new("router_chat")
/// .with_trace_id("abc123")
/// .with_parent_span_id("parent456")
/// .with_kind(SpanKind::Internal)
/// .with_attribute("http.method", "POST")
/// .with_attribute("http.path", "/v1/chat/completions")
/// .build();
/// ```
pub struct SpanBuilder {
name: String,
trace_id: Option<String>,
parent_span_id: Option<String>,
start_time: SystemTime,
end_time: Option<SystemTime>,
kind: SpanKind,
attributes: HashMap<String, String>,
span_id: Option<String>,
}
impl SpanBuilder {
/// Create a new span builder
///
/// # Arguments
/// * `name` - The operation name for this span (e.g., "router_chat", "determine_route")
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
trace_id: None,
parent_span_id: None,
start_time: SystemTime::now(),
end_time: None,
kind: SpanKind::Internal,
attributes: HashMap::new(),
span_id: None,
}
}
/// Set the trace ID (extracted from traceparent or OpenTelemetry context)
pub fn with_trace_id(mut self, trace_id: impl Into<String>) -> Self {
self.trace_id = Some(trace_id.into());
self
}
pub fn with_span_id(mut self, span_id: impl Into<String>) -> Self {
self.span_id = Some(span_id.into());
self
}
/// Set the parent span ID to link this span to its parent
pub fn with_parent_span_id(mut self, parent_span_id: impl Into<String>) -> Self {
self.parent_span_id = Some(parent_span_id.into());
self
}
/// Set the span kind (defaults to Internal)
pub fn with_kind(mut self, kind: SpanKind) -> Self {
self.kind = kind;
self
}
/// Set explicit start time (defaults to now)
pub fn with_start_time(mut self, start_time: SystemTime) -> Self {
self.start_time = start_time;
self
}
/// Set explicit end time (defaults to build time)
pub fn with_end_time(mut self, end_time: SystemTime) -> Self {
self.end_time = Some(end_time);
self
}
/// Add a single attribute to the span
pub fn with_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.attributes.insert(key.into(), value.into());
self
}
/// Add multiple attributes at once
pub fn with_attributes(mut self, attrs: HashMap<String, String>) -> Self {
self.attributes.extend(attrs);
self
}
/// Build the span, consuming the builder
///
/// Creates a complete OTEL-compliant span with all provided attributes,
/// generating span_id and using provided or random trace_id.
pub fn build(self) -> Span {
let end_time = self.end_time.unwrap_or_else(SystemTime::now);
let start_nanos = system_time_to_nanos(self.start_time);
let end_nanos = system_time_to_nanos(end_time);
// Generate trace_id if not provided
let trace_id = self.trace_id.unwrap_or_else(|| generate_random_trace_id());
// Create attributes in OTEL format
let attributes: Vec<Attribute> = self.attributes
.into_iter()
.map(|(key, value)| Attribute {
key,
value: AttributeValue {
string_value: Some(value),
},
})
.collect();
// Build span directly without going through Span::new()
Span {
trace_id,
span_id: self.span_id.unwrap_or_else(|| generate_random_span_id()),
parent_span_id: self.parent_span_id,
name: self.name,
start_time_unix_nano: format!("{}", start_nanos),
end_time_unix_nano: format!("{}", end_nanos),
kind: self.kind as u32,
attributes,
events: None,
}
}
}
/// Convert SystemTime to nanoseconds since UNIX epoch for OTEL
fn system_time_to_nanos(time: SystemTime) -> u128 {
time.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
}
/// Generate a random span ID (16 hex characters = 8 bytes)
pub fn generate_random_span_id() -> String {
use rand::RngCore;
let mut rng = rand::thread_rng();
let mut random_bytes = [0u8; 8];
rng.fill_bytes(&mut random_bytes);
hex::encode(random_bytes)
}
/// Generate a random trace ID (32 hex characters = 16 bytes)
fn generate_random_trace_id() -> String {
use rand::RngCore;
let mut rng = rand::thread_rng();
let mut random_bytes = [0u8; 16];
rng.fill_bytes(&mut random_bytes);
hex::encode(random_bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_span_builder_basic() {
let span = SpanBuilder::new("test_operation")
.with_trace_id("abc123")
.with_parent_span_id("parent123")
.with_attribute("key", "value")
.build();
assert_eq!(span.name, "test_operation");
assert_eq!(span.trace_id, "abc123");
assert_eq!(span.parent_span_id, Some("parent123".to_string()));
assert_eq!(span.attributes.len(), 1);
}
#[test]
fn test_span_builder_no_parent() {
let span = SpanBuilder::new("root_span")
.with_trace_id("xyz789")
.build();
assert_eq!(span.name, "root_span");
assert_eq!(span.trace_id, "xyz789");
assert_eq!(span.parent_span_id, None);
}
}

View file

@ -0,0 +1,101 @@
//! Mock OTEL Collector for testing trace output
//!
//! This module provides a simple HTTP server that mimics an OTEL collector.
//! It exposes three endpoints:
//! - POST /v1/traces: Capture incoming OTLP JSON payloads
//! - GET /v1/traces: Return all captured payloads as JSON array
//! - DELETE /v1/traces: Clear all captured payloads
//!
//! Each test creates its own MockOtelCollector instance.
use axum::{
extract::State,
http::StatusCode,
routing::{delete, get, post},
Json, Router,
};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::RwLock;
type SharedTraces = Arc<RwLock<Vec<Value>>>;
/// POST /v1/traces - capture incoming OTLP payload
async fn post_traces(
State(traces): State<SharedTraces>,
Json(payload): Json<Value>,
) -> StatusCode {
traces.write().await.push(payload);
StatusCode::OK
}
/// GET /v1/traces - return all captured payloads
async fn get_traces(State(traces): State<SharedTraces>) -> Json<Vec<Value>> {
Json(traces.read().await.clone())
}
/// DELETE /v1/traces - clear all captured payloads
async fn delete_traces(State(traces): State<SharedTraces>) -> StatusCode {
traces.write().await.clear();
StatusCode::NO_CONTENT
}
/// Mock OTEL collector server
pub struct MockOtelCollector {
address: String,
client: reqwest::Client,
#[allow(dead_code)]
server_handle: tokio::task::JoinHandle<()>,
}
impl MockOtelCollector {
/// Create and start a new mock collector on a random port
pub async fn start() -> Self {
let traces = Arc::new(RwLock::new(Vec::new()));
let app = Router::new()
.route("/v1/traces", post(post_traces))
.route("/v1/traces", get(get_traces))
.route("/v1/traces", delete(delete_traces))
.with_state(traces.clone());
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("Failed to bind to random port");
let addr = listener.local_addr().expect("Failed to get local address");
let address = format!("http://127.0.0.1:{}", addr.port());
let server_handle = tokio::spawn(async move {
axum::serve(listener, app)
.await
.expect("Server failed");
});
// Give server a moment to start
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
Self {
address,
client: reqwest::Client::new(),
server_handle,
}
}
/// Get the address of the collector
pub fn address(&self) -> &str {
&self.address
}
/// GET /v1/traces - fetch all captured payloads
pub async fn get_traces(&self) -> Vec<Value> {
self.client
.get(format!("{}/v1/traces", self.address))
.send()
.await
.expect("Failed to GET traces")
.json()
.await
.expect("Failed to parse traces JSON")
}
}

View file

@ -0,0 +1,4 @@
mod mock_otel_collector;
mod trace_integration_test;
pub use mock_otel_collector::MockOtelCollector;

View file

@ -0,0 +1,304 @@
//! Integration tests for OpenTelemetry tracing in router.rs
//!
//! These tests validate that the spans created for LLM requests contain
//! all expected attributes and events by checking the raw JSON payloads
//! sent to the mock OTEL collector.
//!
//! ## Test Design
//! Each test creates its own MockOtelCollector and TraceCollector:
//! 1. Start MockOtelCollector on random port
//! 2. Create TraceCollector with 500ms flush interval
//! 3. Record spans using TraceCollector
//! 4. Flush and wait (500ms + 200ms buffer = 700ms total) for spans to arrive
//! 5. Get raw JSON payloads (GET /v1/traces) and validate structure
//! 6. Test cleanup happens automatically when collectors are dropped
//!
//! ## Serial Execution
//! Tests use the `#[serial]` attribute to run sequentially because they
//! use global environment variables (OTEL_COLLECTOR_URL, OTEL_TRACING_ENABLED,
//! TRACE_FLUSH_INTERVAL_MS). This ensures test isolation without requiring
//! the `--test-threads=1` command line flag.
const FLUSH_INTERVAL_MS: u64 = 50;
const FLUSH_BUFFER_MS: u64 = 50;
const TOTAL_WAIT_MS: u64 = FLUSH_INTERVAL_MS + FLUSH_BUFFER_MS;
use crate::traces::{SpanBuilder, SpanKind, TraceCollector};
use serde_json::Value;
use serial_test::serial;
use std::sync::Arc;
use super::MockOtelCollector;
/// Helper to extract all spans from OTLP JSON payloads
fn extract_spans(payloads: &[Value]) -> Vec<&Value> {
let mut spans = Vec::new();
for payload in payloads {
if let Some(resource_spans) = payload.get("resourceSpans").and_then(|v| v.as_array()) {
for resource_span in resource_spans {
if let Some(scope_spans) = resource_span.get("scopeSpans").and_then(|v| v.as_array()) {
for scope_span in scope_spans {
if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array()) {
spans.extend(span_list.iter());
}
}
}
}
}
}
spans
}
/// Helper to get string attribute value from a span
fn get_string_attr<'a>(span: &'a Value, key: &str) -> Option<&'a str> {
span.get("attributes")
.and_then(|attrs| attrs.as_array())
.and_then(|attrs| {
attrs.iter().find(|attr| {
attr.get("key").and_then(|k| k.as_str()) == Some(key)
})
})
.and_then(|attr| attr.get("value"))
.and_then(|v| v.get("stringValue"))
.and_then(|v| v.as_str())
}
#[tokio::test]
#[serial]
async fn test_llm_span_contains_basic_attributes() {
// Start mock OTEL collector
let mock_collector = MockOtelCollector::start().await;
// Create TraceCollector pointing to mock with 500ms flush intervalc
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
// Create a test span simulating router.rs behavior
let span = SpanBuilder::new("POST /v1/chat/completions >> /v1/chat/completions")
.with_kind(SpanKind::Client)
.with_trace_id("test-trace-123")
.with_attribute("http.method", "POST")
.with_attribute("http.target", "/v1/chat/completions")
.with_attribute("http.upstream_target", "/v1/chat/completions")
.with_attribute("llm.model", "gpt-4o")
.with_attribute("llm.provider", "openai")
.with_attribute("llm.is_streaming", "true")
.with_attribute("llm.temperature", "0.7")
.build();
trace_collector.record_span("archgw(llm)", span);
// Flush and wait for spans to arrive (500ms flush interval + 200ms buffer)
trace_collector.flush().await.expect("Failed to flush");
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
let payloads = mock_collector.get_traces().await;
let spans = extract_spans(&payloads);
assert_eq!(spans.len(), 1, "Expected exactly one span");
let span = spans[0];
// Validate HTTP attributes
assert_eq!(get_string_attr(span, "http.method"), Some("POST"));
assert_eq!(get_string_attr(span, "http.target"), Some("/v1/chat/completions"));
// Validate LLM attributes
assert_eq!(get_string_attr(span, "llm.model"), Some("gpt-4o"));
assert_eq!(get_string_attr(span, "llm.provider"), Some("openai"));
assert_eq!(get_string_attr(span, "llm.is_streaming"), Some("true"));
assert_eq!(get_string_attr(span, "llm.temperature"), Some("0.7"));
}
#[tokio::test]
#[serial]
async fn test_llm_span_contains_tool_information() {
let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
let tools_formatted = "get_weather(...)\nsearch_web(...)\ncalculate(...)";
let span = SpanBuilder::new("POST /v1/chat/completions")
.with_trace_id("test-trace-tools")
.with_attribute("llm.request.tools", tools_formatted)
.with_attribute("llm.model", "gpt-4o")
.build();
trace_collector.record_span("archgw(llm)", span);
trace_collector.flush().await.expect("Failed to flush");
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
let payloads = mock_collector.get_traces().await;
let spans = extract_spans(&payloads);
assert!(!spans.is_empty(), "No spans captured");
let span = spans[0];
let tools = get_string_attr(span, "llm.request.tools");
assert!(tools.is_some(), "Tools attribute missing");
assert!(tools.unwrap().contains("get_weather(...)"));
assert!(tools.unwrap().contains("search_web(...)"));
assert!(tools.unwrap().contains("calculate(...)"));
assert!(tools.unwrap().contains('\n'), "Tools should be newline-separated");
}
#[tokio::test]
#[serial]
async fn test_llm_span_contains_user_message_preview() {
let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
let long_message = "This is a very long user message that should be truncated to 50 characters in the span";
let preview = if long_message.len() > 50 {
format!("{}...", &long_message[..50])
} else {
long_message.to_string()
};
let span = SpanBuilder::new("POST /v1/messages")
.with_trace_id("test-trace-preview")
.with_attribute("llm.request.user_message_preview", &preview)
.build();
trace_collector.record_span("archgw(llm)", span);
trace_collector.flush().await.expect("Failed to flush");
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
let payloads = mock_collector.get_traces().await;
let spans = extract_spans(&payloads);
let span = spans[0];
let message_preview = get_string_attr(span, "llm.request.user_message_preview");
assert!(message_preview.is_some());
assert!(message_preview.unwrap().len() <= 53); // 50 chars + "..."
assert!(message_preview.unwrap().contains("..."));
}
#[tokio::test]
#[serial]
async fn test_llm_span_contains_time_to_first_token() {
let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
let ttft_ms = "245"; // milliseconds as string
let span = SpanBuilder::new("POST /v1/chat/completions")
.with_trace_id("test-trace-ttft")
.with_attribute("llm.is_streaming", "true")
.with_attribute("llm.time_to_first_token_ms", ttft_ms)
.build();
trace_collector.record_span("archgw(llm)", span);
trace_collector.flush().await.expect("Failed to flush");
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
let payloads = mock_collector.get_traces().await;
let spans = extract_spans(&payloads);
let span = spans[0];
// Check TTFT attribute
let ttft_attr = get_string_attr(span, "llm.time_to_first_token_ms");
assert_eq!(ttft_attr, Some("245"));
}
#[tokio::test]
#[serial]
async fn test_llm_span_contains_upstream_path() {
let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
// Test Zhipu provider with path transformation
let span = SpanBuilder::new("POST /v1/chat/completions >> /api/paas/v4/chat/completions")
.with_trace_id("test-trace-upstream")
.with_attribute("http.upstream_target", "/api/paas/v4/chat/completions")
.with_attribute("llm.provider", "zhipu")
.with_attribute("llm.model", "glm-4")
.build();
trace_collector.record_span("archgw(llm)", span);
trace_collector.flush().await.expect("Failed to flush");
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
let payloads = mock_collector.get_traces().await;
let spans = extract_spans(&payloads);
let span = spans[0];
// Operation name should show the transformation
let name = span.get("name").and_then(|v| v.as_str());
assert!(name.is_some());
assert!(name.unwrap().contains(">>"), "Operation name should show path transformation");
// Check upstream target attribute
let upstream = get_string_attr(span, "http.upstream_target");
assert_eq!(upstream, Some("/api/paas/v4/chat/completions"));
}
#[tokio::test]
#[serial]
async fn test_llm_span_multiple_services() {
let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
// Create spans for different services
let llm_span = SpanBuilder::new("LLM Request")
.with_trace_id("test-multi")
.with_attribute("service", "llm")
.build();
let routing_span = SpanBuilder::new("Routing Decision")
.with_trace_id("test-multi")
.with_attribute("service", "routing")
.build();
trace_collector.record_span("archgw(llm)", llm_span);
trace_collector.record_span("archgw(routing)", routing_span);
trace_collector.flush().await.expect("Failed to flush");
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
let payloads = mock_collector.get_traces().await;
let all_spans = extract_spans(&payloads);
assert_eq!(all_spans.len(), 2, "Should have captured both spans");
}
#[tokio::test]
#[serial]
async fn test_tracing_disabled_produces_no_spans() {
let mock_collector = MockOtelCollector::start().await;
// Create TraceCollector with tracing DISABLED
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var("OTEL_TRACING_ENABLED", "false");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(false)));
let span = SpanBuilder::new("Test Span")
.with_trace_id("test-disabled")
.build();
trace_collector.record_span("archgw(llm)", span);
trace_collector.flush().await.ok(); // Should be no-op when disabled
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
let payloads = mock_collector.get_traces().await;
let all_spans = extract_spans(&payloads);
assert_eq!(all_spans.len(), 0, "No spans should be captured when tracing is disabled");
}

View file

@ -10,3 +10,5 @@ serde_with = {version = "3.12.0", features = ["base64"]}
thiserror = "2.0.12"
aws-smithy-eventstream = "0.60"
bytes = "1.10"
uuid = { version = "1.11", features = ["v4"] }
log = "0.4"

View file

@ -7,7 +7,7 @@ use thiserror::Error;
use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::ProviderStreamResponse;
use crate::providers::streaming_response::ProviderStreamResponse;
// ============================================================================
// AMAZON BEDROCK CONVERSE API ENUMERATION
@ -200,6 +200,17 @@ impl ProviderRequest for ConverseRequest {
})
}
fn get_tool_names(&self) -> Option<Vec<String>> {
self.tool_config.as_ref()?.tools.as_ref().map(|tools| {
tools
.iter()
.filter_map(|tool| match tool {
Tool::ToolSpec { tool_spec } => Some(tool_spec.name.clone()),
})
.collect()
})
}
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
serde_json::to_vec(self).map_err(|e| ProviderRequestError {
message: format!("Failed to serialize Bedrock request: {}", e),
@ -218,6 +229,108 @@ impl ProviderRequest for ConverseRequest {
false
}
}
fn get_temperature(&self) -> Option<f32> {
self.inference_config.as_ref()?.temperature
}
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
use crate::apis::openai::{Message, MessageContent, Role};
let mut openai_messages = Vec::new();
// Add system messages if present
if let Some(system) = &self.system {
for sys_block in system {
match sys_block {
SystemContentBlock::Text { text } => {
openai_messages.push(Message {
role: Role::System,
content: MessageContent::Text(text.clone()),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
_ => {} // Skip other system content types
}
}
}
// Convert conversation messages
if let Some(messages) = &self.messages {
for msg in messages {
let role = match msg.role {
ConversationRole::User => Role::User,
ConversationRole::Assistant => Role::Assistant,
};
// Extract text from content blocks
let content = msg.content.iter()
.filter_map(|block| {
if let ContentBlock::Text { text } = block {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
openai_messages.push(Message {
role,
content: MessageContent::Text(content),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
}
openai_messages
}
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
// Convert OpenAI messages to Bedrock format
use crate::apis::amazon_bedrock::{ContentBlock, ConversationRole, SystemContentBlock};
let mut system_blocks = Vec::new();
let mut bedrock_messages = Vec::new();
for msg in messages {
match msg.role {
crate::apis::openai::Role::System => {
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
system_blocks.push(SystemContentBlock::Text { text: text.clone() });
}
}
crate::apis::openai::Role::User | crate::apis::openai::Role::Assistant => {
let role = match msg.role {
crate::apis::openai::Role::User => ConversationRole::User,
crate::apis::openai::Role::Assistant => ConversationRole::Assistant,
_ => continue,
};
let content = if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
vec![ContentBlock::Text { text: text.clone() }]
} else {
vec![]
};
bedrock_messages.push(crate::apis::amazon_bedrock::Message {
role,
content,
});
}
_ => {}
}
}
if !system_blocks.is_empty() {
self.system = Some(system_blocks);
}
self.messages = Some(bedrock_messages);
}
}
// ============================================================================

View file

@ -6,7 +6,8 @@ use std::collections::HashMap;
use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
use crate::providers::response::ProviderResponse;
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::transforms::lib::ExtractText;
use crate::MESSAGES_PATH;
@ -397,6 +398,8 @@ pub enum MessagesContentDelta {
InputJsonDelta { partial_json: String },
#[serde(rename = "thinking_delta")]
ThinkingDelta { thinking: String },
#[serde(rename = "signature_delta")]
SignatureDelta { signature: String },
}
#[skip_serializing_none]
@ -512,6 +515,12 @@ impl ProviderRequest for MessagesRequest {
None
}
fn get_tool_names(&self) -> Option<Vec<String>> {
self.tools.as_ref().map(|tools| {
tools.iter().map(|tool| tool.name.clone()).collect()
})
}
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
serde_json::to_vec(self).map_err(|e| ProviderRequestError {
message: format!("Failed to serialize MessagesRequest: {}", e),
@ -530,6 +539,69 @@ impl ProviderRequest for MessagesRequest {
false
}
}
fn get_temperature(&self) -> Option<f32> {
self.temperature
}
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
use crate::apis::openai::Message;
let mut openai_messages = Vec::new();
// Add system prompt as system message if present
if let Some(system) = &self.system {
openai_messages.push(system.clone().into());
}
// Convert each Anthropic message to OpenAI format
for msg in &self.messages {
if let Ok(converted_msgs) = TryInto::<Vec<Message>>::try_into(msg.clone()) {
openai_messages.extend(converted_msgs);
}
}
openai_messages
}
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
// Convert OpenAI messages to Anthropic format
// Separate system messages from regular messages
let mut system_messages = Vec::new();
let mut regular_messages = Vec::new();
for msg in messages {
if msg.role == crate::apis::openai::Role::System {
system_messages.push(msg.clone());
} else {
regular_messages.push(msg.clone());
}
}
// Set system prompt if there are system messages
if !system_messages.is_empty() {
// Combine all system messages into one
let system_text = system_messages.iter()
.filter_map(|msg| {
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
self.system = Some(crate::apis::anthropic::MessagesSystemPrompt::Single(system_text));
}
// Convert regular messages
self.messages = regular_messages.iter()
.filter_map(|msg| {
msg.clone().try_into().ok()
})
.collect();
}
}
impl MessagesResponse {

View file

@ -1,8 +1,8 @@
pub mod amazon_bedrock;
pub mod amazon_bedrock_binary_frame;
pub mod anthropic;
pub mod openai;
pub mod sse;
pub mod openai_responses;
pub mod streaming_shapes;
// Explicit exports to avoid naming conflicts
pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest};
@ -88,8 +88,9 @@ mod tests {
fn test_all_variants_method() {
// Test that all_variants returns the expected variants
let openai_variants = OpenAIApi::all_variants();
assert_eq!(openai_variants.len(), 1);
assert_eq!(openai_variants.len(), 2);
assert!(openai_variants.contains(&OpenAIApi::ChatCompletions));
assert!(openai_variants.contains(&OpenAIApi::Responses));
let anthropic_variants = AnthropicApi::all_variants();
assert_eq!(anthropic_variants.len(), 1);

View file

@ -7,9 +7,10 @@ use thiserror::Error;
use super::ApiDefinition;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
use crate::providers::response::{ProviderResponse, TokenUsage};
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::transforms::lib::ExtractText;
use crate::CHAT_COMPLETIONS_PATH;
use crate::{CHAT_COMPLETIONS_PATH, OPENAI_RESPONSES_API_PATH};
// ============================================================================
// OPENAI API ENUMERATION
@ -19,6 +20,7 @@ use crate::CHAT_COMPLETIONS_PATH;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum OpenAIApi {
ChatCompletions,
Responses,
// Future APIs can be added here:
// Embeddings,
// FineTuning,
@ -29,12 +31,14 @@ impl ApiDefinition for OpenAIApi {
fn endpoint(&self) -> &'static str {
match self {
OpenAIApi::ChatCompletions => CHAT_COMPLETIONS_PATH,
OpenAIApi::Responses => OPENAI_RESPONSES_API_PATH,
}
}
fn from_endpoint(endpoint: &str) -> Option<Self> {
match endpoint {
CHAT_COMPLETIONS_PATH => Some(OpenAIApi::ChatCompletions),
OPENAI_RESPONSES_API_PATH => Some(OpenAIApi::Responses),
_ => None,
}
}
@ -42,23 +46,26 @@ impl ApiDefinition for OpenAIApi {
fn supports_streaming(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
OpenAIApi::Responses => true,
}
}
fn supports_tools(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
OpenAIApi::Responses => true,
}
}
fn supports_vision(&self) -> bool {
match self {
OpenAIApi::ChatCompletions => true,
OpenAIApi::Responses => true,
}
}
fn all_variants() -> Vec<Self> {
vec![OpenAIApi::ChatCompletions]
vec![OpenAIApi::ChatCompletions, OpenAIApi::Responses]
}
}
@ -101,6 +108,12 @@ pub struct ChatCompletionsRequest {
pub top_logprobs: Option<u32>,
pub user: Option<String>,
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
// VLLM-specific parameters (used by Arch-Function)
pub top_k: Option<u32>,
pub stop_token_ids: Option<Vec<u32>>,
pub continue_final_message: Option<bool>,
pub add_generation_prompt: Option<bool>,
}
impl ChatCompletionsRequest {
@ -385,6 +398,8 @@ pub struct ChatCompletionsResponse {
pub usage: Usage,
pub system_fingerprint: Option<String>,
pub service_tier: Option<String>,
// This isn't a standard OpenAI field, but we include it for extensibility
pub metadata: Option<HashMap<String, Value>>,
}
impl Default for ChatCompletionsResponse {
@ -398,6 +413,7 @@ impl Default for ChatCompletionsResponse {
usage: Usage::default(),
system_fingerprint: None,
service_tier: None,
metadata: None,
}
}
}
@ -671,6 +687,32 @@ impl ProviderRequest for ChatCompletionsRequest {
})
}
fn get_tool_names(&self) -> Option<Vec<String>> {
// First check the 'tools' field (current API)
if let Some(tools) = &self.tools {
let names: Vec<String> = tools
.iter()
.map(|tool| tool.function.name.clone())
.collect();
if !names.is_empty() {
return Some(names);
}
}
// Fallback to 'functions' field (deprecated but still supported)
if let Some(functions) = &self.functions {
let names: Vec<String> = functions
.iter()
.map(|func| func.function.name.clone())
.collect();
if !names.is_empty() {
return Some(names);
}
}
None
}
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
serde_json::to_vec(&self).map_err(|e| ProviderRequestError {
message: format!("Failed to serialize OpenAI request: {}", e),
@ -689,6 +731,18 @@ impl ProviderRequest for ChatCompletionsRequest {
false
}
}
fn get_temperature(&self) -> Option<f32> {
self.temperature
}
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
self.messages.clone()
}
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
self.messages = messages.to_vec();
}
}
/// Implementation of ProviderResponse for ChatCompletionsResponse
@ -1068,8 +1122,9 @@ mod tests {
// Test all_variants
let all_variants = OpenAIApi::all_variants();
assert_eq!(all_variants.len(), 1);
assert_eq!(all_variants[0], OpenAIApi::ChatCompletions);
assert_eq!(all_variants.len(), 2);
assert!(all_variants.contains(&OpenAIApi::ChatCompletions));
assert!(all_variants.contains(&OpenAIApi::Responses));
}
#[test]

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,6 @@
use aws_smithy_eventstream::frame::DecodedFrame;
use aws_smithy_eventstream::frame::MessageFrameDecoder;
use bytes::Buf;
use std::collections::HashSet;
/// AWS Event Stream frame decoder wrapper
pub struct BedrockBinaryFrameDecoder<B>
@ -10,7 +9,6 @@ where
{
decoder: MessageFrameDecoder,
buffer: B,
content_block_start_indices: HashSet<i32>,
}
impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
@ -20,7 +18,6 @@ impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
Self {
decoder: MessageFrameDecoder::new(),
buffer,
content_block_start_indices: std::collections::HashSet::new(),
}
}
}
@ -33,7 +30,6 @@ where
Self {
decoder: MessageFrameDecoder::new(),
buffer,
content_block_start_indices: HashSet::new(),
}
}
@ -52,14 +48,4 @@ where
pub fn has_remaining(&self) -> bool {
self.buffer.has_remaining()
}
/// Check if a content_block_start event has been sent for the given index
pub fn has_content_block_start_been_sent(&self, index: i32) -> bool {
self.content_block_start_indices.contains(&index)
}
/// Mark that a content_block_start event has been sent for the given index
pub fn set_content_block_start_sent(&mut self, index: i32) {
self.content_block_start_indices.insert(index);
}
}

View file

@ -0,0 +1,507 @@
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
use crate::apis::anthropic::MessagesStreamEvent;
use crate::providers::streaming_response::ProviderStreamResponseType;
use std::collections::HashSet;
/// SSE Stream Buffer for Anthropic Messages API streaming.
///
/// This buffer manages the wire format for Anthropic Messages API streaming,
/// handling the specific event sequencing requirements:
/// - MessageStart → ContentBlockStart → ContentBlockDelta(s) → ContentBlockStop → MessageDelta → MessageStop
///
/// When converting from OpenAI to Anthropic format, this buffer injects the required
/// ContentBlockStart and ContentBlockStop events to maintain proper Anthropic protocol.
pub struct AnthropicMessagesStreamBuffer {
/// Buffered SSE events ready to be written to wire
buffered_events: Vec<SseEvent>,
/// Track if we've seen a message_start event
message_started: bool,
/// Track content block indices that have received ContentBlockStart events
content_block_start_indices: HashSet<i32>,
/// Track if we need to inject ContentBlockStop before message_delta
needs_content_block_stop: bool,
/// Track if we've seen a MessageDelta (so we need to send MessageStop at the end)
seen_message_delta: bool,
/// Model name to use when generating message_start events
model: Option<String>,
}
impl AnthropicMessagesStreamBuffer {
pub fn new() -> Self {
Self {
buffered_events: Vec::new(),
message_started: false,
content_block_start_indices: HashSet::new(),
needs_content_block_stop: false,
seen_message_delta: false,
model: None,
}
}
/// Check if a content_block_start event has been sent for the given index
fn has_content_block_start_been_sent(&self, index: i32) -> bool {
self.content_block_start_indices.contains(&index)
}
/// Mark that a content_block_start event has been sent for the given index
fn set_content_block_start_sent(&mut self, index: i32) {
self.content_block_start_indices.insert(index);
}
/// Helper to create and format a ContentBlockStart SSE event
fn create_content_block_start_event() -> SseEvent {
let content_block_start = MessagesStreamEvent::ContentBlockStart {
index: 0,
content_block: crate::apis::anthropic::MessagesContentBlock::Text {
text: String::new(),
cache_control: None,
},
};
let sse_string: String = content_block_start.into();
SseEvent {
data: None,
event: Some("content_block_start".to_string()),
raw_line: sse_string.clone(),
sse_transformed_lines: sse_string,
provider_stream_response: None,
}
}
/// Helper to create and format a MessageStart SSE event
fn create_message_start_event(model: &str) -> SseEvent {
let message_start = MessagesStreamEvent::MessageStart {
message: crate::apis::anthropic::MessagesStreamMessage {
id: format!("msg_{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
obj_type: "message".to_string(),
role: crate::apis::anthropic::MessagesRole::Assistant,
content: vec![],
model: model.to_string(),
stop_reason: None,
stop_sequence: None,
usage: crate::apis::anthropic::MessagesUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
},
};
let sse_string: String = message_start.into();
SseEvent {
data: None,
event: Some("message_start".to_string()),
raw_line: sse_string.clone(),
sse_transformed_lines: sse_string,
provider_stream_response: None,
}
}
/// Helper to create and format a ContentBlockStop SSE event
fn create_content_block_stop_event() -> SseEvent {
let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 };
let sse_string: String = content_block_stop.into();
SseEvent {
data: None,
event: Some("content_block_stop".to_string()),
raw_line: sse_string.clone(),
sse_transformed_lines: sse_string,
provider_stream_response: None,
}
}
}
impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
fn add_transformed_event(&mut self, event: SseEvent) {
// Skip ping messages
if event.should_skip() {
return;
}
// FIRST: Try to extract model name from the raw event data before transformation
// The provider_stream_response has already been transformed to Anthropic format,
// so we need to extract the model from the original raw data if available
if self.model.is_none() {
if let Some(data) = &event.data {
// Try to parse as JSON and extract model field
if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
if let Some(model) = json.get("model").and_then(|m| m.as_str()) {
self.model = Some(model.to_string());
}
}
}
}
// Match directly on the provider response type to handle event processing
// We match on a reference first to determine the type, then move the event
match &event.provider_stream_response {
Some(ProviderStreamResponseType::MessagesStreamEvent(evt)) => {
match evt {
MessagesStreamEvent::MessageStart { .. } => {
// Add the message_start event
self.buffered_events.push(event);
self.message_started = true;
}
MessagesStreamEvent::ContentBlockStart { index, .. } => {
let index = *index as i32;
// Inject message_start if needed
if !self.message_started {
let model = self.model.as_deref().unwrap_or("unknown");
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
self.buffered_events.push(message_start);
self.message_started = true;
}
// Add the content_block_start event (from tool calls or other sources)
self.buffered_events.push(event);
self.set_content_block_start_sent(index);
self.needs_content_block_stop = true;
}
MessagesStreamEvent::ContentBlockDelta { index, .. } => {
let index = *index as i32;
// Inject message_start if needed
if !self.message_started {
let model = self.model.as_deref().unwrap_or("unknown");
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
self.buffered_events.push(message_start);
self.message_started = true;
}
// Check if ContentBlockStart was sent for this index
if !self.has_content_block_start_been_sent(index) {
// Inject ContentBlockStart before delta
let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event();
self.buffered_events.push(content_block_start);
self.set_content_block_start_sent(index);
self.needs_content_block_stop = true;
}
// Content deltas are between ContentBlockStart and ContentBlockStop
self.buffered_events.push(event);
}
MessagesStreamEvent::MessageDelta { usage, .. } => {
// Inject ContentBlockStop before message_delta
if self.needs_content_block_stop {
let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event();
self.buffered_events.push(content_block_stop);
self.needs_content_block_stop = false;
}
// Check if the last event was also a MessageDelta - if so, merge them
// This handles Bedrock's split of stop_reason (MessageStop) and usage (Metadata)
if let Some(last_event) = self.buffered_events.last_mut() {
if let Some(ProviderStreamResponseType::MessagesStreamEvent(
MessagesStreamEvent::MessageDelta {
usage: last_usage,
..
}
)) = &mut last_event.provider_stream_response {
// Merge: take stop_reason from first, usage from second (if non-zero)
if usage.input_tokens > 0 || usage.output_tokens > 0 {
*last_usage = usage.clone();
}
// Mark that we've seen MessageDelta (need to send MessageStop later)
self.seen_message_delta = true;
// Don't push the new event, we've merged it
return;
}
}
// No previous MessageDelta to merge with, add this one
self.buffered_events.push(event);
self.seen_message_delta = true;
}
MessagesStreamEvent::ContentBlockStop { .. } => {
// ContentBlockStop received from upstream (e.g., Bedrock)
// Clear the flag so we don't inject another one
self.needs_content_block_stop = false;
self.buffered_events.push(event);
}
MessagesStreamEvent::MessageStop => {
// MessageStop received from upstream (e.g., OpenAI via [DONE])
// Clear the flag so we don't inject another one
self.seen_message_delta = false;
self.buffered_events.push(event);
}
_ => {
// Other Anthropic event types (Ping, etc.), just accumulate
self.buffered_events.push(event);
}
}
}
_ => {
// Non-Anthropic events or events without provider_stream_response, just accumulate
self.buffered_events.push(event);
}
}
}
fn into_bytes(&mut self) -> Vec<u8> {
// Convert all accumulated events to bytes and clear buffer
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
// Inject MessageStop after MessageDelta if we've seen one
// This completes the Anthropic Messages API event sequence
if self.seen_message_delta {
let message_stop = MessagesStreamEvent::MessageStop;
let sse_string: String = message_stop.into();
let message_stop_event = SseEvent {
data: None,
event: Some("message_stop".to_string()),
raw_line: sse_string.clone(),
sse_transformed_lines: sse_string,
provider_stream_response: None,
};
self.buffered_events.push(message_stop_event);
self.seen_message_delta = false;
}
let mut buffer = Vec::new();
for event in self.buffered_events.drain(..) {
let event_bytes: Vec<u8> = event.into();
buffer.extend_from_slice(&event_bytes);
}
buffer
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
use crate::apis::streaming_shapes::sse::SseStreamIter;
#[test]
fn test_openai_to_anthropic_complete_transformation() {
// OpenAI ChatCompletions input that will be transformed to Anthropic Messages API
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
data: [DONE]"#;
println!("\n{}", "=".repeat(80));
println!("TEST 1: OpenAI → Anthropic Messages API Complete Transformation");
println!("{}", "=".repeat(80));
println!("\nRAW INPUT (OpenAI ChatCompletions):");
println!("{}", "-".repeat(80));
println!("{}", raw_input);
// Setup API configuration for transformation (client wants Anthropic, upstream is OpenAI)
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Parse events and apply transformation
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
println!("{}", "-".repeat(80));
println!("{}", output);
// Assertions
assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("event: message_start"), "Should have message_start");
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
let delta_count = output.matches("event: content_block_delta").count();
assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events");
// Verify both pieces of content are present
assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'");
assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'");
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
assert!(output.contains("event: message_delta"), "Should have message_delta");
assert!(output.contains("event: message_stop"), "Should have message_stop");
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API");
println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop");
println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count);
println!("✓ Complete stream with message_stop");
println!("✓ Proper Anthropic protocol sequencing\n");
}
#[test]
fn test_openai_to_anthropic_partial_transformation() {
// Partial OpenAI ChatCompletions stream - no [DONE]
let raw_input = r#"data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"The weather"},"finish_reason":null}]}
data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" in San Francisco"},"finish_reason":null}]}
data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" is"},"finish_reason":null}]}"#;
println!("\n{}", "=".repeat(80));
println!("TEST 2: OpenAI → Anthropic Partial Transformation (NO [DONE])");
println!("{}", "=".repeat(80));
println!("\nRAW INPUT (OpenAI ChatCompletions - NO [DONE]):");
println!("{}", "-".repeat(80));
println!("{}", raw_input);
// Setup API configuration for transformation
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Parse and transform events
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
println!("{}", "-".repeat(80));
println!("{}", output);
// Assertions
assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("event: message_start"), "Should have message_start");
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
let delta_count = output.matches("event: content_block_delta").count();
assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events");
// Verify all three pieces of content are present
assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta");
assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta");
assert!(output.contains("\"text\":\" is\""), "Should have third content delta");
// For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop
// because the stream may continue. This is correct behavior - only inject lifecycle events
// when we have explicit signals from upstream (finish_reason, [DONE], etc.)
assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream");
// Should NOT have completion events
assert!(!output.contains("event: message_delta"), "Should NOT have message_delta");
assert!(!output.contains("event: message_stop"), "Should NOT have message_stop");
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)");
println!("✓ Injected: message_start, content_block_start at beginning");
println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count);
println!("✓ NO completion events (partial stream, no [DONE])");
println!("✓ Buffer maintains Anthropic protocol for active streams\n");
}
#[test]
fn test_openai_tool_calling_to_anthropic_transformation() {
// OpenAI ChatCompletions tool calling stream
let raw_input = r#"data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_2Uzw0AEZQeOex2CP2TKjcLKc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"uSpCcO"}
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"24WSqt08jtf"}
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"6CleV8twTxkKYg"}
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"San"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Francisco"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"1XLz89l3v"}
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":","}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"sh"}
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" CA"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"obfuscation":"I"}
data: [DONE]"#;
println!("\n{}", "=".repeat(80));
println!("TEST 3: OpenAI Tool Calling → Anthropic Messages API Transformation");
println!("{}", "=".repeat(80));
println!("\nRAW INPUT (OpenAI ChatCompletions with Tool Calls):");
println!("{}", "-".repeat(80));
println!("{}", raw_input);
// Setup API configuration for transformation
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Parse and transform events
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
println!("{}", "-".repeat(80));
println!("{}", output);
// Assertions for tool calling transformation
assert!(!output_bytes.is_empty(), "Should have output");
// Should have lifecycle events (injected by buffer)
assert!(output.contains("event: message_start"), "Should have message_start (injected)");
assert!(output.contains("event: content_block_start"), "Should have content_block_start");
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
assert!(output.contains("event: message_delta"), "Should have message_delta");
assert!(output.contains("event: message_stop"), "Should have message_stop");
// Should have tool_use content block
assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type");
assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name");
assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID");
// Count input_json_delta events - should match the number of argument chunks
let delta_count = output.matches("event: content_block_delta").count();
assert!(delta_count >= 8, "Should have at least 8 input_json_delta events");
// Verify argument deltas are present
assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type");
assert!(output.contains("\"partial_json\":"), "Should have partial_json field");
// Verify the accumulated arguments contain the location
assert!(output.contains("San"), "Arguments should contain 'San'");
assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'");
assert!(output.contains("CA"), "Arguments should contain 'CA'");
// Verify stop reason is tool_use
assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use");
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Complete tool calling transformation: OpenAI → Anthropic Messages API");
println!("✓ Injected lifecycle: message_start, content_block_stop");
println!("✓ Tool metadata: name='get_weather', id='call_2Uzw0AEZQeOex2CP2TKjcLKc'");
println!("✓ Argument deltas: {} events", delta_count);
println!("✓ Complete JSON arguments: '{{\"location\":\"San Francisco, CA\"}}'");
println!("✓ Stop reason: tool_use");
println!("✓ Proper Anthropic tool_use protocol\n");
}
}

View file

@ -0,0 +1,39 @@
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
/// OpenAI Chat Completions SSE Stream Buffer for when client and upstream APIs match.
pub struct OpenAIChatCompletionsStreamBuffer {
/// Buffered SSE events ready to be written to wire
buffered_events: Vec<SseEvent>,
}
impl OpenAIChatCompletionsStreamBuffer {
pub fn new() -> Self {
Self {
buffered_events: Vec::new(),
}
}
}
impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer {
fn add_transformed_event(&mut self, event: SseEvent) {
// Skip ping messages
if event.should_skip() {
return;
}
// For OpenAI Chat Completions, events are already properly transformed
// Just accumulate them for later wire transmission
self.buffered_events.push(event);
}
fn into_bytes(&mut self) -> Vec<u8> {
// No finalization needed for OpenAI Chat Completions
// The [DONE] marker is already handled by the transformation layer
let mut buffer = Vec::new();
for event in self.buffered_events.drain(..) {
let event_bytes: Vec<u8> = event.into();
buffer.extend_from_slice(&event_bytes);
}
buffer
}
}

View file

@ -0,0 +1,7 @@
pub mod sse;
pub mod sse_chunk_processor;
pub mod amazon_bedrock_binary_frame;
pub mod anthropic_streaming_buffer;
pub mod chat_completions_streaming_buffer;
pub mod passthrough_streaming_buffer;
pub mod responses_api_streaming_buffer;

View file

@ -0,0 +1,95 @@
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
/// Passthrough SSE Stream Buffer for when client and upstream APIs match.
pub struct PassthroughStreamBuffer {
/// Buffered SSE events ready to be written to wire
buffered_events: Vec<SseEvent>,
}
impl PassthroughStreamBuffer {
pub fn new() -> Self {
Self {
buffered_events: Vec::new(),
}
}
}
impl SseStreamBufferTrait for PassthroughStreamBuffer {
fn add_transformed_event(&mut self, event: SseEvent) {
// Skip ping messages
if event.should_skip() {
return;
}
// Skip events with empty transformed lines (e.g., suppressed event-only lines)
if event.sse_transformed_lines.is_empty() {
return;
}
// Just accumulate events as-is
self.buffered_events.push(event);
}
fn into_bytes(&mut self) -> Vec<u8> {
// No finalization needed for passthrough - just convert accumulated events to bytes
let mut buffer = Vec::new();
for event in self.buffered_events.drain(..) {
let event_bytes: Vec<u8> = event.into();
buffer.extend_from_slice(&event_bytes);
}
buffer
}
}
#[cfg(test)]
mod tests {
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait};
#[test]
fn test_chat_completions_passthrough_buffer() {
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
data: [DONE]"#;
println!("\n{}", "=".repeat(80));
println!("TEST 1: ChatCompletions Passthrough Buffer");
println!("{}", "=".repeat(80));
println!("\nRAW INPUT (ChatCompletions):");
println!("{}", "-".repeat(80));
println!("{}", raw_input);
// Parse and process through buffer
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
let mut buffer = PassthroughStreamBuffer::new();
for event in stream_iter {
buffer.add_transformed_event(event);
}
let output_bytes = buffer.into_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):");
println!("{}", "-".repeat(80));
println!("{}", output);
// Assertions
assert!(!output_bytes.is_empty());
assert!(output.contains("chatcmpl-123"));
assert!(output.contains("[DONE]"));
assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input");
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Passthrough buffer: input = output (no transformation)");
println!("✓ All events preserved including [DONE]");
println!("✓ Function calling events preserved\n");
}
}

View file

@ -0,0 +1,649 @@
use std::collections::HashMap;
use log::debug;
use crate::apis::openai_responses::{
ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus,
ResponseStatus, TextConfig, TextFormat, Reasoning,
};
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
/// Helper to convert ResponseAPIStreamEvent to SseEvent
fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
let event_type = match &event {
ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created",
ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress",
ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed",
ResponsesAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added",
ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done",
ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta",
ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done",
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta",
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done",
unknown => {
debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown);
"unknown"
}
};
let json_data = match serde_json::to_string(&event) {
Ok(data) => data,
Err(e) => {
debug!("Error serializing ResponsesAPIStreamEvent to JSON: {}", e);
String::new()
}
};
let wire_format: String = event.into();
SseEvent {
data: Some(json_data),
event: Some(event_type.to_string()),
raw_line: wire_format.clone(),
sse_transformed_lines: wire_format,
provider_stream_response: None,
}
}
/// SSE Stream Buffer for ResponsesAPIStreamEvent with full lifecycle management.
///
/// This buffer manages the wire format for v1/responses streaming, handling
/// delta events and emitting complete lifecycle events.
///
pub struct ResponsesAPIStreamBuffer {
/// Sequence number for events
sequence_number: i32,
/// Track item IDs by output index
item_ids: HashMap<i32, String>,
/// Response metadata
response_id: Option<String>,
model: Option<String>,
created_at: Option<i64>,
/// Full response metadata from upstream (tools, temperature, etc.)
/// This is extracted from the first upstream event and used to build
/// complete response.created and response.in_progress events
upstream_response_metadata: Option<ResponsesAPIResponse>,
/// Lifecycle state flags
created_emitted: bool,
in_progress_emitted: bool,
/// Track which output items we've added
output_items_added: HashMap<i32, String>, // output_index -> item_id
/// Accumulated content by item_id
text_content: HashMap<String, String>,
function_arguments: HashMap<String, String>,
/// Tool call metadata by output_index
tool_call_metadata: HashMap<i32, (String, String)>, // output_index -> (call_id, name)
/// Final completed response (for logging/tracing/persistence)
completed_response: Option<ResponsesAPIResponse>,
/// Buffered SSE events ready to be written to wire
buffered_events: Vec<SseEvent>,
}
impl ResponsesAPIStreamBuffer {
pub fn new() -> Self {
Self {
sequence_number: 0,
item_ids: HashMap::new(),
response_id: None,
model: None,
created_at: None,
upstream_response_metadata: None,
created_emitted: false,
in_progress_emitted: false,
output_items_added: HashMap::new(),
text_content: HashMap::new(),
function_arguments: HashMap::new(),
tool_call_metadata: HashMap::new(),
completed_response: None,
buffered_events: Vec::new(),
}
}
fn next_sequence_number(&mut self) -> i32 {
let seq = self.sequence_number;
self.sequence_number += 1;
seq
}
fn generate_item_id(prefix: &str) -> String {
format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", ""))
}
fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String {
if let Some(id) = self.item_ids.get(&output_index) {
return id.clone();
}
let id = ResponsesAPIStreamBuffer::generate_item_id(prefix);
self.item_ids.insert(output_index, id.clone());
id
}
/// Create response.created event
fn create_response_created_event(&mut self) -> SseEvent {
let response = self.build_response(ResponseStatus::InProgress);
let event = ResponsesAPIStreamEvent::ResponseCreated {
response,
sequence_number: self.next_sequence_number(),
};
event_to_sse(event)
}
/// Create response.in_progress event
fn create_response_in_progress_event(&mut self) -> SseEvent {
let response = self.build_response(ResponseStatus::InProgress);
let event = ResponsesAPIStreamEvent::ResponseInProgress {
response,
sequence_number: self.next_sequence_number(),
};
event_to_sse(event)
}
/// Create output_item.added event for text
fn create_output_item_added_event(&mut self, output_index: i32, item_id: &str) -> SseEvent {
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
output_index,
item: OutputItem::Message {
id: item_id.to_string(),
status: OutputItemStatus::InProgress,
role: "assistant".to_string(),
content: vec![],
},
sequence_number: self.next_sequence_number(),
};
event_to_sse(event)
}
/// Create output_item.added event for tool call
fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent {
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
output_index,
item: OutputItem::FunctionCall {
id: item_id.to_string(),
status: OutputItemStatus::InProgress,
call_id: call_id.to_string(),
name: Some(name.to_string()),
arguments: Some(String::new()),
},
sequence_number: self.next_sequence_number(),
};
event_to_sse(event)
}
/// Build the base response object with current state
fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse {
// If we have upstream metadata, use it as a base and update status/output
if let Some(upstream) = &self.upstream_response_metadata {
let mut response = upstream.clone();
response.status = status;
// Don't update output here - will be set in finalize()
return response;
}
// Fallback: build a minimal response from local state
ResponsesAPIResponse {
id: self.response_id.clone().unwrap_or_default(),
object: "response".to_string(),
created_at: self.created_at.unwrap_or(0),
status,
error: None,
incomplete_details: None,
instructions: None,
model: self.model.clone().unwrap_or_else(|| "unknown".to_string()),
output: vec![],
usage: None,
parallel_tool_calls: true,
conversation: None,
previous_response_id: None,
tools: vec![],
tool_choice: "auto".to_string(),
temperature: 1.0,
top_p: 1.0,
metadata: HashMap::new(),
truncation: Some("disabled".to_string()),
max_output_tokens: None,
reasoning: Some(Reasoning {
effort: None,
summary: None,
}),
store: Some(true),
text: Some(TextConfig {
format: TextFormat::Text,
}),
audio: None,
modalities: None,
service_tier: Some("auto".to_string()),
background: Some(false),
top_logprobs: Some(0),
max_tool_calls: None,
}
}
/// Get the completed response after finalization (for logging/tracing/persistence)
pub fn get_completed_response(&self) -> Option<&ResponsesAPIResponse> {
self.completed_response.as_ref()
}
/// Finalize the response by emitting all *.done events and response.completed.
/// Call this when the stream is complete (after seeing [DONE] or end_of_stream).
pub fn finalize(&mut self) {
let mut events = Vec::new();
// Emit done events for all accumulated content
// Text content done events
let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect();
for (item_id, content) in text_items {
let output_index = self.output_items_added.iter()
.find(|(_, id)| **id == item_id)
.map(|(idx, _)| *idx)
.unwrap_or(0);
let seq1 = self.next_sequence_number();
let text_done_event = ResponsesAPIStreamEvent::ResponseOutputTextDone {
item_id: item_id.clone(),
output_index,
content_index: 0,
text: content.clone(),
logprobs: vec![],
sequence_number: seq1,
};
events.push(event_to_sse(text_done_event));
let seq2 = self.next_sequence_number();
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
output_index,
item: OutputItem::Message {
id: item_id.clone(),
status: OutputItemStatus::Completed,
role: "assistant".to_string(),
content: vec![],
},
sequence_number: seq2,
};
events.push(event_to_sse(item_done_event));
}
// Function call done events
let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect();
for (item_id, arguments) in func_items {
let output_index = self.output_items_added.iter()
.find(|(_, id)| **id == item_id)
.map(|(idx, _)| *idx)
.unwrap_or(0);
let seq1 = self.next_sequence_number();
let args_done_event = ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone {
output_index,
item_id: item_id.clone(),
arguments: arguments.clone(),
sequence_number: seq1,
};
events.push(event_to_sse(args_done_event));
let (call_id, name) = self.tool_call_metadata.get(&output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
let seq2 = self.next_sequence_number();
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
output_index,
item: OutputItem::FunctionCall {
id: item_id.clone(),
status: OutputItemStatus::Completed,
call_id,
name: Some(name),
arguments: Some(arguments.clone()),
},
sequence_number: seq2,
};
events.push(event_to_sse(item_done_event));
}
// Build final response
let mut output_items = Vec::new();
// Build complete output array by iterating through all output indices in order
let max_output_index = self.output_items_added.keys().max().copied().unwrap_or(-1);
for output_index in 0..=max_output_index {
if let Some(item_id) = self.output_items_added.get(&output_index) {
// Check if this is a function call
if let Some(arguments) = self.function_arguments.get(item_id) {
let (call_id, name) = self.tool_call_metadata.get(&output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
output_items.push(OutputItem::FunctionCall {
id: item_id.clone(),
status: OutputItemStatus::Completed,
call_id,
name: Some(name),
arguments: Some(arguments.clone()),
});
}
// Check if this is a text message
else if let Some(text) = self.text_content.get(item_id) {
use crate::apis::openai_responses::OutputContent;
output_items.push(OutputItem::Message {
id: item_id.clone(),
status: OutputItemStatus::Completed,
role: "assistant".to_string(),
content: vec![OutputContent::OutputText {
text: text.clone(),
annotations: vec![],
logprobs: None,
}],
});
}
}
}
let mut final_response = self.build_response(ResponseStatus::Completed);
final_response.output = output_items;
// Store completed response
self.completed_response = Some(final_response.clone());
// Emit response.completed
let seq_final = self.next_sequence_number();
let completed_event = ResponsesAPIStreamEvent::ResponseCompleted {
response: final_response,
sequence_number: seq_final,
};
events.push(event_to_sse(completed_event));
// Add all finalization events to the buffer
self.buffered_events.extend(events);
}
}
impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
fn add_transformed_event(&mut self, event: SseEvent) {
// Skip ping messages
if event.should_skip() {
return;
}
// Handle [DONE] marker - trigger finalization
if event.is_done() {
self.finalize();
return;
}
// Extract the ResponseAPIStreamEvent from the SseEvent's provider_stream_response
let provider_response = match event.provider_stream_response.as_ref() {
Some(response) => response,
None => {
eprintln!("Warning: Event missing provider_stream_response");
return;
}
};
// Extract ResponseAPIStreamEvent from the enum
let stream_event = match provider_response {
crate::providers::streaming_response::ProviderStreamResponseType::ResponseAPIStreamEvent(evt) => evt,
_ => {
eprintln!("Warning: Expected ResponseAPIStreamEvent in provider_stream_response");
return;
}
};
let mut events = Vec::new();
// Capture upstream metadata from ResponseCreated or ResponseInProgress if present
match stream_event {
ResponsesAPIStreamEvent::ResponseCreated { response, .. } |
ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
if self.upstream_response_metadata.is_none() {
// Store the full upstream response as our metadata template
self.upstream_response_metadata = Some(response.clone());
// Also extract basic fields
self.response_id = Some(response.id.clone());
self.model = Some(response.model.clone());
self.created_at = Some(response.created_at);
}
// Don't emit these - we'll generate our own lifecycle events
return;
}
_ => {}
}
// Emit lifecycle events if not yet emitted
if !self.created_emitted {
// Initialize metadata from first event if needed
if self.response_id.is_none() {
self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", "")));
self.created_at = Some(std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64);
self.model = Some("unknown".to_string()); // Will be set by caller if available
}
events.push(self.create_response_created_event());
self.created_emitted = true;
}
if !self.in_progress_emitted {
events.push(self.create_response_in_progress_event());
self.in_progress_emitted = true;
}
// Process the delta event
match stream_event {
ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => {
let item_id = self.get_or_create_item_id(*output_index, "msg");
// Emit output_item.added if this is the first time we see this output index
if !self.output_items_added.contains_key(output_index) {
self.output_items_added.insert(*output_index, item_id.clone());
events.push(self.create_output_item_added_event(*output_index, &item_id));
}
// Accumulate text content
self.text_content.entry(item_id.clone())
.and_modify(|content| content.push_str(delta))
.or_insert_with(|| delta.clone());
// Emit text delta with filled-in item_id and sequence_number
let mut delta_event = stream_event.clone();
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
*id = item_id;
*seq = self.next_sequence_number();
}
events.push(event_to_sse(delta_event));
}
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => {
let item_id = self.get_or_create_item_id(*output_index, "fc");
// Store metadata if provided (from initial tool call event)
if let (Some(cid), Some(n)) = (call_id, name) {
self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone()));
}
// Emit output_item.added if this is the first time we see this tool call
if !self.output_items_added.contains_key(output_index) {
self.output_items_added.insert(*output_index, item_id.clone());
// For tool calls, we need call_id and name from metadata
// These should now be populated from the event itself
let (call_id, name) = self.tool_call_metadata.get(output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name));
}
// Accumulate function arguments
self.function_arguments.entry(item_id.clone())
.and_modify(|args| args.push_str(delta))
.or_insert_with(|| delta.clone());
// Emit function call arguments delta with filled-in item_id and sequence_number
let mut delta_event = stream_event.clone();
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
*id = item_id;
*seq = self.next_sequence_number();
}
events.push(event_to_sse(delta_event));
}
_ => {
// For other event types, just pass through with sequence number
let other_event = stream_event.clone();
// TODO: Add sequence number to other event types if needed
events.push(event_to_sse(other_event));
}
}
// Store all generated events in the buffer
self.buffered_events.extend(events);
}
fn into_bytes(&mut self) -> Vec<u8> {
// For Responses API, we need special handling:
// - Most events are already in buffered_events from add_transformed_event
// - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream
// - Just flush the accumulated events and clear the buffer
// Convert all accumulated events to bytes and clear buffer
let mut buffer = Vec::new();
for event in self.buffered_events.drain(..) {
let event_bytes: Vec<u8> = event.into();
buffer.extend_from_slice(&event_bytes);
}
buffer
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::openai::OpenAIApi;
use crate::apis::streaming_shapes::sse::SseStreamIter;
#[test]
fn test_chat_completions_to_responses_api_transformation() {
// ChatCompletions input that will be transformed to ResponsesAPI
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
data: [DONE]"#;
println!("\n{}", "=".repeat(80));
println!("TEST 2: ChatCompletions → ResponsesAPI Transformation (with [DONE])");
println!("{}", "=".repeat(80));
println!("\nRAW INPUT (ChatCompletions):");
println!("{}", "-".repeat(80));
println!("{}", raw_input);
// Setup API configuration for transformation
let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Parse events and apply transformation
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
let mut buffer = ResponsesAPIStreamBuffer::new();
for raw_event in stream_iter {
// Transform the event using the client/upstream APIs
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
println!("{}", "-".repeat(80));
println!("{}", output);
// Assertions
assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("response.created"), "Should have response.created");
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
assert!(output.contains("response.output_text.delta"), "Should have text deltas");
assert!(output.contains("response.output_text.done"), "Should have text.done");
assert!(output.contains("response.output_item.done"), "Should have output_item.done");
assert!(output.contains("response.completed"), "Should have response.completed");
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Lifecycle events: response.created, response.in_progress, response.completed");
println!("✓ Output item lifecycle: output_item.added, output_item.done");
println!("✓ Text streaming: output_text.delta (2 deltas), output_text.done");
println!("✓ Complete transformation with finalization ([DONE] processed)\n");
}
#[test]
fn test_partial_streaming_incremental_output() {
let raw_input = r#"data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_mD5ggLKk3SMKGPFqFdcpKg6q","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"PCFrpy"}
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"TC58A3QEIx8"}
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"PK4oFzlVlGTUP5"}"#;
println!("\n{}", "=".repeat(80));
println!("TEST 3: Partial Streaming - Function Calling (NO [DONE])");
println!("{}", "=".repeat(80));
println!("\nRAW INPUT (ChatCompletions - NO [DONE]):");
println!("{}", "-".repeat(80));
println!("{}", raw_input);
// Setup API configuration for transformation
let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Transform all events
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
let mut buffer = ResponsesAPIStreamBuffer::new();
for raw_event in stream_iter {
let transformed = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed);
}
let output_bytes = buffer.into_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
println!("{}", "-".repeat(80));
println!("{}", output);
// Assertions
assert!(output.contains("response.created"), "Should have response.created");
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type");
assert!(output.contains("\"name\":\"get_weather\""), "Should have function name");
assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id");
let delta_count = output.matches("event: response.function_call_arguments.delta").count();
assert_eq!(delta_count, 4, "Should have 4 delta events");
assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done");
assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done");
assert!(!output.contains("response.completed"), "Should NOT have response.completed");
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Lifecycle events: response.created, response.in_progress");
println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'");
println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)");
println!("✓ NO completion events (partial stream, no [DONE])");
println!("✓ Arguments accumulated: '{{\"location\":\"'\n");
}
}

View file

@ -1,10 +1,73 @@
use crate::providers::response::ProviderStreamResponse;
use crate::providers::response::ProviderStreamResponseType;
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::providers::streaming_response::ProviderStreamResponseType;
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer;
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
use std::str::FromStr;
/// Trait defining the interface for SSE stream buffers.
///
/// This trait is implemented by both the enum `SseStreamBuffer` (for zero-cost dispatch)
/// and individual buffer implementations (for direct use).
///
pub trait SseStreamBufferTrait: Send + Sync {
/// Add a transformed SSE event to the buffer.
///
/// The buffer may inject additional events as needed based on internal state.
/// For example, Anthropic buffers inject ContentBlockStart before the first ContentBlockDelta.
///
/// All events (original + injected) are accumulated internally for the next `into_bytes()` call.
///
/// # Arguments
/// * `event` - A transformed SSE event to accumulate
fn add_transformed_event(&mut self, event: SseEvent);
/// Get bytes for all accumulated events since the last call.
///
/// This method:
/// - Converts all buffered events to wire format bytes
/// - Clears the internal event buffer
/// - Preserves state for subsequent `add_transformed_event()` calls
///
/// Call this after processing each chunk of upstream events to get bytes for immediate transmission.
///
/// # Returns
/// Bytes ready for wire transmission (may be empty if no events were accumulated)
fn into_bytes(&mut self) -> Vec<u8>;
}
/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction
pub enum SseStreamBuffer {
Passthrough(PassthroughStreamBuffer),
OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer),
AnthropicMessages(AnthropicMessagesStreamBuffer),
OpenAIResponses(ResponsesAPIStreamBuffer),
}
impl SseStreamBufferTrait for SseStreamBuffer {
fn add_transformed_event(&mut self, event: SseEvent) {
match self {
Self::Passthrough(buffer) => buffer.add_transformed_event(event),
Self::OpenAIChatCompletions(buffer) => buffer.add_transformed_event(event),
Self::AnthropicMessages(buffer) => buffer.add_transformed_event(event),
Self::OpenAIResponses(buffer) => buffer.add_transformed_event(event),
}
}
fn into_bytes(&mut self) -> Vec<u8> {
match self {
Self::Passthrough(buffer) => buffer.into_bytes(),
Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(),
Self::AnthropicMessages(buffer) => buffer.into_bytes(),
Self::OpenAIResponses(buffer) => buffer.into_bytes(),
}
}
}
// ============================================================================
// SSE EVENT CONTAINER
// ============================================================================
@ -22,16 +85,31 @@ pub struct SseEvent {
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
pub sse_transformed_lines: String, // The complete line as received including "data: " prefix and "\n\n"
#[serde(skip_serializing, skip_deserializing)]
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
}
impl SseEvent {
/// Create an SseEvent from a ProviderStreamResponseType
/// This is useful for binary frame formats (like Bedrock) that need to be converted to SSE
pub fn from_provider_response(response: ProviderStreamResponseType) -> Self {
// Convert the provider response to SSE format string
let sse_string: String = response.clone().into();
SseEvent {
data: None, // Data is embedded in sse_transformed_lines
event: None, // Event type is embedded in sse_transformed_lines
raw_line: sse_string.clone(),
sse_transformed_lines: sse_string,
provider_stream_response: Some(response),
}
}
/// Check if this event represents the end of the stream
pub fn is_done(&self) -> bool {
self.data == Some("[DONE]".into())
self.data == Some("[DONE]".into()) || self.event == Some("message_stop".into())
}
/// Check if this event should be skipped during processing
@ -61,23 +139,35 @@ impl FromStr for SseEvent {
type Err = SseParseError;
fn from_str(line: &str) -> Result<Self, Self::Err> {
if line.starts_with("data: ") {
let data: String = line[6..].to_string(); // Remove "data: " prefix
if data.is_empty() {
// Trim leading/trailing whitespace for parsing
let trimmed_line = line.trim();
// Skip empty or whitespace-only lines (SSE event separators)
if trimmed_line.is_empty() {
return Err(SseParseError {
message: "Empty line (SSE event separator)".to_string(),
});
}
if trimmed_line.starts_with("data: ") {
let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix
// Allow empty data content after "data: " prefix
// This handles cases like "data: " followed by newline
if data.trim().is_empty() {
return Err(SseParseError {
message: "Empty data field is not a valid SSE event".to_string(),
message: "Empty data field after 'data: ' prefix".to_string(),
});
}
Ok(SseEvent {
data: Some(data),
event: None,
raw_line: line.to_string(),
sse_transform_buffer: line.to_string(),
// Preserve original line format for passthrough, use trimmed for transformations
sse_transformed_lines: line.to_string(),
provider_stream_response: None,
})
} else if line.starts_with("event: ") {
//used by Anthropic
let event_type = line[7..].to_string();
} else if trimmed_line.starts_with("event: ") {
let event_type = trimmed_line[7..].to_string();
if event_type.is_empty() {
return Err(SseParseError {
message: "Empty event field is not a valid SSE event".to_string(),
@ -87,12 +177,13 @@ impl FromStr for SseEvent {
data: None,
event: Some(event_type),
raw_line: line.to_string(),
sse_transform_buffer: line.to_string(),
// Preserve original line format for passthrough, use trimmed for transformations
sse_transformed_lines: line.to_string(),
provider_stream_response: None,
})
} else {
Err(SseParseError {
message: format!("Line does not start with 'data: ' or 'event: ': {}", line),
message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line),
})
}
}
@ -100,14 +191,22 @@ impl FromStr for SseEvent {
impl fmt::Display for SseEvent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.sse_transform_buffer)
write!(f, "{}", self.sse_transformed_lines)
}
}
// Into implementation to convert SseEvent to bytes for response buffer
impl Into<Vec<u8>> for SseEvent {
fn into(self) -> Vec<u8> {
format!("{}\n\n", self.sse_transform_buffer).into_bytes()
// For generated events (like ResponsesAPI), sse_transformed_lines already includes trailing \n\n
// For parsed events (like passthrough), we need to add the \n\n separator
if self.sse_transformed_lines.ends_with("\n\n") {
// Already properly formatted with trailing newlines
self.sse_transformed_lines.into_bytes()
} else {
// Add SSE event separator
format!("{}\n\n", self.sse_transformed_lines).into_bytes()
}
}
}

View file

@ -0,0 +1,241 @@
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamIter};
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
/// Stateful processor for handling SSE chunks that may contain incomplete events.
///
/// This processor buffers incomplete SSE event bytes when transformation fails
/// (e.g., due to incomplete JSON) and prepends them to the next chunk for retry.
pub struct SseChunkProcessor {
/// Buffered bytes from incomplete SSE events across chunks
incomplete_event_buffer: Vec<u8>,
}
impl SseChunkProcessor {
pub fn new() -> Self {
Self {
incomplete_event_buffer: Vec::new(),
}
}
/// Process a chunk of SSE data, handling incomplete events across chunk boundaries.
///
/// Returns successfully transformed events. Incomplete events are buffered internally
/// and will be retried when more data arrives in the next chunk.
///
/// # Arguments
/// * `chunk` - Raw bytes from upstream SSE stream
/// * `client_api` - The API format the client expects
/// * `upstream_api` - The API format from the upstream provider
///
/// # Returns
/// * `Ok(Vec<SseEvent>)` - Successfully transformed events ready for client
/// * `Err(String)` - Fatal error that cannot be recovered by buffering
pub fn process_chunk(
&mut self,
chunk: &[u8],
client_api: &SupportedAPIsFromClient,
upstream_api: &SupportedUpstreamAPIs,
) -> Result<Vec<SseEvent>, String> {
// Combine buffered incomplete event with new chunk
let mut combined_data = std::mem::take(&mut self.incomplete_event_buffer);
combined_data.extend_from_slice(chunk);
// Parse using SseStreamIter
let sse_iter = match SseStreamIter::try_from(combined_data.as_slice()) {
Ok(iter) => iter,
Err(e) => return Err(format!("Failed to create SSE iterator: {}", e)),
};
let mut transformed_events = Vec::new();
// Process each parsed SSE event
for sse_event in sse_iter {
// Try to transform the event (this is where incomplete JSON fails)
match SseEvent::try_from((sse_event.clone(), client_api, upstream_api)) {
Ok(transformed) => {
// Successfully transformed - add to results
transformed_events.push(transformed);
}
Err(e) => {
// Check if this is incomplete JSON (EOF while parsing) vs other errors
let error_str = e.to_string().to_lowercase();
let is_incomplete_json = error_str.contains("eof while parsing")
|| error_str.contains("unexpected end of json")
|| error_str.contains("unexpected eof");
if is_incomplete_json {
// Incomplete JSON - buffer for retry with next chunk
self.incomplete_event_buffer = sse_event.raw_line.as_bytes().to_vec();
break;
} else {
// Other error (unsupported event type, validation error, etc.)
// Skip this event and continue processing others
continue;
}
}
}
}
Ok(transformed_events)
}
/// Check if there are buffered incomplete bytes
pub fn has_buffered_data(&self) -> bool {
!self.incomplete_event_buffer.is_empty()
}
/// Get the size of buffered incomplete data (for debugging/logging)
pub fn buffered_size(&self) -> usize {
self.incomplete_event_buffer.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::openai::OpenAIApi;
#[test]
fn test_complete_events_process_immediately() {
let mut processor = SseChunkProcessor::new();
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
assert_eq!(events.len(), 1);
assert!(!processor.has_buffered_data());
}
#[test]
fn test_incomplete_json_buffered_and_completed() {
let mut processor = SseChunkProcessor::new();
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// First chunk with incomplete JSON
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chu";
let events1 = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
assert_eq!(events1.len(), 0, "Incomplete event should not be processed");
assert!(processor.has_buffered_data(), "Incomplete data should be buffered");
// Second chunk completes the JSON
let chunk2 = b"nk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events2 = processor.process_chunk(chunk2, &client_api, &upstream_api).unwrap();
assert_eq!(events2.len(), 1, "Complete event should be processed");
assert!(!processor.has_buffered_data(), "Buffer should be cleared after completion");
}
#[test]
fn test_multiple_events_with_one_incomplete() {
let mut processor = SseChunkProcessor::new();
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Chunk with 2 complete events and 1 incomplete
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"A\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"B\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-125\",\"object\":\"chat.completion.chu";
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
assert_eq!(events.len(), 2, "Two complete events should be processed");
assert!(processor.has_buffered_data(), "Incomplete third event should be buffered");
}
#[test]
fn test_anthropic_signature_delta_from_production_logs() {
use crate::apis::anthropic::AnthropicApi;
let mut processor = SseChunkProcessor::new();
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
// Exact chunk from production logs - signature_delta event followed by content_block_stop
let chunk = br#"event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":"ErECCkYIChgCKkC7lAf/BOatd0I4NnANYNEDKl5/WSsjNK44AETnLoy3i5FfdYMAb0m4qMLJD6A04QnM4Hf3VpGqq/snA/9vvNxCEgw3CYcHcj0aTdqOisQaDOhlVBtAUKkoh3WopSIwAbJp4jG/41vVWBj63eaR7KFJ37OdY1byjlPkaGDUJRcWc/YfUWIDSAToomq2fB4VKpgBk+swVYxLZ709gQvyTCT+3vO/I+yexZpkx6eBl/+YCgQXTeviZ+hTxSoPVayf5vEQoc19ZA4MEkZ7yBInRgk8vUxAJITSf+vOvDIBsElpgkLfSjARCasjh78wONg39AkAoIbKzU+Q2l1htUwXcqQ2b+b5DrY9+Oxae4pBVGQlWU36XAHsa/KG+ejfdwhWJM7FNL3uphwAf0oYAQ=="}}
event: content_block_stop
data: {"type":"content_block_stop","index":0}
"#;
let result = processor.process_chunk(chunk, &client_api, &upstream_api);
match result {
Ok(events) => {
println!("Successfully processed {} events", events.len());
for (i, event) in events.iter().enumerate() {
println!("Event {}: event={:?}, has_data={}", i, event.event, event.data.is_some());
}
// Should successfully process both events (signature_delta + content_block_stop)
assert!(events.len() >= 2, "Should process at least 2 complete events (signature_delta + stop), got {}", events.len());
assert!(!processor.has_buffered_data(), "Complete events should not be buffered");
}
Err(e) => {
panic!("Failed to process signature_delta chunk - this means SignatureDelta is not properly handled: {}", e);
}
}
}
#[test]
fn test_unsupported_event_does_not_block_subsequent_events() {
let mut processor = SseChunkProcessor::new();
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Chunk with an unsupported/invalid event followed by a valid event
// First event has invalid JSON structure that will fail validation (not incomplete)
// Second event is valid and should be processed
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"unsupported_field_causing_validation_error\":true},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
// Should skip the invalid event and process the valid one
// (If we were buffering all errors, we'd get 0 events and have buffered data)
assert!(events.len() >= 1, "Should process at least the valid event, got {} events", events.len());
assert!(!processor.has_buffered_data(), "Invalid (non-incomplete) events should not be buffered");
}
#[test]
fn test_unknown_delta_type_skipped_others_processed() {
use crate::apis::anthropic::AnthropicApi;
let mut processor = SseChunkProcessor::new();
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
// Chunk with valid event, unsupported delta type, then another valid event
// This simulates a future API change where Anthropic adds a new delta type we don't support yet
let chunk = br#"event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"future_unsupported_delta","future_field":"some_value"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" World"}}
"#;
let result = processor.process_chunk(chunk, &client_api, &upstream_api);
match result {
Ok(events) => {
println!("Processed {} events (unsupported event should be skipped)", events.len());
// Should process the 2 valid text_delta events and skip the unsupported one
// We expect at least 2 events (the valid ones), unsupported should be skipped
assert!(events.len() >= 2, "Should process at least 2 valid events, got {}", events.len());
assert!(!processor.has_buffered_data(), "Unsupported events should be skipped, not buffered");
}
Err(e) => {
panic!("Should not fail on unsupported delta type, should skip it: {}", e);
}
}
}
}

View file

@ -4,9 +4,10 @@ use std::fmt;
/// Unified enum representing all supported API endpoints across providers
#[derive(Debug, Clone, PartialEq)]
pub enum SupportedAPIs {
pub enum SupportedAPIsFromClient {
OpenAIChatCompletions(OpenAIApi),
AnthropicMessagesAPI(AnthropicApi),
OpenAIResponsesAPI(OpenAIApi),
}
#[derive(Debug, Clone, PartialEq)]
@ -15,17 +16,21 @@ pub enum SupportedUpstreamAPIs {
AnthropicMessagesAPI(AnthropicApi),
AmazonBedrockConverse(AmazonBedrockApi),
AmazonBedrockConverseStream(AmazonBedrockApi),
OpenAIResponsesAPI(OpenAIApi),
}
impl fmt::Display for SupportedAPIs {
impl fmt::Display for SupportedAPIsFromClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SupportedAPIs::OpenAIChatCompletions(api) => {
SupportedAPIsFromClient::OpenAIChatCompletions(api) => {
write!(f, "OpenAI ({})", api.endpoint())
}
SupportedAPIs::AnthropicMessagesAPI(api) => {
SupportedAPIsFromClient::AnthropicMessagesAPI(api) => {
write!(f, "Anthropic AI ({})", api.endpoint())
}
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => {
write!(f, "OpenAI Responses ({})", api.endpoint())
}
}
}
}
@ -45,19 +50,27 @@ impl fmt::Display for SupportedUpstreamAPIs {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(api) => {
write!(f, "Amazon Bedrock ({})", api.endpoint())
}
SupportedUpstreamAPIs::OpenAIResponsesAPI(api) => {
write!(f, "OpenAI Responses ({})", api.endpoint())
}
}
}
}
impl SupportedAPIs {
impl SupportedAPIsFromClient {
/// Create a SupportedApi from an endpoint path
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
return Some(SupportedAPIs::OpenAIChatCompletions(openai_api));
// Check if this is the Responses API endpoint
if openai_api == OpenAIApi::Responses {
return Some(SupportedAPIsFromClient::OpenAIResponsesAPI(openai_api));
}
// Otherwise it's ChatCompletions
return Some(SupportedAPIsFromClient::OpenAIChatCompletions(openai_api));
}
if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) {
return Some(SupportedAPIs::AnthropicMessagesAPI(anthropic_api));
return Some(SupportedAPIsFromClient::AnthropicMessagesAPI(anthropic_api));
}
None
@ -66,8 +79,9 @@ impl SupportedAPIs {
/// Get the endpoint path for this API
pub fn endpoint(&self) -> &'static str {
match self {
SupportedAPIs::OpenAIChatCompletions(api) => api.endpoint(),
SupportedAPIs::AnthropicMessagesAPI(api) => api.endpoint(),
SupportedAPIsFromClient::OpenAIChatCompletions(api) => api.endpoint(),
SupportedAPIsFromClient::AnthropicMessagesAPI(api) => api.endpoint(),
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => api.endpoint(),
}
}
@ -94,8 +108,62 @@ impl SupportedAPIs {
}
};
// Helper function to route based on provider with a specific endpoint suffix
let route_by_provider = |endpoint_suffix: &str| -> String {
match provider_id {
ProviderId::Groq => {
if request_path.starts_with("/v1/") {
build_endpoint("/openai", request_path)
} else {
build_endpoint("/v1", endpoint_suffix)
}
}
ProviderId::Zhipu => {
if request_path.starts_with("/v1/") {
build_endpoint("/api/paas/v4", endpoint_suffix)
} else {
build_endpoint("/v1", endpoint_suffix)
}
}
ProviderId::Qwen => {
if request_path.starts_with("/v1/") {
build_endpoint("/compatible-mode/v1", endpoint_suffix)
} else {
build_endpoint("/v1", endpoint_suffix)
}
}
ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") {
let suffix = endpoint_suffix.trim_start_matches('/');
build_endpoint("/openai/deployments", &format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix))
} else {
build_endpoint("/v1", endpoint_suffix)
}
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
build_endpoint("/v1beta/openai", endpoint_suffix)
} else {
build_endpoint("/v1", endpoint_suffix)
}
}
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") {
if !is_streaming {
build_endpoint("", &format!("/model/{}/converse", model_id))
} else {
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
}
} else {
build_endpoint("/v1", endpoint_suffix)
}
}
_ => build_endpoint("/v1", endpoint_suffix),
}
};
match self {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") && !is_streaming {
@ -108,59 +176,57 @@ impl SupportedAPIs {
}
_ => build_endpoint("/v1", "/chat/completions"),
},
_ => match provider_id {
ProviderId::Groq => {
if request_path.starts_with("/v1/") {
build_endpoint("/openai", request_path)
} else {
build_endpoint("/v1", "/chat/completions")
}
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
// For Responses API, check if provider supports it, otherwise translate to chat/completions
match provider_id {
// OpenAI and compatible providers that support /v1/responses
ProviderId::OpenAI => route_by_provider("/responses"),
// All other providers: translate to /chat/completions
_ => route_by_provider("/chat/completions"),
}
ProviderId::Zhipu => {
if request_path.starts_with("/v1/") {
build_endpoint("/api/paas/v4", "/chat/completions")
} else {
build_endpoint("/v1", "/chat/completions")
}
}
ProviderId::Qwen => {
if request_path.starts_with("/v1/") {
build_endpoint("/compatible-mode/v1", "/chat/completions")
} else {
build_endpoint("/v1", "/chat/completions")
}
}
ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") {
build_endpoint("/openai/deployments", &format!("/{}/chat/completions?api-version=2025-01-01-preview", model_id))
} else {
build_endpoint("/v1", "/chat/completions")
}
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
build_endpoint("/v1beta/openai", "/chat/completions")
} else {
build_endpoint("/v1", "/chat/completions")
}
}
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") {
if !is_streaming {
build_endpoint("", &format!("/model/{}/converse", model_id))
} else {
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
}
} else {
build_endpoint("/v1", "/chat/completions")
}
}
_ => build_endpoint("/v1", "/chat/completions"),
},
}
SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
// For Chat Completions API, use the standard chat/completions path
route_by_provider("/chat/completions")
}
}
}
}
impl SupportedUpstreamAPIs {
/// Create a SupportedUpstreamApi from an endpoint path
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
// Check if this is the Responses API endpoint
if openai_api == OpenAIApi::Responses {
return Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(openai_api));
}
// Otherwise it's ChatCompletions
return Some(SupportedUpstreamAPIs::OpenAIChatCompletions(openai_api));
}
if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) {
return Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(anthropic_api));
}
if let Some(bedrock_api) = AmazonBedrockApi::from_endpoint(endpoint) {
match bedrock_api {
AmazonBedrockApi::Converse => {
return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api))
}
AmazonBedrockApi::ConverseStream => {
return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(bedrock_api))
}
}
}
None
}
}
/// Get all supported endpoint paths
pub fn supported_endpoints() -> Vec<&'static str> {
let mut endpoints = Vec::new();
@ -198,22 +264,23 @@ mod tests {
#[test]
fn test_is_supported_endpoint() {
// OpenAI endpoints
assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some());
assert!(SupportedAPIsFromClient::from_endpoint("/v1/chat/completions").is_some());
// Anthropic endpoints
assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
assert!(SupportedAPIsFromClient::from_endpoint("/v1/messages").is_some());
// Unsupported endpoints
assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some());
assert!(!SupportedAPIs::from_endpoint("").is_some());
assert!(!SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_some());
assert!(!SupportedAPIsFromClient::from_endpoint("/v2/chat").is_some());
assert!(!SupportedAPIsFromClient::from_endpoint("").is_some());
}
#[test]
fn test_supported_endpoints() {
let endpoints = supported_endpoints();
assert_eq!(endpoints.len(), 2); // We have 2 APIs defined
assert_eq!(endpoints.len(), 3); // We have 3 APIs defined
assert!(endpoints.contains(&"/v1/chat/completions"));
assert!(endpoints.contains(&"/v1/messages"));
assert!(endpoints.contains(&"/v1/responses"));
}
#[test]
@ -263,7 +330,7 @@ mod tests {
#[test]
fn test_target_endpoint_without_base_url_prefix() {
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Test default OpenAI provider
assert_eq!(
@ -340,7 +407,7 @@ mod tests {
#[test]
fn test_target_endpoint_with_base_url_prefix() {
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Test Zhipu with custom base_url_path_prefix
assert_eq!(
@ -405,7 +472,7 @@ mod tests {
#[test]
fn test_target_endpoint_with_empty_base_url_prefix() {
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Test with just slashes - trims to empty, uses provider default
assert_eq!(
@ -434,7 +501,7 @@ mod tests {
#[test]
fn test_amazon_bedrock_endpoints() {
let api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
let api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
// Test Bedrock non-streaming without prefix
assert_eq!(
@ -487,7 +554,7 @@ mod tests {
#[test]
fn test_anthropic_messages_endpoint() {
let api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
let api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
// Test Anthropic without prefix
assert_eq!(
@ -516,7 +583,7 @@ mod tests {
#[test]
fn test_non_v1_request_paths() {
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Test Groq with non-v1 path (should use default)
assert_eq!(
@ -557,7 +624,7 @@ mod tests {
#[test]
fn test_azure_openai_with_query_params() {
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
// Test Azure without prefix - should include query params
assert_eq!(

View file

@ -1,9 +1,8 @@
pub mod endpoints;
pub mod lib;
pub mod transformer;
// Re-export the main items for easier access
pub use endpoints::{identify_provider, SupportedAPIs};
pub use endpoints::*;
pub use lib::*;
// Note: transformer module contains TryFrom trait implementations that are automatically available

View file

@ -1,694 +0,0 @@
// Re-export new transformation modules for backward compatibility
//KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING
// ============================================================================
// TESTS
// ============================================================================
#[cfg(test)]
mod tests {
use crate::apis::anthropic::*;
use crate::apis::openai::*;
use crate::transforms::*;
use serde_json::json;
type AnthropicMessagesRequest = MessagesRequest;
#[test]
fn test_anthropic_to_openai_basic_request() {
let anthropic_req = AnthropicMessagesRequest {
model: "claude-3-sonnet-20240229".to_string(),
system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())),
messages: vec![MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Hello, world!".to_string()),
}],
max_tokens: 1024,
container: None,
mcp_servers: None,
service_tier: None,
thinking: None,
temperature: Some(0.7),
top_p: Some(0.9),
top_k: Some(50),
stream: Some(false),
stop_sequences: Some(vec!["STOP".to_string()]),
tools: None,
tool_choice: None,
metadata: None,
};
let openai_req: ChatCompletionsRequest = anthropic_req.try_into().unwrap();
assert_eq!(openai_req.model, "claude-3-sonnet-20240229");
assert_eq!(openai_req.messages.len(), 2); // system + user message
assert_eq!(openai_req.max_completion_tokens, Some(1024));
assert_eq!(openai_req.temperature, Some(0.7));
assert_eq!(openai_req.top_p, Some(0.9));
assert_eq!(openai_req.stream, Some(false));
assert_eq!(openai_req.stop, Some(vec!["STOP".to_string()]));
}
#[test]
fn test_roundtrip_consistency() {
// Test that converting back and forth maintains consistency
let original_anthropic = AnthropicMessagesRequest {
model: "claude-3-sonnet".to_string(),
system: Some(MessagesSystemPrompt::Single("System prompt".to_string())),
messages: vec![MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("User message".to_string()),
}],
max_tokens: 1000,
container: None,
mcp_servers: None,
service_tier: None,
thinking: None,
temperature: Some(0.5),
top_p: Some(1.0),
top_k: None,
stream: Some(false),
stop_sequences: None,
tools: None,
tool_choice: None,
metadata: None,
};
// Convert to OpenAI and back
let openai_req: ChatCompletionsRequest = original_anthropic.clone().try_into().unwrap();
let roundtrip_anthropic: AnthropicMessagesRequest = openai_req.try_into().unwrap();
// Check key fields are preserved
assert_eq!(original_anthropic.model, roundtrip_anthropic.model);
assert_eq!(
original_anthropic.max_tokens,
roundtrip_anthropic.max_tokens
);
assert_eq!(
original_anthropic.temperature,
roundtrip_anthropic.temperature
);
assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p);
assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream);
assert_eq!(
original_anthropic.messages.len(),
roundtrip_anthropic.messages.len()
);
}
#[test]
fn test_tool_choice_auto() {
let anthropic_req = AnthropicMessagesRequest {
model: "claude-3".to_string(),
system: None,
messages: vec![],
max_tokens: 100,
container: None,
mcp_servers: None,
service_tier: None,
thinking: None,
temperature: None,
top_p: None,
top_k: None,
stream: None,
stop_sequences: None,
tools: Some(vec![MessagesTool {
name: "test_tool".to_string(),
description: Some("A test tool".to_string()),
input_schema: json!({"type": "object"}),
}]),
tool_choice: Some(MessagesToolChoice {
kind: MessagesToolChoiceType::Auto,
name: None,
disable_parallel_tool_use: Some(true),
}),
metadata: None,
};
let openai_req: ChatCompletionsRequest = anthropic_req.try_into().unwrap();
assert!(openai_req.tools.is_some());
assert_eq!(openai_req.tools.as_ref().unwrap().len(), 1);
if let Some(ToolChoice::Type(choice)) = openai_req.tool_choice {
assert_eq!(choice, ToolChoiceType::Auto);
} else {
panic!("Expected auto tool choice");
}
assert_eq!(openai_req.parallel_tool_calls, Some(false));
}
#[test]
fn test_default_max_tokens_used_when_openai_has_none() {
// Test that DEFAULT_MAX_TOKENS is used when OpenAI request has no max_tokens
let openai_req = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("Hello".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}],
max_tokens: None, // No max_tokens specified
..Default::default()
};
let anthropic_req: AnthropicMessagesRequest = openai_req.try_into().unwrap();
assert_eq!(anthropic_req.max_tokens, DEFAULT_MAX_TOKENS);
}
#[test]
fn test_anthropic_message_start_streaming() {
let event = MessagesStreamEvent::MessageStart {
message: MessagesStreamMessage {
id: "msg_stream_123".to_string(),
obj_type: "message".to_string(),
role: MessagesRole::Assistant,
content: vec![],
model: "claude-3".to_string(),
stop_reason: None,
stop_sequence: None,
usage: MessagesUsage {
input_tokens: 5,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
},
};
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
assert_eq!(openai_resp.id, "msg_stream_123");
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
assert_eq!(openai_resp.model, "claude-3");
assert_eq!(openai_resp.choices.len(), 1);
let choice = &openai_resp.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.delta.role, Some(Role::Assistant));
assert_eq!(choice.delta.content, None);
assert_eq!(choice.finish_reason, None);
}
#[test]
fn test_anthropic_content_block_delta_streaming() {
let event = MessagesStreamEvent::ContentBlockDelta {
index: 0,
delta: MessagesContentDelta::TextDelta {
text: "Hello, world!".to_string(),
},
};
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
assert_eq!(openai_resp.choices.len(), 1);
let choice = &openai_resp.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.delta.content, Some("Hello, world!".to_string()));
assert_eq!(choice.delta.role, None);
assert_eq!(choice.finish_reason, None);
}
#[test]
fn test_anthropic_tool_use_streaming() {
// Test tool use start
let tool_start = MessagesStreamEvent::ContentBlockStart {
index: 0,
content_block: MessagesContentBlock::ToolUse {
id: "call_123".to_string(),
name: "get_weather".to_string(),
input: json!({}),
cache_control: None,
},
};
let openai_resp: ChatCompletionsStreamResponse = tool_start.try_into().unwrap();
assert_eq!(openai_resp.choices.len(), 1);
let choice = &openai_resp.choices[0];
assert!(choice.delta.tool_calls.is_some());
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, Some("call_123".to_string()));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name,
Some("get_weather".to_string())
);
}
#[test]
fn test_anthropic_tool_input_delta_streaming() {
let event = MessagesStreamEvent::ContentBlockDelta {
index: 0,
delta: MessagesContentDelta::InputJsonDelta {
partial_json: r#"{"location": "San Francisco"#.to_string(),
},
};
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
assert_eq!(openai_resp.choices.len(), 1);
let choice = &openai_resp.choices[0];
assert!(choice.delta.tool_calls.is_some());
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].function.as_ref().unwrap().arguments,
Some(r#"{"location": "San Francisco"#.to_string())
);
}
#[test]
fn test_anthropic_message_delta_with_usage() {
let event = MessagesStreamEvent::MessageDelta {
delta: MessagesMessageDelta {
stop_reason: MessagesStopReason::EndTurn,
stop_sequence: None,
},
usage: MessagesUsage {
input_tokens: 10,
output_tokens: 25,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
};
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
assert_eq!(openai_resp.choices.len(), 1);
let choice = &openai_resp.choices[0];
assert_eq!(choice.finish_reason, Some(FinishReason::Stop));
assert!(openai_resp.usage.is_some());
let usage = openai_resp.usage.unwrap();
assert_eq!(usage.prompt_tokens, 10);
assert_eq!(usage.completion_tokens, 25);
assert_eq!(usage.total_tokens, 35);
}
#[test]
fn test_anthropic_message_stop_streaming() {
let event = MessagesStreamEvent::MessageStop;
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
assert_eq!(openai_resp.choices.len(), 1);
let choice = &openai_resp.choices[0];
assert_eq!(choice.finish_reason, Some(FinishReason::Stop));
}
#[test]
fn test_anthropic_ping_streaming() {
let event = MessagesStreamEvent::Ping;
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
assert_eq!(openai_resp.choices.len(), 0); // Ping has no choices
}
#[test]
fn test_openai_to_anthropic_streaming_role_start() {
let openai_resp = ChatCompletionsStreamResponse {
id: "chatcmpl-123".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: 1234567890,
model: "gpt-4".to_string(),
choices: vec![StreamChoice {
index: 0,
delta: MessageDelta {
role: Some(Role::Assistant),
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
finish_reason: None,
logprobs: None,
}],
usage: None,
system_fingerprint: None,
service_tier: None,
};
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match anthropic_event {
MessagesStreamEvent::MessageStart { message } => {
assert_eq!(message.id, "chatcmpl-123");
assert_eq!(message.role, MessagesRole::Assistant);
assert_eq!(message.model, "gpt-4");
}
_ => panic!("Expected MessageStart event"),
}
}
#[test]
fn test_openai_to_anthropic_streaming_content_delta() {
let openai_resp = ChatCompletionsStreamResponse {
id: "chatcmpl-123".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: 1234567890,
model: "gpt-4".to_string(),
choices: vec![StreamChoice {
index: 0,
delta: MessageDelta {
role: None,
content: Some("Hello there!".to_string()),
refusal: None,
function_call: None,
tool_calls: None,
},
finish_reason: None,
logprobs: None,
}],
usage: None,
system_fingerprint: None,
service_tier: None,
};
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match anthropic_event {
MessagesStreamEvent::ContentBlockDelta { index, delta } => {
assert_eq!(index, 0);
match delta {
MessagesContentDelta::TextDelta { text } => {
assert_eq!(text, "Hello there!");
}
_ => panic!("Expected TextDelta"),
}
}
_ => panic!("Expected ContentBlockDelta event"),
}
}
#[test]
fn test_openai_to_anthropic_streaming_tool_calls() {
let openai_resp = ChatCompletionsStreamResponse {
id: "chatcmpl-123".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: 1234567890,
model: "gpt-4".to_string(),
choices: vec![StreamChoice {
index: 0,
delta: MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: Some("call_abc123".to_string()),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some("get_current_weather".to_string()),
arguments: Some("".to_string()),
}),
}]),
},
finish_reason: None,
logprobs: None,
}],
usage: None,
system_fingerprint: None,
service_tier: None,
};
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match anthropic_event {
MessagesStreamEvent::ContentBlockStart {
index,
content_block,
} => {
assert_eq!(index, 0);
match content_block {
MessagesContentBlock::ToolUse { id, name, .. } => {
assert_eq!(id, "call_abc123");
assert_eq!(name, "get_current_weather");
}
_ => panic!("Expected ToolUse content block"),
}
}
_ => panic!("Expected ContentBlockStart event"),
}
}
#[test]
fn test_openai_to_anthropic_streaming_final_usage() {
let openai_resp = ChatCompletionsStreamResponse {
id: "chatcmpl-123".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: 1234567890,
model: "gpt-4".to_string(),
choices: vec![StreamChoice {
index: 0,
delta: MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: None,
}],
usage: Some(Usage {
prompt_tokens: 15,
completion_tokens: 30,
total_tokens: 45,
prompt_tokens_details: None,
completion_tokens_details: None,
}),
system_fingerprint: None,
service_tier: None,
};
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match anthropic_event {
MessagesStreamEvent::MessageDelta { delta, usage } => {
assert_eq!(delta.stop_reason, MessagesStopReason::EndTurn);
assert_eq!(usage.input_tokens, 15);
assert_eq!(usage.output_tokens, 30);
}
_ => panic!("Expected MessageDelta event"),
}
}
#[test]
fn test_openai_empty_choices_to_anthropic_ping() {
let openai_resp = ChatCompletionsStreamResponse {
id: "chatcmpl-123".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: 1234567890,
model: "gpt-4".to_string(),
choices: vec![], // Empty choices
usage: None,
system_fingerprint: None,
service_tier: None,
};
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match anthropic_event {
MessagesStreamEvent::Ping => {
// Expected behavior
}
_ => panic!("Expected Ping event for empty choices"),
}
}
#[test]
fn test_streaming_roundtrip_consistency() {
// Test that streaming events can roundtrip through conversions
let original_event = MessagesStreamEvent::ContentBlockDelta {
index: 0,
delta: MessagesContentDelta::TextDelta {
text: "Test message".to_string(),
},
};
// Convert to OpenAI and back
let openai_resp: ChatCompletionsStreamResponse = original_event.try_into().unwrap();
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
// Verify the roundtrip maintains the essential information
match roundtrip_event {
MessagesStreamEvent::ContentBlockDelta { index, delta } => {
assert_eq!(index, 0);
match delta {
MessagesContentDelta::TextDelta { text } => {
assert_eq!(text, "Test message");
}
_ => panic!("Expected TextDelta after roundtrip"),
}
}
_ => panic!("Expected ContentBlockDelta after roundtrip"),
}
}
#[test]
fn test_streaming_tool_argument_accumulation() {
// Test multiple tool argument deltas that should accumulate
let tool_start = MessagesStreamEvent::ContentBlockStart {
index: 0,
content_block: MessagesContentBlock::ToolUse {
id: "call_weather".to_string(),
name: "get_weather".to_string(),
input: json!({}),
cache_control: None,
},
};
let arg_delta1 = MessagesStreamEvent::ContentBlockDelta {
index: 0,
delta: MessagesContentDelta::InputJsonDelta {
partial_json: r#"{"location": "#.to_string(),
},
};
let arg_delta2 = MessagesStreamEvent::ContentBlockDelta {
index: 0,
delta: MessagesContentDelta::InputJsonDelta {
partial_json: r#"San Francisco", "unit": "fahrenheit"}"#.to_string(),
},
};
// Test that each delta converts properly to OpenAI format
let openai_start: ChatCompletionsStreamResponse = tool_start.try_into().unwrap();
let openai_delta1: ChatCompletionsStreamResponse = arg_delta1.try_into().unwrap();
let openai_delta2: ChatCompletionsStreamResponse = arg_delta2.try_into().unwrap();
// Verify tool start
let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls[0].id, Some("call_weather".to_string()));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name,
Some("get_weather".to_string())
);
// Verify argument deltas
let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0]
.function
.as_ref()
.unwrap()
.arguments;
assert_eq!(args1, &Some(r#"{"location": "#.to_string()));
let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0]
.function
.as_ref()
.unwrap()
.arguments;
assert_eq!(
args2,
&Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string())
);
}
#[test]
fn test_streaming_multiple_finish_reasons() {
// Test different finish reasons in streaming
let test_cases = vec![
(MessagesStopReason::EndTurn, FinishReason::Stop),
(MessagesStopReason::MaxTokens, FinishReason::Length),
(MessagesStopReason::ToolUse, FinishReason::ToolCalls),
(MessagesStopReason::StopSequence, FinishReason::Stop),
];
for (anthropic_reason, expected_openai_reason) in test_cases {
let event = MessagesStreamEvent::MessageDelta {
delta: MessagesMessageDelta {
stop_reason: anthropic_reason.clone(),
stop_sequence: None,
},
usage: MessagesUsage {
input_tokens: 10,
output_tokens: 20,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
};
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
assert_eq!(
openai_resp.choices[0].finish_reason,
Some(expected_openai_reason)
);
// Test reverse conversion
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
match roundtrip_event {
MessagesStreamEvent::MessageDelta { delta, .. } => {
// Note: Some precision may be lost in roundtrip due to mapping differences
assert!(matches!(
delta.stop_reason,
MessagesStopReason::EndTurn
| MessagesStopReason::MaxTokens
| MessagesStopReason::ToolUse
| MessagesStopReason::StopSequence
));
}
_ => panic!("Expected MessageDelta after roundtrip"),
}
}
}
#[test]
fn test_streaming_error_handling() {
// Test that malformed streaming events are handled gracefully
let openai_resp_with_missing_data = ChatCompletionsStreamResponse {
id: "test".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: 1234567890,
model: "test".to_string(),
choices: vec![StreamChoice {
index: 0,
delta: MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
finish_reason: None,
logprobs: None,
}],
usage: None,
system_fingerprint: None,
service_tier: None,
};
// Should convert to Ping when no meaningful content
let anthropic_event: MessagesStreamEvent =
openai_resp_with_missing_data.try_into().unwrap();
assert!(matches!(anthropic_event, MessagesStreamEvent::Ping));
}
#[test]
fn test_streaming_content_block_stop() {
let event = MessagesStreamEvent::ContentBlockStop { index: 0 };
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
// ContentBlockStop should produce an empty chunk
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
assert_eq!(openai_resp.choices.len(), 1);
let choice = &openai_resp.choices[0];
assert_eq!(choice.delta.role, None);
assert_eq!(choice.delta.content, None);
assert_eq!(choice.delta.tool_calls, None);
assert_eq!(choice.finish_reason, None);
}
}

View file

@ -6,18 +6,21 @@ pub mod clients;
pub mod providers;
pub mod transforms;
// Re-export important types and traits
pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
pub use apis::sse::{SseEvent, SseStreamIter};
pub use apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
pub use apis::streaming_shapes::sse::{SseEvent, SseStreamIter};
pub use aws_smithy_eventstream::frame::DecodedFrame;
pub use providers::id::ProviderId;
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
pub use providers::response::{
ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse,
ProviderStreamResponseType, TokenUsage,
ProviderResponse, ProviderResponseType, TokenUsage, ProviderResponseError
};
pub use providers::streaming_response::{
ProviderStreamResponse, ProviderStreamResponseType
};
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses";
pub const MESSAGES_PATH: &str = "/v1/messages";
#[cfg(test)]
@ -42,9 +45,9 @@ mod tests {
data: [DONE]
"#;
use crate::clients::endpoints::SupportedAPIs;
use crate::clients::endpoints::SupportedAPIsFromClient;
let client_api =
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
SupportedAPIsFromClient::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
let upstream_api =
SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
@ -79,9 +82,16 @@ mod tests {
assert_eq!(stream_response.content_delta(), Some("Hello"));
assert!(!stream_response.is_final());
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
// Test that stream ends properly with [DONE]
// The iterator should return the [DONE] event, then None
let done_event = streaming_iter.next();
assert!(done_event.is_some(), "Should get [DONE] event");
let done_event = done_event.unwrap();
assert!(done_event.is_done(), "[DONE] event should be marked as done");
// After [DONE], iterator should return None
let final_event = streaming_iter.next();
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
assert!(final_event.is_none(), "Iterator should return None after [DONE]");
}
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses.

View file

@ -1,5 +1,5 @@
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs};
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use std::fmt::Display;
/// Provider identifier enum - simple enum for identifying providers
@ -51,19 +51,24 @@ impl ProviderId {
/// Given a client API, return the compatible upstream API for this provider
pub fn compatible_api_for_client(
&self,
client_api: &SupportedAPIs,
client_api: &SupportedAPIsFromClient,
is_streaming: bool,
) -> SupportedUpstreamAPIs {
match (self, client_api) {
// Claude/Anthropic providers natively support Anthropic APIs
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
(ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
}
(
ProviderId::Anthropic,
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
SupportedAPIsFromClient::OpenAIChatCompletions(_),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// Anthropic doesn't support Responses API, fall back to chat completions
(ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
}
// OpenAI-compatible providers only support OpenAI chat completions
(
ProviderId::OpenAI
@ -80,7 +85,7 @@ impl ProviderId {
| ProviderId::Moonshotai
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::AnthropicMessagesAPI(_),
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(
@ -98,11 +103,16 @@ impl ProviderId {
| ProviderId::Moonshotai
| ProviderId::Zhipu
| ProviderId::Qwen,
SupportedAPIs::OpenAIChatCompletions(_),
SupportedAPIsFromClient::OpenAIChatCompletions(_),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// OpenAI Responses API - only OpenAI supports this
(ProviderId::OpenAI, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses)
}
// Amazon Bedrock natively supports Bedrock APIs
(ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => {
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {
if is_streaming {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
AmazonBedrockApi::ConverseStream,
@ -111,7 +121,7 @@ impl ProviderId {
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
}
}
(ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => {
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
if is_streaming {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
AmazonBedrockApi::ConverseStream,
@ -120,6 +130,20 @@ impl ProviderId {
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
}
}
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
if is_streaming {
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
AmazonBedrockApi::ConverseStream,
)
} else {
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
}
}
// Non-OpenAI providers: if client requested the Responses API, fall back to Chat Completions
(_, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
}
}
}
}

View file

@ -6,7 +6,9 @@
pub mod id;
pub mod request;
pub mod response;
pub mod streaming_response;
pub use id::ProviderId;
pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage};
pub use response::{ProviderResponse, ProviderResponseType, TokenUsage};
pub use streaming_response::{ProviderStreamResponse, ProviderStreamResponseType};

View file

@ -2,19 +2,21 @@ use crate::apis::anthropic::MessagesRequest;
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
use crate::clients::endpoints::SupportedAPIs;
use crate::apis::openai_responses::ResponsesAPIRequest;
use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::clients::endpoints::SupportedUpstreamAPIs;
use serde_json::Value;
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum ProviderRequestType {
ChatCompletionsRequest(ChatCompletionsRequest),
MessagesRequest(MessagesRequest),
BedrockConverse(ConverseRequest),
BedrockConverseStream(ConverseStreamRequest),
ResponsesAPIRequest(ResponsesAPIRequest),
//add more request types here
}
pub trait ProviderRequest: Send + Sync {
@ -33,6 +35,9 @@ pub trait ProviderRequest: Send + Sync {
/// Extract the user message for tracing/logging purposes
fn get_recent_user_message(&self) -> Option<String>;
/// Get tool names if tools are defined in the request
fn get_tool_names(&self) -> Option<Vec<String>>;
/// Convert the request to bytes for transmission
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
@ -40,6 +45,30 @@ pub trait ProviderRequest: Send + Sync {
/// Remove a metadata key from the request and return true if the key was present
fn remove_metadata_key(&mut self, key: &str) -> bool;
fn get_temperature(&self) -> Option<f32>;
/// Get message history as OpenAI Message format
/// This is useful for processing chat history across different provider formats
fn get_messages(&self) -> Vec<crate::apis::openai::Message>;
/// Set message history from OpenAI Message format
/// This converts OpenAI messages to the appropriate format for each provider type
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]);
}
impl ProviderRequestType {
/// Set message history from OpenAI Message format
/// This converts OpenAI messages to the appropriate format for each provider type
pub fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
match self {
Self::ChatCompletionsRequest(r) => r.set_messages(messages),
Self::MessagesRequest(r) => r.set_messages(messages),
Self::BedrockConverse(r) => r.set_messages(messages),
Self::BedrockConverseStream(r) => r.set_messages(messages),
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
}
}
}
impl ProviderRequest for ProviderRequestType {
@ -49,6 +78,7 @@ impl ProviderRequest for ProviderRequestType {
Self::MessagesRequest(r) => r.model(),
Self::BedrockConverse(r) => r.model(),
Self::BedrockConverseStream(r) => r.model(),
Self::ResponsesAPIRequest(r) => r.model(),
}
}
@ -58,6 +88,7 @@ impl ProviderRequest for ProviderRequestType {
Self::MessagesRequest(r) => r.set_model(model),
Self::BedrockConverse(r) => r.set_model(model),
Self::BedrockConverseStream(r) => r.set_model(model),
Self::ResponsesAPIRequest(r) => r.set_model(model),
}
}
@ -67,6 +98,7 @@ impl ProviderRequest for ProviderRequestType {
Self::MessagesRequest(r) => r.is_streaming(),
Self::BedrockConverse(_) => false,
Self::BedrockConverseStream(_) => true,
Self::ResponsesAPIRequest(r) => r.is_streaming(),
}
}
@ -76,6 +108,7 @@ impl ProviderRequest for ProviderRequestType {
Self::MessagesRequest(r) => r.extract_messages_text(),
Self::BedrockConverse(r) => r.extract_messages_text(),
Self::BedrockConverseStream(r) => r.extract_messages_text(),
Self::ResponsesAPIRequest(r) => r.extract_messages_text(),
}
}
@ -85,6 +118,17 @@ impl ProviderRequest for ProviderRequestType {
Self::MessagesRequest(r) => r.get_recent_user_message(),
Self::BedrockConverse(r) => r.get_recent_user_message(),
Self::BedrockConverseStream(r) => r.get_recent_user_message(),
Self::ResponsesAPIRequest(r) => r.get_recent_user_message(),
}
}
fn get_tool_names(&self) -> Option<Vec<String>> {
match self {
Self::ChatCompletionsRequest(r) => r.get_tool_names(),
Self::MessagesRequest(r) => r.get_tool_names(),
Self::BedrockConverse(r) => r.get_tool_names(),
Self::BedrockConverseStream(r) => r.get_tool_names(),
Self::ResponsesAPIRequest(r) => r.get_tool_names(),
}
}
@ -94,6 +138,7 @@ impl ProviderRequest for ProviderRequestType {
Self::MessagesRequest(r) => r.to_bytes(),
Self::BedrockConverse(r) => r.to_bytes(),
Self::BedrockConverseStream(r) => r.to_bytes(),
Self::ResponsesAPIRequest(r) => r.to_bytes(),
}
}
@ -103,6 +148,7 @@ impl ProviderRequest for ProviderRequestType {
Self::MessagesRequest(r) => r.metadata(),
Self::BedrockConverse(r) => r.metadata(),
Self::BedrockConverseStream(r) => r.metadata(),
Self::ResponsesAPIRequest(r) => r.metadata(),
}
}
@ -112,18 +158,49 @@ impl ProviderRequest for ProviderRequestType {
Self::MessagesRequest(r) => r.remove_metadata_key(key),
Self::BedrockConverse(r) => r.remove_metadata_key(key),
Self::BedrockConverseStream(r) => r.remove_metadata_key(key),
Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key),
}
}
fn get_temperature(&self) -> Option<f32> {
match self {
Self::ChatCompletionsRequest(r) => r.get_temperature(),
Self::MessagesRequest(r) => r.get_temperature(),
Self::BedrockConverse(r) => r.get_temperature(),
Self::BedrockConverseStream(r) => r.get_temperature(),
Self::ResponsesAPIRequest(r) => r.get_temperature(),
}
}
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
match self {
Self::ChatCompletionsRequest(r) => r.get_messages(),
Self::MessagesRequest(r) => r.get_messages(),
Self::BedrockConverse(r) => r.get_messages(),
Self::BedrockConverseStream(r) => r.get_messages(),
Self::ResponsesAPIRequest(r) => r.get_messages(),
}
}
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
match self {
Self::ChatCompletionsRequest(r) => r.set_messages(messages),
Self::MessagesRequest(r) => r.set_messages(messages),
Self::BedrockConverse(r) => r.set_messages(messages),
Self::BedrockConverseStream(r) => r.set_messages(messages),
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
}
}
}
/// Parse the client API from a byte slice.
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType {
type Error = std::io::Error;
fn try_from((bytes, client_api): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
fn try_from((bytes, client_api): (&[u8], &SupportedAPIsFromClient)) -> Result<Self, Self::Error> {
// Use SupportedApi to determine the appropriate request type
match client_api {
SupportedAPIs::OpenAIChatCompletions(_) => {
SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
let chat_completion_request: ChatCompletionsRequest =
ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
@ -131,11 +208,20 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
chat_completion_request,
))
}
SupportedAPIs::AnthropicMessagesAPI(_) => {
SupportedAPIsFromClient::AnthropicMessagesAPI(_) => {
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::MessagesRequest(messages_request))
}
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
let responses_apirequest: ResponsesAPIRequest =
ResponsesAPIRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::ResponsesAPIRequest(
responses_apirequest,
))
}
}
}
}
@ -148,17 +234,13 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
(client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs),
) -> Result<Self, Self::Error> {
match (client_request, upstream_api) {
// Same API - no conversion needed, just clone the reference
// ============================================================================
// ChatCompletionsRequest conversions
// ============================================================================
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
// Cross-API conversion - cloning is necessary for transformation
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
@ -173,7 +255,45 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
})?;
Ok(ProviderRequestType::MessagesRequest(messages_req))
}
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
) => {
let bedrock_req = ConverseRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
source: Some(Box::new(e))
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
) => {
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock Stream request: {}", e),
source: Some(Box::new(e))
})?;
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
}
(
ProviderRequestType::ChatCompletionsRequest(_),
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
) => {
Err(ProviderRequestError {
message: "Conversion from ChatCompletionsRequest to ResponsesAPIRequest is not supported. ResponsesAPI can only be used as a client API, not as an upstream API.".to_string(),
source: None,
})
}
// ============================================================================
// MessagesRequest conversions
// ============================================================================
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
@ -189,31 +309,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
})?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
}
// Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
) => {
let bedrock_req = ConverseRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
source: Some(Box::new(e))
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
(
ProviderRequestType::ChatCompletionsRequest(chat_req),
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
) => {
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
.map_err(|e| ProviderRequestError {
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
source: Some(Box::new(e))
})?;
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
(
ProviderRequestType::MessagesRequest(messages_req),
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
@ -235,7 +330,97 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
"Failed to convert MessagesRequest to Amazon Bedrock Stream request: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
}
(
ProviderRequestType::MessagesRequest(_),
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
) => {
Err(ProviderRequestError {
message: "Conversion from MessagesRequest to ResponsesAPIRequest is not supported. ResponsesAPI can only be used as a client API, not as an upstream API.".to_string(),
source: None,
})
}
// ============================================================================
// ResponsesAPIRequest conversions (only converts TO other formats)
// ============================================================================
(
ProviderRequestType::ResponsesAPIRequest(responses_req),
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
) => Ok(ProviderRequestType::ResponsesAPIRequest(responses_req)),
// ResponsesAPI -> ChatCompletions (direct conversion)
(
ProviderRequestType::ResponsesAPIRequest(responses_req),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => {
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
}
// ResponsesAPI -> Anthropic Messages (via ChatCompletions)
(
ProviderRequestType::ResponsesAPIRequest(responses_req),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => {
// Chain: ResponsesAPI -> ChatCompletions -> MessagesRequest
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
let messages_req = MessagesRequest::try_from(chat_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to MessagesRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::MessagesRequest(messages_req))
}
// ResponsesAPI -> Bedrock Converse (via ChatCompletions)
(
ProviderRequestType::ResponsesAPIRequest(responses_req),
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
) => {
// Chain: ResponsesAPI -> ChatCompletions -> ConverseRequest
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
let bedrock_req = ConverseRequest::try_from(chat_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}",
e
),
source: Some(Box::new(e)),
@ -244,13 +429,50 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
}
// Amazon Bedrock to other APIs conversions
// ResponsesAPI -> Bedrock Converse Stream (via ChatCompletions)
(
ProviderRequestType::ResponsesAPIRequest(responses_req),
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
) => {
// Chain: ResponsesAPI -> ChatCompletions -> ConverseStreamRequest
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
e
),
source: Some(Box::new(e)),
}
})?;
let bedrock_req = ConverseStreamRequest::try_from(chat_req).map_err(|e| {
ProviderRequestError {
message: format!(
"Failed to convert ChatCompletionsRequest to Amazon Bedrock Stream request: {}",
e
),
source: Some(Box::new(e)),
}
})?;
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
}
// ============================================================================
// Amazon Bedrock conversions (not supported as client API)
// ============================================================================
(ProviderRequestType::BedrockConverse(_), _) => {
todo!("Amazon Bedrock to ChatCompletionsRequest conversion not implemented yet")
Err(ProviderRequestError {
message: "Amazon Bedrock Converse is not supported as a client API. Only OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses APIs are supported as client APIs.".to_string(),
source: None,
})
}
(ProviderRequestType::BedrockConverseStream(_), _) => {
todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet")
Err(ProviderRequestError {
message: "Amazon Bedrock Converse Stream is not supported as a client API. Only OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses APIs are supported as client APIs.".to_string(),
source: None,
})
}
}
}
@ -284,7 +506,7 @@ mod tests {
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::clients::endpoints::SupportedAPIs;
use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::transforms::lib::ExtractText;
use serde_json::json;
@ -298,7 +520,7 @@ mod tests {
]
});
let bytes = serde_json::to_vec(&req).unwrap();
let api = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
let api = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions);
let result = ProviderRequestType::try_from((bytes.as_slice(), &api));
assert!(result.is_ok());
match result.unwrap() {
@ -321,7 +543,7 @@ mod tests {
]
});
let bytes = serde_json::to_vec(&req).unwrap();
let endpoint = SupportedAPIs::AnthropicMessagesAPI(Messages);
let endpoint = SupportedAPIsFromClient::AnthropicMessagesAPI(Messages);
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
assert!(result.is_ok());
match result.unwrap() {
@ -343,7 +565,7 @@ mod tests {
]
});
let bytes = serde_json::to_vec(&req).unwrap();
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
let endpoint = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions);
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
assert!(result.is_ok());
match result.unwrap() {
@ -366,7 +588,7 @@ mod tests {
});
let bytes = serde_json::to_vec(&req).unwrap();
// Intentionally use OpenAI endpoint for Anthropic payload
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
let endpoint = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions);
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
// Should parse as ChatCompletionsRequest, not error
assert!(result.is_ok());
@ -486,4 +708,399 @@ mod tests {
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
assert_eq!(original_max_tokens, roundtrip_max_tokens);
}
#[test]
fn test_responses_api_request_from_bytes() {
use crate::apis::openai::OpenAIApi::Responses;
let req = json!({
"model": "gpt-4o",
"input": "Hello, how are you?"
});
let bytes = serde_json::to_vec(&req).unwrap();
let api = SupportedAPIsFromClient::OpenAIResponsesAPI(Responses);
let result = ProviderRequestType::try_from((bytes.as_slice(), &api));
assert!(result.is_ok());
match result.unwrap() {
ProviderRequestType::ResponsesAPIRequest(r) => {
assert_eq!(r.model, "gpt-4o");
}
_ => panic!("Expected ResponsesAPIRequest variant"),
}
}
#[test]
fn test_responses_api_to_chat_completions_conversion() {
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
let responses_req = ResponsesAPIRequest {
model: "gpt-4o".to_string(),
input: InputParam::Text("Hello, world!".to_string()),
temperature: Some(0.7),
top_p: Some(0.9),
max_output_tokens: Some(100),
stream: Some(false),
metadata: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
instructions: None,
modalities: None,
user: None,
store: None,
reasoning_effort: None,
include: None,
audio: None,
text: None,
service_tier: None,
top_logprobs: None,
stream_options: None,
truncation: None,
conversation: None,
previous_response_id: None,
max_tool_calls: None,
background: None,
};
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(ChatCompletions);
let result = ProviderRequestType::try_from((
ProviderRequestType::ResponsesAPIRequest(responses_req),
&upstream_api,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderRequestType::ChatCompletionsRequest(chat_req) => {
assert_eq!(chat_req.model, "gpt-4o");
assert_eq!(chat_req.temperature, Some(0.7));
assert_eq!(chat_req.top_p, Some(0.9));
assert_eq!(chat_req.max_completion_tokens, Some(100));
assert_eq!(chat_req.messages.len(), 1);
}
_ => panic!("Expected ChatCompletionsRequest variant"),
}
}
#[test]
fn test_responses_api_to_anthropic_messages_conversion() {
use crate::apis::anthropic::AnthropicApi::Messages;
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
let responses_req = ResponsesAPIRequest {
model: "gpt-4o".to_string(),
input: InputParam::Text("Hello, Claude!".to_string()),
temperature: Some(0.8),
max_output_tokens: Some(150),
stream: Some(false),
metadata: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
instructions: Some("You are a helpful assistant".to_string()),
modalities: None,
user: None,
store: None,
reasoning_effort: None,
include: None,
audio: None,
text: None,
service_tier: None,
top_p: None,
top_logprobs: None,
stream_options: None,
truncation: None,
conversation: None,
previous_response_id: None,
max_tool_calls: None,
background: None,
};
let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(Messages);
let result = ProviderRequestType::try_from((
ProviderRequestType::ResponsesAPIRequest(responses_req),
&upstream_api,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderRequestType::MessagesRequest(messages_req) => {
assert_eq!(messages_req.model, "gpt-4o");
assert_eq!(messages_req.temperature, Some(0.8));
assert_eq!(messages_req.max_tokens, 150);
// Instructions should be converted to system prompt via ChatCompletions conversion
// The conversion chain: ResponsesAPI -> ChatCompletions (system message) -> Anthropic (system prompt)
// But we need to check if the system prompt was actually set
assert_eq!(messages_req.messages.len(), 1);
}
_ => panic!("Expected MessagesRequest variant"),
}
}
#[test]
fn test_responses_api_to_bedrock_conversion() {
use crate::apis::amazon_bedrock::AmazonBedrockApi::Converse;
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
let responses_req = ResponsesAPIRequest {
model: "gpt-4o".to_string(),
input: InputParam::Text("Hello, Bedrock!".to_string()),
temperature: Some(0.5),
max_output_tokens: Some(200),
stream: Some(false),
metadata: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
instructions: None,
modalities: None,
user: None,
store: None,
reasoning_effort: None,
include: None,
audio: None,
text: None,
service_tier: None,
top_p: None,
top_logprobs: None,
stream_options: None,
truncation: None,
conversation: None,
previous_response_id: None,
max_tool_calls: None,
background: None,
};
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverse(Converse);
let result = ProviderRequestType::try_from((
ProviderRequestType::ResponsesAPIRequest(responses_req),
&upstream_api,
));
assert!(result.is_ok());
match result.unwrap() {
ProviderRequestType::BedrockConverse(bedrock_req) => {
assert_eq!(bedrock_req.model_id, "gpt-4o");
// Bedrock receives the converted request through ChatCompletions
assert!(!bedrock_req.messages.is_none());
}
_ => panic!("Expected BedrockConverse variant"),
}
}
#[test]
fn test_chat_completions_to_responses_api_not_supported() {
use crate::apis::openai::OpenAIApi::Responses;
use crate::apis::openai::{Message, MessageContent, Role};
let chat_req = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![Message {
role: Role::User,
content: MessageContent::Text("Hello!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}],
..Default::default()
};
let upstream_api = SupportedUpstreamAPIs::OpenAIResponsesAPI(Responses);
let result = ProviderRequestType::try_from((
ProviderRequestType::ChatCompletionsRequest(chat_req),
&upstream_api,
));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("ResponsesAPI can only be used as a client API"));
}
#[test]
fn test_anthropic_messages_to_responses_api_not_supported() {
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
use crate::apis::openai::OpenAIApi::Responses;
let messages_req = AnthropicMessagesRequest {
model: "claude-3-sonnet".to_string(),
messages: vec![crate::apis::anthropic::MessagesMessage {
role: crate::apis::anthropic::MessagesRole::User,
content: crate::apis::anthropic::MessagesMessageContent::Single(
"Hello!".to_string(),
),
}],
max_tokens: 100,
container: None,
mcp_servers: None,
service_tier: None,
thinking: None,
temperature: None,
top_p: None,
top_k: None,
stream: None,
stop_sequences: None,
system: None,
tools: None,
tool_choice: None,
metadata: None,
};
let upstream_api = SupportedUpstreamAPIs::OpenAIResponsesAPI(Responses);
let result = ProviderRequestType::try_from((
ProviderRequestType::MessagesRequest(messages_req),
&upstream_api,
));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("ResponsesAPI can only be used as a client API"));
}
#[test]
fn test_bedrock_as_client_api_not_supported() {
use crate::apis::openai::OpenAIApi::ChatCompletions;
// Create a simple Bedrock request (we'll use Default if available, or minimal construction)
let bedrock_req = ConverseRequest::default();
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(ChatCompletions);
let result = ProviderRequestType::try_from((
ProviderRequestType::BedrockConverse(bedrock_req),
&upstream_api,
));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("not supported as a client API"));
assert!(err
.message
.contains("OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses"));
}
#[test]
fn test_get_message_history_chat_completions() {
use crate::apis::openai::{Message, MessageContent, Role};
let chat_req = ChatCompletionsRequest {
model: "gpt-4".to_string(),
messages: vec![
Message {
role: Role::System,
content: MessageContent::Text("You are helpful".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
Message {
role: Role::User,
content: MessageContent::Text("Hello!".to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
},
],
..Default::default()
};
let provider_req = ProviderRequestType::ChatCompletionsRequest(chat_req);
let messages = provider_req.get_messages();
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, Role::System);
assert_eq!(messages[1].role, Role::User);
}
#[test]
fn test_get_message_history_anthropic_messages() {
use crate::apis::anthropic::{
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
MessagesSystemPrompt,
};
let anthropic_req = MessagesRequest {
model: "claude-3-sonnet".to_string(),
messages: vec![MessagesMessage {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Hello!".to_string()),
}],
system: Some(MessagesSystemPrompt::Single(
"You are helpful".to_string(),
)),
max_tokens: 100,
container: None,
mcp_servers: None,
metadata: None,
service_tier: None,
thinking: None,
temperature: None,
top_p: None,
top_k: None,
stream: None,
stop_sequences: None,
tools: None,
tool_choice: None,
};
let provider_req = ProviderRequestType::MessagesRequest(anthropic_req);
let messages = provider_req.get_messages();
// Should have system message + user message
assert_eq!(messages.len(), 2);
assert_eq!(
messages[0].role,
crate::apis::openai::Role::System
);
assert_eq!(
messages[1].role,
crate::apis::openai::Role::User
);
}
#[test]
fn test_get_message_history_responses_api() {
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
let responses_req = ResponsesAPIRequest {
model: "gpt-4o".to_string(),
input: InputParam::Text("Hello, world!".to_string()),
instructions: Some("Be helpful".to_string()),
temperature: None,
max_output_tokens: None,
stream: None,
metadata: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
modalities: None,
user: None,
store: None,
reasoning_effort: None,
include: None,
audio: None,
text: None,
service_tier: None,
top_p: None,
top_logprobs: None,
stream_options: None,
truncation: None,
conversation: None,
previous_response_id: None,
max_tool_calls: None,
background: None,
};
let provider_req = ProviderRequestType::ResponsesAPIRequest(responses_req);
let messages = provider_req.get_messages();
// Should have system message (instructions) + user message (input)
assert_eq!(messages.len(), 2);
assert_eq!(
messages[0].role,
crate::apis::openai::Role::System
);
assert_eq!(
messages[1].role,
crate::apis::openai::Role::User
);
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -11,11 +11,13 @@
pub mod lib;
pub mod request;
pub mod response;
pub mod response_streaming;
// Re-export commonly used items for convenience
pub use lib::*;
pub use request::*;
pub use response::*;
pub use response_streaming::*;
// ============================================================================
// CONSTANTS

View file

@ -12,6 +12,10 @@ use crate::apis::anthropic::{
use crate::apis::openai::{
ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType,
};
use crate::apis::openai_responses::{
ResponsesAPIRequest, InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice
};
use crate::clients::TransformError;
use crate::transforms::lib::ExtractText;
use crate::transforms::lib::*;
@ -244,6 +248,212 @@ impl TryFrom<Message> for BedrockMessage {
}
}
impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
type Error = TransformError;
fn try_from(req: ResponsesAPIRequest) -> Result<Self, Self::Error> {
// Convert input to messages
let messages = match req.input {
InputParam::Text(text) => {
// Simple text input becomes a user message
vec![Message {
role: Role::User,
content: MessageContent::Text(text),
name: None,
tool_call_id: None,
tool_calls: None,
}]
}
InputParam::Items(items) => {
// Convert input items to messages
let mut converted_messages = Vec::new();
// Add instructions as system message if present
if let Some(instructions) = &req.instructions {
converted_messages.push(Message {
role: Role::System,
content: MessageContent::Text(instructions.clone()),
name: None,
tool_call_id: None,
tool_calls: None,
});
}
// Convert each input item
for item in items {
match item {
InputItem::Message(input_msg) => {
let role = match input_msg.role {
MessageRole::User => Role::User,
MessageRole::Assistant => Role::Assistant,
MessageRole::System => Role::System,
MessageRole::Developer => Role::System, // Map developer to system
};
// Convert content based on MessageContent type
let content = match &input_msg.content {
crate::apis::openai_responses::MessageContent::Text(text) => {
// Simple text content
MessageContent::Text(text.clone())
}
crate::apis::openai_responses::MessageContent::Items(content_items) => {
// Check if it's a single text item (can use simple text format)
if content_items.len() == 1 {
if let InputContent::InputText { text } = &content_items[0] {
MessageContent::Text(text.clone())
} else {
// Single non-text item - use parts format
MessageContent::Parts(
content_items.iter()
.filter_map(|c| match c {
InputContent::InputText { text } => {
Some(crate::apis::openai::ContentPart::Text { text: text.clone() })
}
InputContent::InputImage { image_url, .. } => {
Some(crate::apis::openai::ContentPart::ImageUrl {
image_url: crate::apis::openai::ImageUrl {
url: image_url.clone(),
detail: None,
}
})
}
InputContent::InputFile { .. } => None, // Skip files for now
InputContent::InputAudio { .. } => None, // Skip audio for now
})
.collect()
)
}
} else {
// Multiple content items - convert to parts
MessageContent::Parts(
content_items.iter()
.filter_map(|c| match c {
InputContent::InputText { text } => {
Some(crate::apis::openai::ContentPart::Text { text: text.clone() })
}
InputContent::InputImage { image_url, .. } => {
Some(crate::apis::openai::ContentPart::ImageUrl {
image_url: crate::apis::openai::ImageUrl {
url: image_url.clone(),
detail: None,
}
})
}
InputContent::InputFile { .. } => None, // Skip files for now
InputContent::InputAudio { .. } => None, // Skip audio for now
})
.collect()
)
}
}
};
converted_messages.push(Message {
role,
content,
name: None,
tool_call_id: None,
tool_calls: None,
});
}
// Skip non-message items (references, outputs) for now
// These would need special handling in chat completions format
_ => {}
}
}
converted_messages
}
};
// Build the ChatCompletionsRequest
Ok(ChatCompletionsRequest {
model: req.model,
messages,
temperature: req.temperature,
top_p: req.top_p,
max_completion_tokens: req.max_output_tokens.map(|t| t as u32),
stream: req.stream,
metadata: req.metadata,
user: req.user,
store: req.store,
service_tier: req.service_tier,
top_logprobs: req.top_logprobs.map(|t| t as u32),
modalities: req.modalities.map(|mods| {
mods.into_iter().map(|m| {
match m {
Modality::Text => "text".to_string(),
Modality::Audio => "audio".to_string(),
}
}).collect()
}),
stream_options: req.stream_options.map(|opts| {
crate::apis::openai::StreamOptions {
include_usage: opts.include_usage,
}
}),
reasoning_effort: req.reasoning_effort.map(|effort| {
match effort {
ReasoningEffort::Low => "low".to_string(),
ReasoningEffort::Medium => "medium".to_string(),
ReasoningEffort::High => "high".to_string(),
}
}),
tools: req.tools.map(|tools| {
tools.into_iter().map(|tool| {
// Only convert Function tools - other types are not supported in ChatCompletions
match tool {
ResponsesTool::Function { name, description, parameters, strict } => Ok(Tool {
tool_type: "function".to_string(),
function: crate::apis::openai::Function {
name,
description,
parameters: parameters.unwrap_or_else(|| serde_json::json!({
"type": "object",
"properties": {}
})),
strict,
}
}),
ResponsesTool::FileSearch { .. } => Err(TransformError::UnsupportedConversion(
"FileSearch tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
)),
ResponsesTool::WebSearchPreview { .. } => Err(TransformError::UnsupportedConversion(
"WebSearchPreview tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
)),
ResponsesTool::CodeInterpreter => Err(TransformError::UnsupportedConversion(
"CodeInterpreter tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
)),
ResponsesTool::Computer { .. } => Err(TransformError::UnsupportedConversion(
"Computer tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
)),
}
}).collect::<Result<Vec<_>, _>>()
}).transpose()?,
tool_choice: req.tool_choice.map(|choice| {
match choice {
ResponsesToolChoice::String(s) => {
match s.as_str() {
"auto" => ToolChoice::Type(ToolChoiceType::Auto),
"required" => ToolChoice::Type(ToolChoiceType::Required),
"none" => ToolChoice::Type(ToolChoiceType::None),
_ => ToolChoice::Type(ToolChoiceType::Auto), // Default to auto for unknown strings
}
}
ResponsesToolChoice::Named { function, .. } => ToolChoice::Function {
choice_type: "function".to_string(),
function: crate::apis::openai::FunctionChoice { name: function.name }
}
}
}),
parallel_tool_calls: req.parallel_tool_calls,
..Default::default()
})
}
}
impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
type Error = TransformError;

View file

@ -1,3 +1,4 @@
//! Response transformation modules
pub mod output_to_input;
pub mod to_anthropic;
pub mod to_openai;

View file

@ -0,0 +1,178 @@
//! Conversions from response outputs to request inputs for conversation continuation
//!
//! This module provides utilities for converting OutputItem types from API responses
//! into InputItem types that can be used in subsequent requests. This is primarily used
//! for maintaining conversation history in the v1/responses API.
use crate::apis::openai_responses::{
InputContent, InputItem, InputMessage, MessageContent, MessageRole, OutputContent, OutputItem,
};
/// Converts an OutputItem from a response into an InputItem for the next request
/// This is used to build conversation history from previous responses
pub fn convert_responses_output_to_input_items(output: &OutputItem) -> Option<InputItem> {
match output {
// Convert output messages to input messages
OutputItem::Message {
role, content, ..
} => {
let input_content: Vec<InputContent> = content
.iter()
.filter_map(|c| match c {
OutputContent::OutputText { text, .. } => Some(InputContent::InputText {
text: text.clone(),
}),
OutputContent::OutputAudio {
data, ..
} => Some(InputContent::InputAudio {
data: data.clone(),
format: None, // Format not preserved in output
}),
OutputContent::Refusal { .. } => None, // Skip refusals
})
.collect();
if input_content.is_empty() {
return None;
}
// Map role string to MessageRole enum
let message_role = match role.as_str() {
"user" => MessageRole::User,
"assistant" => MessageRole::Assistant,
"system" => MessageRole::System,
"developer" => MessageRole::Developer,
_ => MessageRole::Assistant, // Default to assistant
};
Some(InputItem::Message(InputMessage {
role: message_role,
content: MessageContent::Items(input_content),
}))
}
// For function calls, we'll create an assistant message with the tool call info
// This matches how conversation history is typically built
OutputItem::FunctionCall {
name, arguments, ..
} => {
let tool_call_text = if let (Some(n), Some(args)) = (name, arguments) {
format!("Called function: {} with arguments: {}", n, args)
} else {
"Called a function".to_string()
};
Some(InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: tool_call_text,
}]),
}))
}
// Skip other output types (tool outputs, etc.) as they don't convert to input
_ => None,
}
}
/// Converts a Vec of OutputItems into InputItems for conversation continuation
pub fn outputs_to_inputs(outputs: &[OutputItem]) -> Vec<InputItem> {
outputs
.iter()
.filter_map(convert_responses_output_to_input_items)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::openai_responses::{OutputItemStatus};
#[test]
fn test_output_message_to_input() {
let output = OutputItem::Message {
id: "msg_123".to_string(),
status: OutputItemStatus::Completed,
role: "assistant".to_string(),
content: vec![OutputContent::OutputText {
text: "Hello!".to_string(),
annotations: vec![],
logprobs: None,
}],
};
let input = convert_responses_output_to_input_items(&output).unwrap();
match input {
InputItem::Message(msg) => {
assert!(matches!(msg.role, MessageRole::Assistant));
match &msg.content {
MessageContent::Items(items) => {
assert_eq!(items.len(), 1);
match &items[0] {
InputContent::InputText { text } => assert_eq!(text, "Hello!"),
_ => panic!("Expected InputText"),
}
}
_ => panic!("Expected MessageContent::Items"),
}
}
_ => panic!("Expected Message variant"),
}
}
#[test]
fn test_function_call_to_input() {
let output = OutputItem::FunctionCall {
id: "fc_123".to_string(),
status: OutputItemStatus::Completed,
call_id: "call_123".to_string(),
name: Some("get_weather".to_string()),
arguments: Some(r#"{"location":"SF"}"#.to_string()),
};
let input = convert_responses_output_to_input_items(&output).unwrap();
match input {
InputItem::Message(msg) => {
assert!(matches!(msg.role, MessageRole::Assistant));
match &msg.content {
MessageContent::Items(items) => {
match &items[0] {
InputContent::InputText { text } => {
assert!(text.contains("get_weather"));
}
_ => panic!("Expected InputText"),
}
}
_ => panic!("Expected MessageContent::Items"),
}
}
_ => panic!("Expected Message variant"),
}
}
#[test]
fn test_outputs_to_inputs() {
let outputs = vec![
OutputItem::Message {
id: "msg_1".to_string(),
status: OutputItemStatus::Completed,
role: "assistant".to_string(),
content: vec![OutputContent::OutputText {
text: "Hello".to_string(),
annotations: vec![],
logprobs: None,
}],
},
OutputItem::FunctionCall {
id: "fc_1".to_string(),
status: OutputItemStatus::Completed,
call_id: "call_1".to_string(),
name: Some("test".to_string()),
arguments: Some("{}".to_string()),
},
];
let inputs = outputs_to_inputs(&outputs);
assert_eq!(inputs.len(), 2);
}
}

View file

@ -1,16 +1,11 @@
use crate::apis::amazon_bedrock::{
ContentBlockDelta, ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
};
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesResponse,
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
};
use crate::apis::openai::{
ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta,
MessagesContentBlock, MessagesResponse,
MessagesRole, MessagesStopReason, MessagesUsage,
};
use crate::apis::openai::ChatCompletionsResponse;
use crate::clients::TransformError;
use crate::transforms::lib::*;
use serde_json::Value;
// ============================================================================
// STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience
@ -120,289 +115,6 @@ impl TryFrom<ConverseResponse> for MessagesResponse {
}
}
impl TryFrom<ChatCompletionsStreamResponse> for MessagesStreamEvent {
type Error = TransformError;
fn try_from(resp: ChatCompletionsStreamResponse) -> Result<Self, Self::Error> {
if resp.choices.is_empty() {
return Ok(MessagesStreamEvent::Ping);
}
let choice = &resp.choices[0];
// Handle final chunk with usage
let has_usage = resp.usage.is_some();
if let Some(usage) = resp.usage {
if let Some(finish_reason) = &choice.finish_reason {
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
return Ok(MessagesStreamEvent::MessageDelta {
delta: MessagesMessageDelta {
stop_reason: anthropic_stop_reason,
stop_sequence: None,
},
usage: usage.into(),
});
}
}
// Handle role start
if let Some(Role::Assistant) = choice.delta.role {
return Ok(MessagesStreamEvent::MessageStart {
message: MessagesStreamMessage {
id: resp.id,
obj_type: "message".to_string(),
role: MessagesRole::Assistant,
content: vec![],
model: resp.model,
stop_reason: None,
stop_sequence: None,
usage: MessagesUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
},
});
}
// Handle content delta
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
return Ok(MessagesStreamEvent::ContentBlockDelta {
index: 0,
delta: MessagesContentDelta::TextDelta {
text: content.clone(),
},
});
}
}
// Handle tool calls
if let Some(tool_calls) = &choice.delta.tool_calls {
return convert_tool_call_deltas(tool_calls.clone());
}
// Handle finish reason - generate MessageDelta only (MessageStop comes later)
if let Some(finish_reason) = &choice.finish_reason {
// If we have usage data, it was already handled above
// If not, we need to generate MessageDelta with default usage
if !has_usage {
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
return Ok(MessagesStreamEvent::MessageDelta {
delta: MessagesMessageDelta {
stop_reason: anthropic_stop_reason,
stop_sequence: None,
},
usage: MessagesUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
});
}
// If usage was already handled above, we don't need to do anything more here
// MessageStop will be handled when [DONE] is encountered
}
// Default to ping for unhandled cases
Ok(MessagesStreamEvent::Ping)
}
}
impl Into<String> for MessagesStreamEvent {
fn into(self) -> String {
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
let event_type = match &self {
MessagesStreamEvent::MessageStart { .. } => "message_start",
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
MessagesStreamEvent::MessageStop => "message_stop",
MessagesStreamEvent::Ping => "ping",
};
let event = format!("event: {}\n", event_type);
let data = format!("data: {}\n\n", transformed_json);
event + &data
}
}
impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
type Error = TransformError;
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
match event {
// MessageStart - convert to Anthropic MessageStart
ConverseStreamEvent::MessageStart(start_event) => {
let role = match start_event.role {
crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User,
crate::apis::amazon_bedrock::ConversationRole::Assistant => {
MessagesRole::Assistant
}
};
Ok(MessagesStreamEvent::MessageStart {
message: MessagesStreamMessage {
id: format!(
"bedrock-stream-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
),
obj_type: "message".to_string(),
role,
content: vec![],
model: "bedrock-model".to_string(),
stop_reason: None,
stop_sequence: None,
usage: MessagesUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
},
})
}
// ContentBlockStart - convert to Anthropic ContentBlockStart
ConverseStreamEvent::ContentBlockStart(start_event) => {
// Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas
// Anthropic expects the same pattern, so we initialize with an empty input object
match start_event.start {
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => {
Ok(MessagesStreamEvent::ContentBlockStart {
index: start_event.content_block_index as u32,
content_block: MessagesContentBlock::ToolUse {
id: tool_use.tool_use_id,
name: tool_use.name,
input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas
cache_control: None,
},
})
}
}
}
// ContentBlockDelta - convert to Anthropic ContentBlockDelta
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
let delta = match delta_event.delta {
ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text },
ContentBlockDelta::ToolUse { tool_use } => {
MessagesContentDelta::InputJsonDelta {
partial_json: tool_use.input,
}
}
};
Ok(MessagesStreamEvent::ContentBlockDelta {
index: delta_event.content_block_index as u32,
delta,
})
}
// ContentBlockStop - convert to Anthropic ContentBlockStop
ConverseStreamEvent::ContentBlockStop(stop_event) => {
Ok(MessagesStreamEvent::ContentBlockStop {
index: stop_event.content_block_index as u32,
})
}
// MessageStop - convert to Anthropic MessageDelta with stop reason + MessageStop
ConverseStreamEvent::MessageStop(stop_event) => {
let anthropic_stop_reason = match stop_event.stop_reason {
StopReason::EndTurn => MessagesStopReason::EndTurn,
StopReason::ToolUse => MessagesStopReason::ToolUse,
StopReason::MaxTokens => MessagesStopReason::MaxTokens,
StopReason::StopSequence => MessagesStopReason::EndTurn,
StopReason::GuardrailIntervened => MessagesStopReason::Refusal,
StopReason::ContentFiltered => MessagesStopReason::Refusal,
};
// Return MessageDelta (MessageStop will be sent separately by the streaming handler)
Ok(MessagesStreamEvent::MessageDelta {
delta: MessagesMessageDelta {
stop_reason: anthropic_stop_reason,
stop_sequence: None,
},
usage: MessagesUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
})
}
// Metadata - convert usage information to MessageDelta
ConverseStreamEvent::Metadata(metadata_event) => {
Ok(MessagesStreamEvent::MessageDelta {
delta: MessagesMessageDelta {
stop_reason: MessagesStopReason::EndTurn,
stop_sequence: None,
},
usage: MessagesUsage {
input_tokens: metadata_event.usage.input_tokens,
output_tokens: metadata_event.usage.output_tokens,
cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens,
cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens,
},
})
}
// Exception events - convert to Ping (could be enhanced to return error events)
ConverseStreamEvent::InternalServerException(_)
| ConverseStreamEvent::ModelStreamErrorException(_)
| ConverseStreamEvent::ServiceUnavailableException(_)
| ConverseStreamEvent::ThrottlingException(_)
| ConverseStreamEvent::ValidationException(_) => {
// TODO: Consider adding proper error handling/events
Ok(MessagesStreamEvent::Ping)
}
}
}
}
/// Convert tool call deltas to Anthropic stream events
fn convert_tool_call_deltas(
tool_calls: Vec<ToolCallDelta>,
) -> Result<MessagesStreamEvent, TransformError> {
for tool_call in tool_calls {
if let Some(id) = &tool_call.id {
// Tool call start
if let Some(function) = &tool_call.function {
if let Some(name) = &function.name {
return Ok(MessagesStreamEvent::ContentBlockStart {
index: tool_call.index,
content_block: MessagesContentBlock::ToolUse {
id: id.clone(),
name: name.clone(),
input: Value::Object(serde_json::Map::new()),
cache_control: None,
},
});
}
}
} else if let Some(function) = &tool_call.function {
if let Some(arguments) = &function.arguments {
// Tool arguments delta
return Ok(MessagesStreamEvent::ContentBlockDelta {
index: tool_call.index,
delta: MessagesContentDelta::InputJsonDelta {
partial_json: arguments.clone(),
},
});
}
}
}
// Fallback to ping if no valid tool call found
Ok(MessagesStreamEvent::Ping)
}
/// Convert Bedrock Message to Anthropic content blocks
///

View file

@ -1,15 +1,13 @@
use crate::apis::amazon_bedrock::{
ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
ConverseOutput, ConverseResponse, StopReason,
};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesContentDelta, MessagesResponse, MessagesStopReason,
MessagesStreamEvent, MessagesUsage,
MessagesContentBlock, MessagesResponse, MessagesUsage,
};
use crate::apis::openai::{
ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason,
FunctionCallDelta, MessageContent, MessageDelta, ResponseMessage, Role, StreamChoice,
ToolCallDelta, Usage,
ChatCompletionsResponse, Choice, FinishReason, MessageContent, ResponseMessage, Role, Usage,
};
use crate::apis::openai_responses::ResponsesAPIResponse;
use crate::clients::TransformError;
use crate::transforms::lib::*;
@ -30,6 +28,182 @@ impl Into<Usage> for MessagesUsage {
}
}
impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
type Error = TransformError;
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
use crate::apis::openai_responses::{
IncompleteDetails, IncompleteReason, OutputContent, OutputItem, OutputItemStatus,
ResponseStatus, ResponseUsage, ResponsesAPIResponse,
};
// Convert the first choice's message to output items
let output = if let Some(choice) = resp.choices.first() {
let mut items = Vec::new();
// Create a message output item from the response message
let mut content = Vec::new();
// Add text content if present
if let Some(text) = &choice.message.content {
content.push(OutputContent::OutputText {
text: text.clone(),
annotations: vec![],
logprobs: None,
});
}
// Add audio content if present (audio is a Value, need to handle it carefully)
if let Some(audio) = &choice.message.audio {
// Audio is serde_json::Value, try to extract data and transcript
if let Some(audio_obj) = audio.as_object() {
let data = audio_obj
.get("data")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let transcript = audio_obj
.get("transcript")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
content.push(OutputContent::OutputAudio { data, transcript });
}
}
// Add refusal content if present
if let Some(refusal) = &choice.message.refusal {
content.push(OutputContent::Refusal {
refusal: refusal.clone(),
});
}
// Only add the message item if there's actual content (text, audio, or refusal)
// Don't add empty message items when there are only tool calls
if !content.is_empty() {
// Generate message ID: strip common prefixes to avoid double-prefixing
let message_id = if resp.id.starts_with("msg_") {
resp.id.clone()
} else if resp.id.starts_with("resp_") {
format!("msg_{}", &resp.id[5..]) // Strip "resp_" prefix
} else if resp.id.starts_with("chatcmpl-") {
format!("msg_{}", &resp.id[9..]) // Strip "chatcmpl-" prefix
} else {
format!("msg_{}", resp.id)
};
items.push(OutputItem::Message {
id: message_id,
status: OutputItemStatus::Completed,
role: match choice.message.role {
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
Role::System => "system".to_string(),
Role::Tool => "tool".to_string(),
},
content,
});
}
// Add tool calls as function call items if present
if let Some(tool_calls) = &choice.message.tool_calls {
for tool_call in tool_calls {
items.push(OutputItem::FunctionCall {
id: format!("func_{}", tool_call.id),
status: OutputItemStatus::Completed,
call_id: tool_call.id.clone(),
name: Some(tool_call.function.name.clone()),
arguments: Some(tool_call.function.arguments.clone()),
});
}
}
items
} else {
vec![]
};
// Convert finish_reason to status
let status = if let Some(choice) = resp.choices.first() {
match choice.finish_reason {
Some(FinishReason::Stop) => ResponseStatus::Completed,
Some(FinishReason::ToolCalls) => ResponseStatus::Completed,
Some(FinishReason::Length) => ResponseStatus::Incomplete,
Some(FinishReason::ContentFilter) => ResponseStatus::Failed,
_ => ResponseStatus::Completed,
}
} else {
ResponseStatus::Completed
};
// Convert usage
let usage = ResponseUsage {
input_tokens: resp.usage.prompt_tokens as i32,
output_tokens: resp.usage.completion_tokens as i32,
total_tokens: resp.usage.total_tokens as i32,
input_tokens_details: resp.usage.prompt_tokens_details.map(|details| {
crate::apis::openai_responses::TokenDetails {
cached_tokens: details.cached_tokens.unwrap_or(0) as i32,
}
}),
output_tokens_details: resp.usage.completion_tokens_details.map(|details| {
crate::apis::openai_responses::OutputTokenDetails {
reasoning_tokens: details.reasoning_tokens.unwrap_or(0) as i32,
}
}),
};
// Set incomplete_details if status is incomplete
let incomplete_details = if matches!(status, ResponseStatus::Incomplete) {
Some(IncompleteDetails {
reason: IncompleteReason::MaxOutputTokens,
})
} else {
None
};
Ok(ResponsesAPIResponse {
// Generate proper resp_ prefixed ID if not already present
id: if resp.id.starts_with("resp_") {
resp.id
} else {
format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", ""))
},
object: "response".to_string(),
created_at: resp.created as i64,
status,
background: Some(false),
error: None,
incomplete_details,
instructions: None,
max_output_tokens: None,
max_tool_calls: None,
model: resp.model,
output,
usage: Some(usage),
parallel_tool_calls: true,
conversation: None,
previous_response_id: None,
tools: vec![],
tool_choice: "auto".to_string(),
temperature: 1.0,
top_p: 1.0,
metadata: resp.metadata.unwrap_or_default(),
truncation: None,
reasoning: Some(crate::apis::openai_responses::Reasoning {
effort: None,
summary: None,
}),
store: None,
text: None,
audio: None,
modalities: None,
service_tier: resp.service_tier,
top_logprobs: None,
})
}
}
impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
type Error = TransformError;
@ -83,8 +257,7 @@ impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
model: resp.model,
choices: vec![choice],
usage,
system_fingerprint: None,
service_tier: None,
..Default::default()
})
}
}
@ -169,422 +342,11 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
model,
choices: vec![choice],
usage,
system_fingerprint: None,
service_tier: None,
..Default::default()
})
}
}
// ============================================================================
// STREAMING TRANSFORMATIONS
// ============================================================================
impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
type Error = TransformError;
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
match event {
MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
&message.id,
&message.model,
MessageDelta {
role: Some(Role::Assistant),
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
convert_content_block_start(content_block)
}
MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
MessagesStreamEvent::MessageDelta { delta, usage } => {
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
let openai_usage: Option<Usage> = Some(usage.into());
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
finish_reason,
openai_usage,
))
}
MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
Some(FinishReason::Stop),
None,
)),
MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
id: "stream".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: current_timestamp(),
model: "unknown".to_string(),
choices: vec![],
usage: None,
system_fingerprint: None,
service_tier: None,
}),
}
}
}
impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
type Error = TransformError;
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
match event {
ConverseStreamEvent::MessageStart(start_event) => {
let role = match start_event.role {
crate::apis::amazon_bedrock::ConversationRole::User => Role::User,
crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant,
};
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: Some(role),
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
))
}
ConverseStreamEvent::ContentBlockStart(start_event) => {
use crate::apis::amazon_bedrock::ContentBlockStart;
match start_event.start {
ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: start_event.content_block_index as u32,
id: Some(tool_use.tool_use_id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(tool_use.name),
arguments: Some("".to_string()),
}),
}]),
},
None,
None,
)),
}
}
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
use crate::apis::amazon_bedrock::ContentBlockDelta;
match delta_event.delta {
ContentBlockDelta::Text { text } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(text),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: delta_event.content_block_index as u32,
id: None,
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(tool_use.input),
}),
}]),
},
None,
None,
)),
}
}
ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()),
ConverseStreamEvent::MessageStop(stop_event) => {
let finish_reason = match stop_event.stop_reason {
StopReason::EndTurn => FinishReason::Stop,
StopReason::ToolUse => FinishReason::ToolCalls,
StopReason::MaxTokens => FinishReason::Length,
StopReason::StopSequence => FinishReason::Stop,
StopReason::GuardrailIntervened => FinishReason::ContentFilter,
StopReason::ContentFiltered => FinishReason::ContentFilter,
};
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
Some(finish_reason),
None,
))
}
ConverseStreamEvent::Metadata(metadata_event) => {
let usage = Usage {
prompt_tokens: metadata_event.usage.input_tokens,
completion_tokens: metadata_event.usage.output_tokens,
total_tokens: metadata_event.usage.total_tokens,
prompt_tokens_details: None,
completion_tokens_details: None,
};
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
Some(usage),
))
}
// Error events - convert to empty chunks (errors should be handled elsewhere)
ConverseStreamEvent::InternalServerException(_)
| ConverseStreamEvent::ModelStreamErrorException(_)
| ConverseStreamEvent::ServiceUnavailableException(_)
| ConverseStreamEvent::ThrottlingException(_)
| ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()),
}
}
}
/// Convert content block start to OpenAI chunk
fn convert_content_block_start(
content_block: MessagesContentBlock,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match content_block {
MessagesContentBlock::Text { .. } => {
// No immediate output for text block start
Ok(create_empty_openai_chunk())
}
MessagesContentBlock::ToolUse { id, name, .. }
| MessagesContentBlock::ServerToolUse { id, name, .. }
| MessagesContentBlock::McpToolUse { id, name, .. } => {
// Tool use start → OpenAI chunk with tool_calls
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: Some(id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(name),
arguments: Some("".to_string()),
}),
}]),
},
None,
None,
))
}
_ => Err(TransformError::UnsupportedContent(
"Unsupported content block type in stream start".to_string(),
)),
}
}
/// Convert content delta to OpenAI chunk
fn convert_content_delta(
delta: MessagesContentDelta,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match delta {
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(text),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(format!("thinking: {}", thinking)),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: 0,
id: None,
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(partial_json),
}),
}]),
},
None,
None,
)),
}
}
/// Helper to create OpenAI streaming chunk
fn create_openai_chunk(
id: &str,
model: &str,
delta: MessageDelta,
finish_reason: Option<FinishReason>,
usage: Option<Usage>,
) -> ChatCompletionsStreamResponse {
ChatCompletionsStreamResponse {
id: id.to_string(),
object: Some("chat.completion.chunk".to_string()),
created: current_timestamp(),
model: model.to_string(),
choices: vec![StreamChoice {
index: 0,
delta,
finish_reason,
logprobs: None,
}],
usage,
system_fingerprint: None,
service_tier: None,
}
}
/// Helper to create empty OpenAI streaming chunk
fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)
}
/// Convert Anthropic content blocks to OpenAI message content
fn convert_anthropic_content_to_openai(
content: &[MessagesContentBlock],
) -> Result<MessageContent, TransformError> {
let mut text_parts = Vec::new();
for block in content {
match block {
MessagesContentBlock::Text { text, .. } => {
text_parts.push(text.clone());
}
MessagesContentBlock::Thinking { thinking, .. } => {
text_parts.push(format!("thinking: {}", thinking));
}
_ => {
// Skip other content types for basic text conversion
continue;
}
}
}
Ok(MessageContent::Text(text_parts.join("\n")))
}
// Stop Reason Conversions
impl Into<FinishReason> for MessagesStopReason {
fn into(self) -> FinishReason {
match self {
MessagesStopReason::EndTurn => FinishReason::Stop,
MessagesStopReason::MaxTokens => FinishReason::Length,
MessagesStopReason::StopSequence => FinishReason::Stop,
MessagesStopReason::ToolUse => FinishReason::ToolCalls,
MessagesStopReason::PauseTurn => FinishReason::Stop,
MessagesStopReason::Refusal => FinishReason::ContentFilter,
}
}
}
/// Convert Bedrock Message to OpenAI content and tool calls
/// This function extracts text content and tool calls from a Bedrock message
fn convert_bedrock_message_to_openai(
@ -629,6 +391,31 @@ fn convert_bedrock_message_to_openai(
Ok((content, tool_calls))
}
/// Convert Anthropic content blocks to OpenAI message content
fn convert_anthropic_content_to_openai(
content: &[MessagesContentBlock],
) -> Result<MessageContent, TransformError> {
let mut text_parts = Vec::new();
for block in content {
match block {
MessagesContentBlock::Text { text, .. } => {
text_parts.push(text.clone());
}
MessagesContentBlock::Thinking { thinking, .. } => {
text_parts.push(format!("thinking: {}", thinking));
}
_ => {
// Skip other content types for basic text conversion
continue;
}
}
}
Ok(MessageContent::Text(text_parts.join("\n")))
}
#[cfg(test)]
mod tests {
use super::*;
@ -1168,4 +955,214 @@ mod tests {
assert!(content.contains("Here's the analysis:"));
// Note: Image blocks are not converted to text in the current implementation
}
#[test]
fn test_chat_completions_to_responses_api_basic() {
use crate::apis::openai_responses::{OutputContent, OutputItem, ResponsesAPIResponse};
let chat_response = ChatCompletionsResponse {
id: "resp_6de5512800cf4375a329a473a4f02879".to_string(),
object: Some("chat.completion".to_string()),
created: 1677652288,
model: "gpt-4".to_string(),
choices: vec![Choice {
index: 0,
message: crate::apis::openai::ResponseMessage {
role: Role::Assistant,
content: Some("Hello! How can I help you?".to_string()),
refusal: None,
annotations: None,
audio: None,
function_call: None,
tool_calls: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: None,
}],
usage: Usage {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
prompt_tokens_details: None,
completion_tokens_details: None,
},
system_fingerprint: None,
service_tier: Some("default".to_string()),
metadata: None,
};
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
// Response ID should be generated with resp_ prefix
assert!(responses_api.id.starts_with("resp_"), "Response ID should start with 'resp_'");
assert_eq!(responses_api.id.len(), 37, "Response ID should be resp_ + 32 char UUID");
assert_eq!(responses_api.object, "response");
assert_eq!(responses_api.model, "gpt-4");
// Check usage conversion
let usage = responses_api.usage.unwrap();
assert_eq!(usage.input_tokens, 10);
assert_eq!(usage.output_tokens, 20);
assert_eq!(usage.total_tokens, 30);
// Check output items
assert_eq!(responses_api.output.len(), 1);
match &responses_api.output[0] {
OutputItem::Message {
role,
content,
..
} => {
assert_eq!(role, "assistant");
assert_eq!(content.len(), 1);
match &content[0] {
OutputContent::OutputText { text, .. } => {
assert_eq!(text, "Hello! How can I help you?");
}
_ => panic!("Expected OutputText content"),
}
}
_ => panic!("Expected Message output item"),
}
}
#[test]
fn test_chat_completions_to_responses_api_with_tool_calls() {
use crate::apis::openai::{FunctionCall, ToolCall};
use crate::apis::openai_responses::{OutputItem, ResponsesAPIResponse};
let chat_response = ChatCompletionsResponse {
id: "chatcmpl-456".to_string(),
object: Some("chat.completion".to_string()),
created: 1677652300,
model: "gpt-4".to_string(),
choices: vec![Choice {
index: 0,
message: crate::apis::openai::ResponseMessage {
role: Role::Assistant,
content: Some("Let me check the weather.".to_string()),
refusal: None,
annotations: None,
audio: None,
function_call: None,
tool_calls: Some(vec![ToolCall {
id: "call_abc123".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location":"San Francisco"}"#.to_string(),
},
}]),
},
finish_reason: Some(FinishReason::ToolCalls),
logprobs: None,
}],
usage: Usage {
prompt_tokens: 15,
completion_tokens: 25,
total_tokens: 40,
prompt_tokens_details: None,
completion_tokens_details: None,
},
system_fingerprint: None,
service_tier: None,
metadata: None,
};
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
// Should have 2 output items: message + function call
assert_eq!(responses_api.output.len(), 2);
// Check message item
match &responses_api.output[0] {
OutputItem::Message { content, .. } => {
assert_eq!(content.len(), 1);
}
_ => panic!("Expected Message output item"),
}
// Check function call item
match &responses_api.output[1] {
OutputItem::FunctionCall {
call_id,
name,
arguments,
..
} => {
assert_eq!(call_id, "call_abc123");
assert_eq!(name.as_ref().unwrap(), "get_weather");
assert!(arguments.as_ref().unwrap().contains("San Francisco"));
}
_ => panic!("Expected FunctionCall output item"),
}
}
#[test]
fn test_chat_completions_to_responses_api_tool_calls_only() {
use crate::apis::openai::{FunctionCall, ToolCall};
use crate::apis::openai_responses::{OutputItem, ResponsesAPIResponse};
// Test the real-world case where content is null and there are only tool calls
let chat_response = ChatCompletionsResponse {
id: "chatcmpl-789".to_string(),
object: Some("chat.completion".to_string()),
created: 1764023939,
model: "gpt-4o-2024-08-06".to_string(),
choices: vec![Choice {
index: 0,
message: crate::apis::openai::ResponseMessage {
role: Role::Assistant,
content: None, // No text content, only tool calls
refusal: None,
annotations: None,
audio: None,
function_call: None,
tool_calls: Some(vec![ToolCall {
id: "call_oJBtqTJmRfBGlFS55QhMfUUV".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location":"San Francisco, CA"}"#.to_string(),
},
}]),
},
finish_reason: Some(FinishReason::ToolCalls),
logprobs: None,
}],
usage: Usage {
prompt_tokens: 84,
completion_tokens: 17,
total_tokens: 101,
prompt_tokens_details: None,
completion_tokens_details: None,
},
system_fingerprint: Some("fp_7eeb46f068".to_string()),
service_tier: Some("default".to_string()),
metadata: None,
};
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
// Should have only 1 output item: function call (no empty message item)
assert_eq!(responses_api.output.len(), 1);
// Check function call item
match &responses_api.output[0] {
OutputItem::FunctionCall {
call_id,
name,
arguments,
..
} => {
assert_eq!(call_id, "call_oJBtqTJmRfBGlFS55QhMfUUV");
assert_eq!(name.as_ref().unwrap(), "get_weather");
assert!(arguments.as_ref().unwrap().contains("San Francisco, CA"));
}
_ => panic!("Expected FunctionCall output item as first item"),
}
// Verify status is Completed for tool_calls finish reason
assert!(matches!(responses_api.status, crate::apis::openai_responses::ResponseStatus::Completed));
}
}

View file

@ -0,0 +1,2 @@
pub mod to_anthropic_streaming;
pub mod to_openai_streaming;

View file

@ -0,0 +1,281 @@
use crate::apis::amazon_bedrock::{
ContentBlockDelta, ConverseStreamEvent,
};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta,
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
};
use crate::apis::openai::{ ChatCompletionsStreamResponse, ToolCallDelta,
};
use crate::clients::TransformError;
use serde_json::Value;
impl TryFrom<ChatCompletionsStreamResponse> for MessagesStreamEvent {
type Error = TransformError;
fn try_from(resp: ChatCompletionsStreamResponse) -> Result<Self, Self::Error> {
if resp.choices.is_empty() {
return Ok(MessagesStreamEvent::Ping);
}
let choice = &resp.choices[0];
// Handle final chunk with usage
let has_usage = resp.usage.is_some();
if let Some(usage) = resp.usage {
if let Some(finish_reason) = &choice.finish_reason {
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
return Ok(MessagesStreamEvent::MessageDelta {
delta: MessagesMessageDelta {
stop_reason: anthropic_stop_reason,
stop_sequence: None,
},
usage: usage.into(),
});
}
}
// NOTE: We do NOT emit MessageStart here anymore!
// The AnthropicMessagesStreamBuffer will inject message_start and content_block_start
// when it sees the first content_block_delta. This solves the problem where OpenAI
// sends both role and content in the same chunk - we can only return one event here,
// so we prioritize the content and let the buffer handle lifecycle events.
// Handle content delta (even if role is present in the same chunk)
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
return Ok(MessagesStreamEvent::ContentBlockDelta {
index: 0,
delta: MessagesContentDelta::TextDelta {
text: content.clone(),
},
});
}
}
// Handle tool calls
if let Some(tool_calls) = &choice.delta.tool_calls {
return convert_tool_call_deltas(tool_calls.clone());
}
// Handle finish reason - generate MessageDelta only (MessageStop comes later)
if let Some(finish_reason) = &choice.finish_reason {
// If we have usage data, it was already handled above
// If not, we need to generate MessageDelta with default usage
if !has_usage {
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
return Ok(MessagesStreamEvent::MessageDelta {
delta: MessagesMessageDelta {
stop_reason: anthropic_stop_reason,
stop_sequence: None,
},
usage: MessagesUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
});
}
// If usage was already handled above, we don't need to do anything more here
// MessageStop will be handled when [DONE] is encountered
}
// Default to ping for unhandled cases
Ok(MessagesStreamEvent::Ping)
}
}
impl Into<String> for MessagesStreamEvent {
fn into(self) -> String {
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
let event_type = match &self {
MessagesStreamEvent::MessageStart { .. } => "message_start",
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
MessagesStreamEvent::MessageStop => "message_stop",
MessagesStreamEvent::Ping => "ping",
};
let event = format!("event: {}\n", event_type);
let data = format!("data: {}\n\n", transformed_json);
event + &data
}
}
impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
type Error = TransformError;
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
match event {
// MessageStart - convert to Anthropic MessageStart
ConverseStreamEvent::MessageStart(start_event) => {
let role = match start_event.role {
crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User,
crate::apis::amazon_bedrock::ConversationRole::Assistant => {
MessagesRole::Assistant
}
};
Ok(MessagesStreamEvent::MessageStart {
message: MessagesStreamMessage {
id: format!(
"bedrock-stream-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
),
obj_type: "message".to_string(),
role,
content: vec![],
model: "bedrock-model".to_string(),
stop_reason: None,
stop_sequence: None,
usage: MessagesUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
},
})
}
// ContentBlockStart - convert to Anthropic ContentBlockStart
ConverseStreamEvent::ContentBlockStart(start_event) => {
// Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas
// Anthropic expects the same pattern, so we initialize with an empty input object
match start_event.start {
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => {
Ok(MessagesStreamEvent::ContentBlockStart {
index: start_event.content_block_index as u32,
content_block: MessagesContentBlock::ToolUse {
id: tool_use.tool_use_id,
name: tool_use.name,
input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas
cache_control: None,
},
})
}
}
}
// ContentBlockDelta - convert to Anthropic ContentBlockDelta
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
let delta = match delta_event.delta {
ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text },
ContentBlockDelta::ToolUse { tool_use } => {
MessagesContentDelta::InputJsonDelta {
partial_json: tool_use.input,
}
}
};
Ok(MessagesStreamEvent::ContentBlockDelta {
index: delta_event.content_block_index as u32,
delta,
})
}
// ContentBlockStop - convert to Anthropic ContentBlockStop
ConverseStreamEvent::ContentBlockStop(stop_event) => {
Ok(MessagesStreamEvent::ContentBlockStop {
index: stop_event.content_block_index as u32,
})
}
// MessageStop - convert to Anthropic MessageDelta with stop reason
// Note: Bedrock sends Metadata separately with usage info, creating a second MessageDelta
// The client should merge these or use the final one with complete usage
ConverseStreamEvent::MessageStop(stop_event) => {
let anthropic_stop_reason = match stop_event.stop_reason {
crate::apis::amazon_bedrock::StopReason::EndTurn => MessagesStopReason::EndTurn,
crate::apis::amazon_bedrock::StopReason::ToolUse => MessagesStopReason::ToolUse,
crate::apis::amazon_bedrock::StopReason::MaxTokens => MessagesStopReason::MaxTokens,
crate::apis::amazon_bedrock::StopReason::StopSequence => MessagesStopReason::EndTurn,
crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => MessagesStopReason::Refusal,
crate::apis::amazon_bedrock::StopReason::ContentFiltered => MessagesStopReason::Refusal,
};
Ok(MessagesStreamEvent::MessageDelta {
delta: MessagesMessageDelta {
stop_reason: anthropic_stop_reason,
stop_sequence: None,
},
usage: MessagesUsage {
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
},
})
}
// Metadata - convert usage information to MessageDelta
ConverseStreamEvent::Metadata(metadata_event) => {
Ok(MessagesStreamEvent::MessageDelta {
delta: MessagesMessageDelta {
stop_reason: MessagesStopReason::EndTurn,
stop_sequence: None,
},
usage: MessagesUsage {
input_tokens: metadata_event.usage.input_tokens,
output_tokens: metadata_event.usage.output_tokens,
cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens,
cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens,
},
})
}
// Exception events - convert to Ping (could be enhanced to return error events)
ConverseStreamEvent::InternalServerException(_)
| ConverseStreamEvent::ModelStreamErrorException(_)
| ConverseStreamEvent::ServiceUnavailableException(_)
| ConverseStreamEvent::ThrottlingException(_)
| ConverseStreamEvent::ValidationException(_) => {
// TODO: Consider adding proper error handling/events
Ok(MessagesStreamEvent::Ping)
}
}
}
}
/// Convert tool call deltas to Anthropic stream events
fn convert_tool_call_deltas(
tool_calls: Vec<ToolCallDelta>,
) -> Result<MessagesStreamEvent, TransformError> {
for tool_call in tool_calls {
if let Some(id) = &tool_call.id {
// Tool call start
if let Some(function) = &tool_call.function {
if let Some(name) = &function.name {
return Ok(MessagesStreamEvent::ContentBlockStart {
index: tool_call.index,
content_block: MessagesContentBlock::ToolUse {
id: id.clone(),
name: name.clone(),
input: Value::Object(serde_json::Map::new()),
cache_control: None,
},
});
}
}
} else if let Some(function) = &tool_call.function {
if let Some(arguments) = &function.arguments {
// Tool arguments delta
return Ok(MessagesStreamEvent::ContentBlockDelta {
index: tool_call.index,
delta: MessagesContentDelta::InputJsonDelta {
partial_json: arguments.clone(),
},
});
}
}
}
// Fallback to ping if no valid tool call found
Ok(MessagesStreamEvent::Ping)
}

View file

@ -0,0 +1,546 @@
use crate::apis::amazon_bedrock::{ ConverseStreamEvent, StopReason};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent};
use crate::apis::openai::{ ChatCompletionsStreamResponse,FinishReason,
FunctionCallDelta, MessageDelta, Role, StreamChoice, ToolCallDelta, Usage,
};
use crate::apis::openai_responses::ResponsesAPIStreamEvent;
use crate::clients::TransformError;
use crate::transforms::lib::*;
// ============================================================================
// PROVIDER STREAMING TRANSFORMATIONS TO OPENAI FORMAT
// ============================================================================
//
// This module handles business logic for converting streaming events from
// various providers (Anthropic, Bedrock, etc.) into OpenAI's ChatCompletions format.
//
// # Architecture Separation
//
// **Provider Transformations** (this module):
// - Business logic for converting between provider formats
// - Uses Rust traits (TryFrom, Into) for type-safe conversions
// - Stateless event-by-event transformation
// - Example: MessagesStreamEvent → ChatCompletionsStreamResponse
//
// **Wire Format Buffering** (`apis/streaming_shapes/`):
// - SSE protocol handling (data:, event: lines)
// - State accumulation and lifecycle management
// - Buffering for stateful APIs (v1/responses)
// - Example: ChatCompletionsToResponsesTransformer
//
// # Flow
//
// ```text
// Anthropic Event → [Provider Transform] → OpenAI Event → [Wire Buffer] → SSE Wire Format
// (business) (this module) (protocol) (streaming_shapes) (network)
// ```
//
// ============================================================================
impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
type Error = TransformError;
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
match event {
MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
&message.id,
&message.model,
MessageDelta {
role: Some(Role::Assistant),
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesStreamEvent::ContentBlockStart { content_block, index } => {
convert_content_block_start(content_block, index)
}
MessagesStreamEvent::ContentBlockDelta { delta, index } => convert_content_delta(delta, index),
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
MessagesStreamEvent::MessageDelta { delta, usage } => {
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
let openai_usage: Option<Usage> = Some(usage.into());
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
finish_reason,
openai_usage,
))
}
MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
Some(FinishReason::Stop),
None,
)),
MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
id: "stream".to_string(),
object: Some("chat.completion.chunk".to_string()),
created: current_timestamp(),
model: "unknown".to_string(),
choices: vec![],
usage: None,
system_fingerprint: None,
service_tier: None,
}),
}
}
}
impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
type Error = TransformError;
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
match event {
ConverseStreamEvent::MessageStart(start_event) => {
let role = match start_event.role {
crate::apis::amazon_bedrock::ConversationRole::User => Role::User,
crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant,
};
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: Some(role),
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
))
}
ConverseStreamEvent::ContentBlockStart(start_event) => {
use crate::apis::amazon_bedrock::ContentBlockStart;
match start_event.start {
ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: start_event.content_block_index as u32,
id: Some(tool_use.tool_use_id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(tool_use.name),
arguments: Some("".to_string()),
}),
}]),
},
None,
None,
)),
}
}
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
use crate::apis::amazon_bedrock::ContentBlockDelta;
match delta_event.delta {
ContentBlockDelta::Text { text } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(text),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index: delta_event.content_block_index as u32,
id: None,
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(tool_use.input),
}),
}]),
},
None,
None,
)),
}
}
ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()),
ConverseStreamEvent::MessageStop(stop_event) => {
let finish_reason = match stop_event.stop_reason {
StopReason::EndTurn => FinishReason::Stop,
StopReason::ToolUse => FinishReason::ToolCalls,
StopReason::MaxTokens => FinishReason::Length,
StopReason::StopSequence => FinishReason::Stop,
StopReason::GuardrailIntervened => FinishReason::ContentFilter,
StopReason::ContentFiltered => FinishReason::ContentFilter,
};
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
Some(finish_reason),
None,
))
}
ConverseStreamEvent::Metadata(metadata_event) => {
let usage = Usage {
prompt_tokens: metadata_event.usage.input_tokens,
completion_tokens: metadata_event.usage.output_tokens,
total_tokens: metadata_event.usage.total_tokens,
prompt_tokens_details: None,
completion_tokens_details: None,
};
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
Some(usage),
))
}
// Error events - convert to empty chunks (errors should be handled elsewhere)
ConverseStreamEvent::InternalServerException(_)
| ConverseStreamEvent::ModelStreamErrorException(_)
| ConverseStreamEvent::ServiceUnavailableException(_)
| ConverseStreamEvent::ThrottlingException(_)
| ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()),
}
}
}
/// Convert content block start to OpenAI chunk
fn convert_content_block_start(
content_block: MessagesContentBlock,
index: u32,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match content_block {
MessagesContentBlock::Text { .. } => {
// No immediate output for text block start
Ok(create_empty_openai_chunk())
}
MessagesContentBlock::ToolUse { id, name, .. }
| MessagesContentBlock::ServerToolUse { id, name, .. }
| MessagesContentBlock::McpToolUse { id, name, .. } => {
// Tool use start → OpenAI chunk with tool_calls
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index,
id: Some(id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(name),
arguments: Some("".to_string()),
}),
}]),
},
None,
None,
))
}
_ => Err(TransformError::UnsupportedContent(
"Unsupported content block type in stream start".to_string(),
)),
}
}
/// Convert content delta to OpenAI chunk
fn convert_content_delta(
delta: MessagesContentDelta,
index: u32,
) -> Result<ChatCompletionsStreamResponse, TransformError> {
match delta {
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(text),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: Some(format!("thinking: {}", thinking)),
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)),
MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: Some(vec![ToolCallDelta {
index,
id: None,
call_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: Some(partial_json),
}),
}]),
},
None,
None,
)),
MessagesContentDelta::SignatureDelta { signature: _ } => {
// Signature delta is cryptographic verification metadata, not content
// Create an empty delta chunk to maintain stream continuity
Ok(create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
))
}
}
}
/// Helper to create OpenAI streaming chunk
fn create_openai_chunk(
id: &str,
model: &str,
delta: MessageDelta,
finish_reason: Option<FinishReason>,
usage: Option<Usage>,
) -> ChatCompletionsStreamResponse {
ChatCompletionsStreamResponse {
id: id.to_string(),
object: Some("chat.completion.chunk".to_string()),
created: current_timestamp(),
model: model.to_string(),
choices: vec![StreamChoice {
index: 0,
delta,
finish_reason,
logprobs: None,
}],
usage,
system_fingerprint: None,
service_tier: None,
}
}
/// Helper to create empty OpenAI streaming chunk
fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
create_openai_chunk(
"stream",
"unknown",
MessageDelta {
role: None,
content: None,
refusal: None,
function_call: None,
tool_calls: None,
},
None,
None,
)
}
// Stop Reason Conversions
impl Into<FinishReason> for MessagesStopReason {
fn into(self) -> FinishReason {
match self {
MessagesStopReason::EndTurn => FinishReason::Stop,
MessagesStopReason::MaxTokens => FinishReason::Length,
MessagesStopReason::StopSequence => FinishReason::Stop,
MessagesStopReason::ToolUse => FinishReason::ToolCalls,
MessagesStopReason::PauseTurn => FinishReason::Stop,
MessagesStopReason::Refusal => FinishReason::ContentFilter,
}
}
}
impl TryFrom<ChatCompletionsStreamResponse> for ResponsesAPIStreamEvent {
type Error = TransformError;
fn try_from(chunk: ChatCompletionsStreamResponse) -> Result<Self, TransformError> {
// Stateless conversion - just extract the delta information
// The buffer will manage state, item IDs, and sequence numbers
// Extract first choice if available
if let Some(choice) = chunk.choices.first() {
let delta = &choice.delta;
// Tool call with function name and/or arguments
if let Some(tool_calls) = &delta.tool_calls {
if let Some(tool_call) = tool_calls.first() {
// Extract call_id and name if available (metadata from initial event)
let call_id = tool_call.id.clone();
let function_name = tool_call.function.as_ref()
.and_then(|f| f.name.clone());
// Check if we have function metadata (name, id)
if let Some(function) = &tool_call.function {
// If we have arguments delta, return that
if let Some(args) = &function.arguments {
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
output_index: choice.index as i32,
item_id: "".to_string(), // Buffer will fill this
delta: args.clone(),
sequence_number: 0, // Buffer will fill this
call_id,
name: function_name,
});
}
// If we have function name but no arguments yet (initial tool call event)
// Return an empty arguments delta so the buffer knows to create the item
if function.name.is_some() {
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
output_index: choice.index as i32,
item_id: "".to_string(), // Buffer will fill this
delta: "".to_string(), // Empty delta signals this is the initial event
sequence_number: 0, // Buffer will fill this
call_id,
name: function_name,
});
}
}
}
}
// Text content delta
if let Some(content) = &delta.content {
if !content.is_empty() {
return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta {
item_id: "".to_string(), // Buffer will fill this
output_index: choice.index as i32,
content_index: 0,
delta: content.clone(),
logprobs: vec![],
obfuscation: None,
sequence_number: 0, // Buffer will fill this
});
}
}
// Handle finish_reason - this is a completion signal
// Return an empty delta that the buffer can use to detect completion
if choice.finish_reason.is_some() {
// Return a minimal text delta to signal completion
// The buffer will handle the finish_reason and generate response.completed
return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta {
item_id: "".to_string(), // Buffer will fill this
output_index: choice.index as i32,
content_index: 0,
delta: "".to_string(), // Empty delta signals completion
logprobs: vec![],
obfuscation: None,
sequence_number: 0, // Buffer will fill this
});
}
// Empty delta with role only (common at stream start)
if delta.role.is_some() {
// This is typically the first chunk establishing the assistant role
// Return an empty text delta that the buffer can use to initialize state
return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta {
item_id: "".to_string(),
output_index: choice.index as i32,
content_index: 0,
delta: "".to_string(),
logprobs: vec![],
obfuscation: None,
sequence_number: 0,
});
}
}
// Empty chunk or no convertible content (e.g., keep-alive chunks with delta: {})
// These are valid in OpenAI streaming and should be silently ignored
// Return error so the caller can skip these chunks without warnings
Err(TransformError::UnsupportedConversion(
"Empty or keep-alive chunk with no convertible content".to_string(),
))
}
}

View file

@ -2,26 +2,18 @@ use crate::metrics::Metrics;
use crate::stream_context::StreamContext;
use common::configuration::Configuration;
use common::configuration::Overrides;
use common::consts::OTEL_COLLECTOR_HTTP;
use common::consts::OTEL_POST_PATH;
use common::http::CallArgs;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::ratelimit;
use common::stats::Gauge;
use common::tracing::TraceData;
use log::trace;
use log::warn;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::rc::Rc;
use std::time::Duration;
use std::sync::{Arc, Mutex};
#[derive(Debug)]
pub struct CallContext {}
@ -31,7 +23,6 @@ pub struct FilterContext {
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, CallContext>>,
llm_providers: Option<Rc<LlmProviders>>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
}
@ -41,7 +32,6 @@ impl FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(Metrics::new()),
llm_providers: None,
traces_queue: Arc::new(Mutex::new(VecDeque::new())),
overrides: Rc::new(None),
}
}
@ -95,7 +85,6 @@ impl RootContext for FilterContext {
.as_ref()
.expect("LLM Providers must exist when Streams are being created"),
),
Arc::clone(&self.traces_queue),
Rc::clone(&self.overrides),
)))
}
@ -108,34 +97,6 @@ impl RootContext for FilterContext {
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
let _ = self.traces_queue.try_lock().map(|mut traces_queue| {
while let Some(trace) = traces_queue.pop_front() {
let trace_str = serde_json::to_string(&trace).unwrap();
trace!("trace details: {}", trace_str);
let call_args = CallArgs::new(
OTEL_COLLECTOR_HTTP,
OTEL_POST_PATH,
vec![
(":method", http::Method::POST.as_str()),
(":path", OTEL_POST_PATH),
(":authority", OTEL_COLLECTOR_HTTP),
("content-type", "application/json"),
],
Some(trace_str.as_bytes()),
vec![],
Duration::from_secs(60),
);
if let Err(error) = self.http_call(call_args, CallContext {}) {
warn!(
"failed to schedule http call to otel-collector: {:?}",
error
);
}
}
});
}
}
impl Context for FilterContext {

View file

@ -4,10 +4,8 @@ use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::collections::VecDeque;
use std::num::NonZero;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::metrics::Metrics;
@ -20,13 +18,13 @@ use common::errors::ServerError;
use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
use hermesllm::apis::anthropic::{MessagesContentBlock, MessagesStreamEvent};
use hermesllm::apis::sse::{SseEvent, SseStreamIter};
use hermesllm::clients::endpoints::SupportedAPIs;
use hermesllm::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
use hermesllm::apis::streaming_shapes::sse::{SseEvent, SseStreamBuffer, SseStreamBufferTrait};
use hermesllm::apis::streaming_shapes::sse_chunk_processor::SseChunkProcessor;
use hermesllm::clients::endpoints::SupportedAPIsFromClient;
use hermesllm::providers::response::ProviderResponse;
use hermesllm::providers::streaming_response::ProviderStreamResponse;
use hermesllm::{
DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType,
ProviderStreamResponseType,
@ -38,7 +36,7 @@ pub struct StreamContext {
streaming_response: bool,
response_tokens: usize,
/// The API that is requested by the client (before compatibility mapping)
client_api: Option<SupportedAPIs>,
client_api: Option<SupportedAPIsFromClient>,
/// The API that should be used for the upstream provider (after compatibility mapping)
resolved_api: Option<SupportedUpstreamAPIs>,
llm_providers: Rc<LlmProviders>,
@ -49,20 +47,20 @@ pub struct StreamContext {
ttft_time: Option<u128>,
traceparent: Option<String>,
request_body_sent_time: Option<u128>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
user_message: Option<String>,
upstream_status_code: Option<StatusCode>,
binary_frame_decoder: Option<BedrockBinaryFrameDecoder<bytes::BytesMut>>,
http_method: Option<String>,
http_protocol: Option<String>,
sse_buffer: Option<SseStreamBuffer>,
sse_chunk_processor: Option<SseChunkProcessor>,
}
impl StreamContext {
pub fn new(
metrics: Rc<Metrics>,
llm_providers: Rc<LlmProviders>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
) -> Self {
StreamContext {
@ -80,13 +78,14 @@ impl StreamContext {
ttft_duration: None,
traceparent: None,
ttft_time: None,
traces_queue,
request_body_sent_time: None,
user_message: None,
upstream_status_code: None,
binary_frame_decoder: None,
http_method: None,
http_protocol: None,
sse_buffer: None,
sse_chunk_processor: None,
}
}
@ -140,7 +139,7 @@ impl StreamContext {
));
info!(
"[ARCHGW_REQ_ID:{}] PROVIDER_SELECTION: Hint='{}' -> Selected='{}'",
"[PLANO_REQ_ID:{}] PROVIDER_SELECTION: Hint='{}' -> Selected='{}'",
self.request_identifier(),
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.unwrap_or("none".to_string()),
@ -172,7 +171,8 @@ impl StreamContext {
Some(
SupportedUpstreamAPIs::OpenAIChatCompletions(_)
| SupportedUpstreamAPIs::AmazonBedrockConverse(_)
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
)
| None => {
// OpenAI and default: use Authorization Bearer token
@ -224,7 +224,7 @@ impl StreamContext {
let token_count = tokenizer::token_count(model, json_string).unwrap_or(0);
debug!(
"[ARCHGW_REQ_ID:{}] TOKEN_COUNT: model='{}' input_tokens={}",
"[PLANO_REQ_ID:{}] TOKEN_COUNT: model='{}' input_tokens={}",
self.request_identifier(),
model,
token_count
@ -238,7 +238,7 @@ impl StreamContext {
// Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() {
info!(
"[ARCHGW_REQ_ID:{}] RATELIMIT_CHECK: model='{}' selector='{}:{}'",
"[PLANO_REQ_ID:{}] RATELIMIT_CHECK: model='{}' selector='{}:{}'",
self.request_identifier(),
model,
selector.key,
@ -251,7 +251,7 @@ impl StreamContext {
)?;
} else {
debug!(
"[ARCHGW_REQ_ID:{}] RATELIMIT_SKIP: model='{}' (no selector)",
"[PLANO_REQ_ID:{}] RATELIMIT_SKIP: model='{}' (no selector)",
self.request_identifier(),
model
);
@ -270,7 +270,7 @@ impl StreamContext {
Ok(duration) => {
let duration_ms = duration.as_millis();
info!(
"[ARCHGW_REQ_ID:{}] TIME_TO_FIRST_TOKEN: {}ms",
"[PLANO_REQ_ID:{}] TIME_TO_FIRST_TOKEN: {}ms",
self.request_identifier(),
duration_ms
);
@ -279,7 +279,7 @@ impl StreamContext {
}
Err(e) => {
warn!(
"[ARCHGW_REQ_ID:{}] TIME_MEASUREMENT_ERROR: {:?}",
"[PLANO_REQ_ID:{}] TIME_MEASUREMENT_ERROR: {:?}",
self.request_identifier(),
e
);
@ -295,7 +295,7 @@ impl StreamContext {
// Convert the duration to milliseconds
let duration_ms = duration.as_millis();
info!(
"[ARCHGW_REQ_ID:{}] REQUEST_COMPLETE: latency={}ms tokens={}",
"[PLANO_REQ_ID:{}] REQUEST_COMPLETE: latency={}ms tokens={}",
self.request_identifier(),
duration_ms,
self.response_tokens
@ -311,7 +311,7 @@ impl StreamContext {
self.metrics.time_per_output_token.record(tpot);
info!(
"[ARCHGW_REQ_ID:{}] TOKEN_THROUGHPUT: time_per_token={}ms tokens_per_second={}",
"[PLANO_REQ_ID:{}] TOKEN_THROUGHPUT: time_per_token={}ms tokens_per_second={}",
self.request_identifier(),
tpot,
1000 / tpot
@ -328,75 +328,13 @@ impl StreamContext {
self.metrics
.output_sequence_length
.record(self.response_tokens as u64);
if let Some(traceparent) = self.traceparent.as_ref() {
let current_time_ns = current_time_ns();
match Traceparent::try_from(traceparent.to_string()) {
Err(e) => {
warn!("traceparent header is invalid: {}", e);
}
Ok(traceparent) => {
let service_name = match &self.resolved_api {
Some(api) => {
let api_display = api.to_string();
format!("archgw.{}", api_display)
}
None => "archgw".to_string(),
};
let mut trace_data =
common::tracing::TraceData::new_with_service_name(service_name);
let mut llm_span = Span::new(
self.llm_provider().name.to_string(),
Some(traceparent.trace_id),
Some(traceparent.parent_id),
self.request_body_sent_time.unwrap(),
current_time_ns,
);
llm_span
.add_attribute("model".to_string(), self.llm_provider().name.to_string());
if let Some(user_message) = &self.user_message {
llm_span.add_attribute("message".to_string(), user_message.clone());
}
// Add HTTP attributes
if let Some(method) = &self.http_method {
llm_span.add_attribute("http.method".to_string(), method.clone());
}
if let Some(protocol) = &self.http_protocol {
llm_span.add_attribute("http.protocol".to_string(), protocol.clone());
}
if let Some(status_code) = &self.upstream_status_code {
llm_span.add_attribute(
"http.status_code".to_string(),
status_code.as_u16().to_string(),
);
}
// Add request ID attribute
llm_span
.add_attribute("http.request_id".to_string(), self.request_identifier());
if self.ttft_time.is_some() {
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
self.ttft_time.unwrap(),
));
}
trace_data.add_span(llm_span);
self.traces_queue.lock().unwrap().push_back(trace_data);
}
};
}
}
fn read_raw_response_body(&mut self, body_size: usize) -> Result<Vec<u8>, Action> {
if self.streaming_response {
let chunk_size = body_size;
debug!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_CHUNK: streaming=true chunk_size={}",
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_CHUNK: streaming=true chunk_size={}",
self.request_identifier(),
chunk_size
);
@ -404,7 +342,7 @@ impl StreamContext {
Some(chunk) => chunk,
None => {
warn!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_ERROR: empty chunk, size={}",
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_ERROR: empty chunk, size={}",
self.request_identifier(),
chunk_size
);
@ -414,7 +352,7 @@ impl StreamContext {
if streaming_chunk.len() != chunk_size {
warn!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_MISMATCH: expected={} actual={}",
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_MISMATCH: expected={} actual={}",
self.request_identifier(),
chunk_size,
streaming_chunk.len()
@ -426,7 +364,7 @@ impl StreamContext {
return Err(Action::Continue);
}
debug!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_COMPLETE: streaming=false body_size={}",
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_COMPLETE: streaming=false body_size={}",
self.request_identifier(),
body_size
);
@ -446,7 +384,7 @@ impl StreamContext {
provider_id: ProviderId,
) -> Result<Vec<u8>, Action> {
debug!(
"[ARCHGW_REQ_ID:{}] STREAMING_PROCESS: client={:?} provider_id={:?} chunk_size={}",
"[PLANO_REQ_ID:{}] STREAMING_PROCESS: client={:?} provider_id={:?} chunk_size={}",
self.request_identifier(),
self.client_api,
provider_id,
@ -466,30 +404,59 @@ impl StreamContext {
return self.handle_bedrock_binary_stream(body, &client_api, &upstream_api);
}
// Parse body into SSE iterator using TryFrom
let sse_iter: SseStreamIter<std::vec::IntoIter<String>> =
match SseStreamIter::try_from(body) {
Ok(iter) => iter,
// Initialize SSE chunk processor if not present
if self.sse_chunk_processor.is_none() {
self.sse_chunk_processor = Some(SseChunkProcessor::new());
}
// Initialize SSE buffer if not present
if self.sse_buffer.is_none() {
self.sse_buffer = match SseStreamBuffer::try_from((&client_api, &upstream_api))
{
Ok(buffer) => Some(buffer),
Err(e) => {
warn!("Failed to parse body into SSE iterator: {}", e);
warn!("Failed to create SSE buffer: {}", e);
return Err(Action::Continue);
}
};
}
let mut response_buffer = Vec::new();
// Process chunk through SSE processor (handles incomplete events)
let transformed_events = match self.sse_chunk_processor.as_mut() {
Some(processor) => {
let result = processor.process_chunk(body, &client_api, &upstream_api);
let has_buffered = processor.has_buffered_data();
let buffered_size = processor.buffered_size();
// Process each SSE event
for sse_event in sse_iter {
// Transform event if upstream API != client API
let transformed_event: SseEvent =
match SseEvent::try_from((sse_event, &client_api, &upstream_api)) {
Ok(event) => event,
match result {
Ok(events) => {
if has_buffered {
debug!(
"[PLANO_REQ_ID:{}] SSE_INCOMPLETE_BUFFERED: {} bytes buffered for next chunk",
self.request_identifier(),
buffered_size
);
}
events
}
Err(e) => {
warn!("Failed to transform SSE event: {}", e);
warn!(
"[PLANO_REQ_ID:{}] SSE_CHUNK_PROCESS_ERROR: {}",
self.request_identifier(),
e
);
return Err(Action::Continue);
}
};
}
}
None => {
warn!("SSE chunk processor unexpectedly missing");
return Err(Action::Continue);
}
};
// Process each successfully transformed SSE event
for transformed_event in transformed_events {
// Extract ProviderStreamResponse for processing (token counting, etc.)
if !transformed_event.is_done() && !transformed_event.is_event_only() {
match transformed_event.provider_response() {
@ -498,7 +465,7 @@ impl StreamContext {
if provider_response.is_final() {
debug!(
"[ARCHGW_REQ_ID:{}] STREAMING_FINAL_CHUNK: total_tokens={}",
"[PLANO_REQ_ID:{}] STREAMING_FINAL_CHUNK: total_tokens={}",
self.request_identifier(),
self.response_tokens
);
@ -508,7 +475,7 @@ impl StreamContext {
let estimated_tokens = content.len() / 4;
self.response_tokens += estimated_tokens.max(1);
debug!(
"[ARCHGW_REQ_ID:{}] STREAMING_TOKEN_UPDATE: delta_chars={} estimated_tokens={} total_tokens={}",
"[PLANO_REQ_ID:{}] STREAMING_TOKEN_UPDATE: delta_chars={} estimated_tokens={} total_tokens={}",
self.request_identifier(),
content.len(),
estimated_tokens.max(1),
@ -518,7 +485,7 @@ impl StreamContext {
}
Err(e) => {
warn!(
"[ARCHGW_REQ_ID:{}] STREAMING_CHUNK_ERROR: {}",
"[PLANO_REQ_ID:{}] STREAMING_CHUNK_ERROR: {}",
self.request_identifier(),
e
);
@ -527,12 +494,32 @@ impl StreamContext {
}
}
// Add transformed event to response buffer
let bytes: Vec<u8> = transformed_event.into();
response_buffer.extend_from_slice(&bytes);
// Add transformed event to buffer (buffer may inject lifecycle events)
if let Some(buffer) = self.sse_buffer.as_mut() {
buffer.add_transformed_event(transformed_event);
}
}
Ok(response_buffer)
// Get accumulated bytes from buffer and return
match self.sse_buffer.as_mut() {
Some(buffer) => {
let bytes = buffer.into_bytes();
if !bytes.is_empty() {
let content = String::from_utf8_lossy(&bytes);
debug!(
"[PLANO_REQ_ID:{}] UPSTREAM_TRANSFORMED_CLIENT_RESPONSE: size={} content={}",
self.request_identifier(),
bytes.len(),
content
);
}
Ok(bytes)
}
None => {
warn!("SSE buffer unexpectedly missing after initialization");
Err(Action::Continue)
}
}
}
None => {
warn!("Missing client_api for non-streaming response");
@ -544,7 +531,7 @@ impl StreamContext {
fn handle_bedrock_binary_stream(
&mut self,
body: &[u8],
client_api: &SupportedAPIs,
client_api: &SupportedAPIsFromClient,
upstream_api: &SupportedUpstreamAPIs,
) -> Result<Vec<u8>, Action> {
// Initialize decoder if not present
@ -552,87 +539,61 @@ impl StreamContext {
self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[]));
}
// Add incoming bytes to buffer
// Initialize SSE buffer if not present
if self.sse_buffer.is_none() {
self.sse_buffer = match SseStreamBuffer::try_from((client_api, upstream_api)) {
Ok(buffer) => Some(buffer),
Err(e) => {
warn!(
"[PLANO_REQ_ID:{}] BEDROCK_BUFFER_INIT_ERROR: {}",
self.request_identifier(),
e
);
return Err(Action::Continue);
}
};
}
// Add incoming bytes to decoder buffer
let decoder = self.binary_frame_decoder.as_mut().unwrap();
decoder.buffer_mut().extend_from_slice(body);
let mut response_buffer = Vec::new();
// Process all complete frames
loop {
let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame();
match decoded_frame {
Some(DecodedFrame::Complete(ref frame_ref)) => {
let frame = DecodedFrame::Complete(frame_ref.clone());
// Convert frame to provider response type
match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) {
Ok(provider_response) => {
self.record_ttft_if_needed();
// Handle ContentBlockStart and ContentBlockDelta events
match &provider_response {
ProviderStreamResponseType::MessagesStreamEvent(evt) => {
match evt {
MessagesStreamEvent::ContentBlockStart {
index, ..
} => {
// Mark that we've seen ContentBlockStart for this index
self.binary_frame_decoder
.as_mut()
.unwrap()
.set_content_block_start_sent(*index as i32);
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}",
self.request_identifier(),
*index
);
}
MessagesStreamEvent::ContentBlockDelta {
index, ..
} => {
// Check if ContentBlockStart was sent for this index
let needs_start = !self
.binary_frame_decoder
.as_ref()
.unwrap()
.has_content_block_start_been_sent(*index as i32);
if needs_start {
// Emit empty ContentBlockStart before delta
let content_block_start =
MessagesStreamEvent::ContentBlockStart {
index: *index,
content_block: MessagesContentBlock::Text {
text: String::new(),
cache_control: None,
},
};
let start_sse: String = content_block_start.into();
response_buffer
.extend_from_slice(start_sse.as_bytes());
// Mark that we've now sent it
self.binary_frame_decoder
.as_mut()
.unwrap()
.set_content_block_start_sent(*index as i32);
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}",
self.request_identifier(),
*index
);
}
}
_ => {}
}
}
_ => {}
// Track token usage
if let Some(content) = provider_response.content_delta() {
let estimated_tokens = content.len() / 4;
self.response_tokens += estimated_tokens.max(1);
debug!(
"[PLANO_REQ_ID:{}] BEDROCK_TOKEN_UPDATE: delta_chars={} estimated_tokens={} total_tokens={}",
self.request_identifier(),
content.len(),
estimated_tokens.max(1),
self.response_tokens
);
}
let sse_string: String = provider_response.into();
response_buffer.extend_from_slice(sse_string.as_bytes());
// Create SseEvent from provider response
let event = SseEvent::from_provider_response(provider_response);
// Add to buffer (buffer handles all shim logic including ContentBlockStart injection)
if let Some(buffer) = self.sse_buffer.as_mut() {
buffer.add_transformed_event(event);
}
}
Err(e) => {
warn!(
"[ARCHGW_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}",
"[PLANO_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}",
self.request_identifier(),
e
);
@ -642,7 +603,7 @@ impl StreamContext {
Some(DecodedFrame::Incomplete) => {
// Incomplete frame - buffer retains partial data, wait for more bytes
debug!(
"[ARCHGW_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data",
"[PLANO_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data",
self.request_identifier()
);
break;
@ -650,7 +611,7 @@ impl StreamContext {
None => {
// Decode error
warn!(
"[ARCHGW_REQ_ID:{}] BEDROCK_DECODE_ERROR",
"[PLANO_REQ_ID:{}] BEDROCK_DECODE_ERROR",
self.request_identifier()
);
return Err(Action::Continue);
@ -658,8 +619,29 @@ impl StreamContext {
}
}
// Return accumulated complete frames (may be empty if all frames incomplete)
Ok(response_buffer)
// Get accumulated bytes from buffer and return
match self.sse_buffer.as_mut() {
Some(buffer) => {
let bytes = buffer.into_bytes();
if !bytes.is_empty() {
let content = String::from_utf8_lossy(&bytes);
debug!(
"[PLANO_REQ_ID:{}] UPSTREAM_TRANSFORMED_CLIENT_RESPONSE: size={} content={}",
self.request_identifier(),
bytes.len(),
content
);
}
Ok(bytes)
}
None => {
warn!(
"[PLANO_REQ_ID:{}] BEDROCK_BUFFER_MISSING",
self.request_identifier()
);
Err(Action::Continue)
}
}
}
fn handle_non_streaming_response(
@ -668,7 +650,7 @@ impl StreamContext {
provider_id: ProviderId,
) -> Result<Vec<u8>, Action> {
debug!(
"[ARCHGW_REQ_ID:{}] NON_STREAMING_PROCESS: provider_id={:?} body_size={}",
"[PLANO_REQ_ID:{}] NON_STREAMING_PROCESS: provider_id={:?} body_size={}",
self.request_identifier(),
provider_id,
body.len()
@ -680,7 +662,7 @@ impl StreamContext {
Ok(response) => response,
Err(e) => {
warn!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_PARSE_ERROR: {} | body: {}",
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_PARSE_ERROR: {} | body: {}",
self.request_identifier(),
e,
String::from_utf8_lossy(body)
@ -695,7 +677,7 @@ impl StreamContext {
}
None => {
warn!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_ERROR: missing client_api",
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_ERROR: missing client_api",
self.request_identifier()
);
return Err(Action::Continue);
@ -707,7 +689,7 @@ impl StreamContext {
response.extract_usage_counts()
{
debug!(
"[ARCHGW_REQ_ID:{}] RESPONSE_USAGE: prompt_tokens={} completion_tokens={} total_tokens={}",
"[PLANO_REQ_ID:{}] RESPONSE_USAGE: prompt_tokens={} completion_tokens={} total_tokens={}",
self.request_identifier(),
prompt_tokens,
completion_tokens,
@ -716,7 +698,7 @@ impl StreamContext {
self.response_tokens = completion_tokens;
} else {
warn!(
"[ARCHGW_REQ_ID:{}] RESPONSE_USAGE: no usage information found",
"[PLANO_REQ_ID:{}] RESPONSE_USAGE: no usage information found",
self.request_identifier()
);
}
@ -724,7 +706,7 @@ impl StreamContext {
match serde_json::to_vec(&response) {
Ok(bytes) => {
debug!(
"[ARCHGW_REQ_ID:{}] CLIENT_RESPONSE_PAYLOAD: {}",
"[PLANO_REQ_ID:{}] CLIENT_RESPONSE_PAYLOAD: {}",
self.request_identifier(),
String::from_utf8_lossy(&bytes)
);
@ -782,13 +764,14 @@ impl HttpContext for StreamContext {
self.select_llm_provider();
// Check if this is a supported API endpoint
if SupportedAPIs::from_endpoint(&request_path).is_none() {
if SupportedAPIsFromClient::from_endpoint(&request_path).is_none() {
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
return Action::Continue;
}
// Get the SupportedApi for routing decisions
let supported_api: Option<SupportedAPIs> = SupportedAPIs::from_endpoint(&request_path);
let supported_api: Option<SupportedAPIsFromClient> =
SupportedAPIsFromClient::from_endpoint(&request_path);
self.client_api = supported_api;
// Debug: log provider, client API, resolved API, and request path
@ -800,7 +783,7 @@ impl HttpContext for StreamContext {
Some(provider_id.compatible_api_for_client(api, self.streaming_response));
debug!(
"[ARCHGW_REQ_ID:{}] ROUTING_INFO: provider='{}' client_api={:?} resolved_api={:?} request_path='{}'",
"[PLANO_REQ_ID:{}] ROUTING_INFO: provider='{}' client_api={:?} resolved_api={:?} request_path='{}'",
self.request_identifier(),
provider.to_provider_id(),
api,
@ -853,7 +836,7 @@ impl HttpContext for StreamContext {
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!(
"[ARCHGW_REQ_ID:{}] REQUEST_BODY_CHUNK: bytes={} end_stream={}",
"[PLANO_REQ_ID:{}] REQUEST_BODY_CHUNK: bytes={} end_stream={}",
self.request_identifier(),
body_size,
end_of_stream
@ -892,14 +875,14 @@ impl HttpContext for StreamContext {
let mut deserialized_client_request: ProviderRequestType = match self.client_api.as_ref() {
Some(the_client_api) => {
info!(
"[ARCHGW_REQ_ID:{}] CLIENT_REQUEST_RECEIVED: api={:?} body_size={}",
"[PLANO_REQ_ID:{}] CLIENT_REQUEST_RECEIVED: api={:?} body_size={}",
self.request_identifier(),
the_client_api,
body_bytes.len()
);
debug!(
"[ARCHGW_REQ_ID:{}] CLIENT_REQUEST_PAYLOAD: {}",
"[PLANO_REQ_ID:{}] CLIENT_REQUEST_PAYLOAD: {}",
self.request_identifier(),
String::from_utf8_lossy(&body_bytes)
);
@ -908,7 +891,7 @@ impl HttpContext for StreamContext {
Ok(deserialized) => deserialized,
Err(e) => {
warn!(
"[ARCHGW_REQ_ID:{}] CLIENT_REQUEST_PARSE_ERROR: {} | body: {}",
"[PLANO_REQ_ID:{}] CLIENT_REQUEST_PARSE_ERROR: {} | body: {}",
self.request_identifier(),
e,
String::from_utf8_lossy(&body_bytes)
@ -951,7 +934,7 @@ impl HttpContext for StreamContext {
"agent_orchestrator".to_string()
} else {
warn!(
"[ARCHGW_REQ_ID:{}] MODEL_RESOLUTION_ERROR: no model specified | req_model='{}' provider='{}' config_model={:?}",
"[PLANO_REQ_ID:{}] MODEL_RESOLUTION_ERROR: no model specified | req_model='{}' provider='{}' config_model={:?}",
self.request_identifier(),
model_requested,
self.llm_provider().name,
@ -980,7 +963,7 @@ impl HttpContext for StreamContext {
self.user_message = deserialized_client_request.get_recent_user_message();
info!(
"[ARCHGW_REQ_ID:{}] MODEL_RESOLUTION: req_model='{}' -> resolved_model='{}' provider='{}' streaming={}",
"[PLANO_REQ_ID:{}] MODEL_RESOLUTION: req_model='{}' -> resolved_model='{}' provider='{}' streaming={}",
self.request_identifier(),
model_requested,
resolved_model,
@ -1011,14 +994,14 @@ impl HttpContext for StreamContext {
match self.resolved_api.as_ref() {
Some(upstream) => {
info!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_TRANSFORM: client_api={:?} -> upstream_api={:?}",
"[PLANO_REQ_ID:{}] UPSTREAM_TRANSFORM: client_api={:?} -> upstream_api={:?}",
self.request_identifier(), self.client_api, upstream
);
match ProviderRequestType::try_from((deserialized_client_request, upstream)) {
Ok(request) => {
debug!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_REQUEST_PAYLOAD: {}",
"[PLANO_REQ_ID:{}] UPSTREAM_REQUEST_PAYLOAD: {}",
self.request_identifier(),
String::from_utf8_lossy(&request.to_bytes().unwrap_or_default())
);
@ -1069,7 +1052,7 @@ impl HttpContext for StreamContext {
self.upstream_status_code = StatusCode::from_u16(status_code).ok();
debug!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_STATUS: {}",
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_STATUS: {}",
self.request_identifier(),
status_code
);
@ -1096,7 +1079,7 @@ impl HttpContext for StreamContext {
let current_time = get_current_time().unwrap();
if end_of_stream && body_size == 0 {
debug!(
"[ARCHGW_REQ_ID:{}] RESPONSE_BODY_COMPLETE: total_bytes={}",
"[PLANO_REQ_ID:{}] RESPONSE_BODY_COMPLETE: total_bytes={}",
self.request_identifier(),
body_size
);
@ -1108,7 +1091,7 @@ impl HttpContext for StreamContext {
if let Some(status_code) = &self.upstream_status_code {
if status_code.is_client_error() || status_code.is_server_error() {
info!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_ERROR_RESPONSE: status={} body_size={}",
"[PLANO_REQ_ID:{}] UPSTREAM_ERROR_RESPONSE: status={} body_size={}",
self.request_identifier(),
status_code.as_u16(),
body_size
@ -1118,7 +1101,7 @@ impl HttpContext for StreamContext {
if body_size > 0 {
if let Ok(body) = self.read_raw_response_body(body_size) {
debug!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_ERROR_BODY: {}",
"[PLANO_REQ_ID:{}] UPSTREAM_ERROR_BODY: {}",
self.request_identifier(),
String::from_utf8_lossy(&body)
);
@ -1131,15 +1114,16 @@ impl HttpContext for StreamContext {
}
match self.client_api {
Some(SupportedAPIs::OpenAIChatCompletions(_)) => {}
Some(SupportedAPIs::AnthropicMessagesAPI(_)) => {}
Some(SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {}
Some(SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {}
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {}
_ => {
let api_info = match &self.client_api {
Some(api) => format!("{}", api),
None => "None".to_string(),
};
info!(
"[ARCHGW_REQ_ID:{}], UNSUPPORTED API: {}",
"[PLANO_REQ_ID:{}], UNSUPPORTED API: {}",
self.request_identifier(),
api_info
);
@ -1153,7 +1137,7 @@ impl HttpContext for StreamContext {
};
debug!(
"[ARCHGW_REQ_ID:{}] UPSTREAM_RAW_RESPONSE: body_size={} content={}",
"[PLANO_REQ_ID:{}] UPSTREAM_RAW_RESPONSE: body_size={} content={}",
self.request_identifier(),
body.len(),
String::from_utf8_lossy(&body)

Some files were not shown because too many files have changed in this diff Show more