mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Merge branch 'main' into musa/www
This commit is contained in:
commit
a6f9ca3594
189 changed files with 21252 additions and 14516 deletions
2
.github/workflows/e2e_archgw.yml
vendored
2
.github/workflows/e2e_archgw.yml
vendored
|
|
@ -30,7 +30,7 @@ jobs:
|
|||
|
||||
- name: build arch docker image
|
||||
run: |
|
||||
cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.18 -t katanemo/archgw:latest
|
||||
cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.22 -t katanemo/archgw:latest
|
||||
|
||||
- name: start archgw
|
||||
env:
|
||||
|
|
|
|||
46
.github/workflows/e2e_model_server.yml
vendored
46
.github/workflows/e2e_model_server.yml
vendored
|
|
@ -1,46 +0,0 @@
|
|||
name: e2e model server tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
e2e_model_server_tests:
|
||||
runs-on: ubuntu-latest-m
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./tests/modelserver
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: "pip" # auto-caches based on requirements files
|
||||
|
||||
- name: install poetry
|
||||
run: |
|
||||
export POETRY_VERSION=2.2.1
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
|
||||
- name: install model server and start it
|
||||
run: |
|
||||
cd ../../model_server/ && poetry install && poetry run archgw_modelserver start
|
||||
|
||||
- name: install test dependencies
|
||||
run: |
|
||||
poetry install
|
||||
|
||||
- name: run tests
|
||||
run: |
|
||||
poetry run pytest
|
||||
|
|
@ -24,7 +24,7 @@ jobs:
|
|||
|
||||
- name: build arch docker image
|
||||
run: |
|
||||
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.18
|
||||
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.22
|
||||
|
||||
- name: install poetry
|
||||
run: |
|
||||
|
|
@ -40,11 +40,10 @@ jobs:
|
|||
curl --location --remote-name https://github.com/Orange-OpenSource/hurl/releases/download/4.0.0/hurl_4.0.0_amd64.deb
|
||||
sudo dpkg -i hurl_4.0.0_amd64.deb
|
||||
|
||||
- name: install model server, arch gateway and test dependencies
|
||||
- name: install arch gateway and test dependencies
|
||||
run: |
|
||||
source venv/bin/activate
|
||||
cd model_server/ && echo "installing model server" && poetry install
|
||||
cd ../arch/tools && echo "installing archgw cli" && poetry install
|
||||
cd arch/tools && echo "installing archgw cli" && poetry install
|
||||
cd ../../demos/shared/test_runner && echo "installing test dependencies" && poetry install
|
||||
|
||||
- name: run demo tests
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ jobs:
|
|||
|
||||
- name: build arch docker image
|
||||
run: |
|
||||
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.18
|
||||
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.22
|
||||
|
||||
- name: install poetry
|
||||
run: |
|
||||
|
|
@ -40,11 +40,10 @@ jobs:
|
|||
curl --location --remote-name https://github.com/Orange-OpenSource/hurl/releases/download/4.0.0/hurl_4.0.0_amd64.deb
|
||||
sudo dpkg -i hurl_4.0.0_amd64.deb
|
||||
|
||||
- name: install model server, arch gateway and test dependencies
|
||||
- name: install arch gateway and test dependencies
|
||||
run: |
|
||||
source venv/bin/activate
|
||||
cd model_server/ && echo "installing model server" && poetry install
|
||||
cd ../arch/tools && echo "installing archgw cli" && poetry install
|
||||
cd arch/tools && echo "installing archgw cli" && poetry install
|
||||
cd ../../demos/shared/test_runner && echo "installing test dependencies" && poetry install
|
||||
|
||||
- name: run demo tests
|
||||
|
|
|
|||
27
.github/workflows/e2e_tests.yml
vendored
27
.github/workflows/e2e_tests.yml
vendored
|
|
@ -14,6 +14,32 @@ jobs:
|
|||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# --- Disk inspection & cleanup section (added to free space on GitHub runner) ---
|
||||
- name: Check disk usage before cleanup
|
||||
run: |
|
||||
echo "=== Disk usage before cleanup ==="
|
||||
df -h
|
||||
echo "=== Repo size ==="
|
||||
du -sh .
|
||||
|
||||
- name: Free disk space on runner
|
||||
run: |
|
||||
echo "=== Cleaning preinstalled SDKs and toolchains to free space ==="
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
# If you still hit disk issues, uncomment this to free more space.
|
||||
# It just removes cached tool versions; setup-python will re-download what it needs.
|
||||
# sudo rm -rf /opt/hostedtoolcache || true
|
||||
|
||||
echo "=== Docker cleanup (before our builds/compose) ==="
|
||||
docker system prune -af || true
|
||||
docker volume prune -f || true
|
||||
|
||||
echo "=== Disk usage after cleanup ==="
|
||||
df -h
|
||||
# --- End disk cleanup section ---
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
|
|
@ -33,6 +59,7 @@ jobs:
|
|||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AWS_BEARER_TOKEN_BEDROCK: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }}
|
||||
GROK_API_KEY : ${{ secrets.GROK_API_KEY }}
|
||||
run: |
|
||||
python -mvenv venv
|
||||
source venv/bin/activate && cd tests/e2e && bash run_e2e_tests.sh
|
||||
|
|
|
|||
45
.github/workflows/model-server-tests.yml
vendored
45
.github/workflows/model-server-tests.yml
vendored
|
|
@ -1,45 +0,0 @@
|
|||
name: model server tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main # Run tests on pushes to the main branch
|
||||
pull_request:
|
||||
branches:
|
||||
- main # Run tests on pull requests to the main branch
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Step 1: Check out the code from your repository
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Step 2: Set up Python (specify the version)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
# Step 3: Install Poetry
|
||||
- name: Install Poetry
|
||||
run: |
|
||||
export POETRY_VERSION=2.2.1
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
|
||||
# Step 4: Install dependencies using Poetry
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd model_server
|
||||
poetry install
|
||||
|
||||
# Step 5: Set PYTHONPATH and run tests
|
||||
- name: Run model server tests with pytest
|
||||
env:
|
||||
PYTHONPATH: model_server # Ensure the app's path is available
|
||||
run: |
|
||||
cd model_server
|
||||
poetry run pytest
|
||||
3
.github/workflows/rust_tests.yml
vendored
3
.github/workflows/rust_tests.yml
vendored
|
|
@ -29,3 +29,6 @@ jobs:
|
|||
|
||||
- name: Run unit tests
|
||||
run: cargo test --lib
|
||||
|
||||
- name: Run trace integration tests
|
||||
run: cargo test -p common --features trace-collection traces::tests::trace_integration_test
|
||||
|
|
|
|||
2
.github/workflows/validate_arch_config.yml
vendored
2
.github/workflows/validate_arch_config.yml
vendored
|
|
@ -24,7 +24,7 @@ jobs:
|
|||
|
||||
- name: build arch docker image
|
||||
run: |
|
||||
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.18
|
||||
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.22
|
||||
|
||||
- name: validate arch config
|
||||
run: |
|
||||
|
|
|
|||
5
.gitignore
vendored
5
.gitignore
vendored
|
|
@ -111,11 +111,6 @@ venv.bak/
|
|||
arch/tools/config
|
||||
arch/tools/build
|
||||
|
||||
# Archgw - model_server
|
||||
model_server/venv_model_server
|
||||
model_server/build
|
||||
model_server/dist
|
||||
|
||||
# Archgw - Docs
|
||||
docs/build/
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ _Arch is a models-native (edge and service) proxy server for agents._<br><br>
|
|||
</div>
|
||||
|
||||
# About The Latest Release:
|
||||
[0.3.18] [Preference-aware multi LLM routing for Claude Code 2.0](demos/use_cases/claude_code_router/README.md) <br><img src="docs/source/_static/img/claude_code_router.png" alt="high-level network architecture for ArchGW" width="50%">
|
||||
[0.3.20] [Preference-aware multi LLM routing for Claude Code 2.0](demos/use_cases/claude_code_router/README.md) <br><img src="docs/source/_static/img/claude_code_router.png" alt="high-level network architecture for ArchGW" width="50%">
|
||||
|
||||
|
||||
# Overview
|
||||
|
|
@ -87,7 +87,7 @@ Arch's CLI allows you to manage and interact with the Arch gateway efficiently.
|
|||
```console
|
||||
$ python3.12 -m venv venv
|
||||
$ source venv/bin/activate # On Windows, use: venv\Scripts\activate
|
||||
$ pip install archgw==0.3.18
|
||||
$ pip install archgw==0.3.22
|
||||
```
|
||||
|
||||
### Use Arch as a LLM Router
|
||||
|
|
@ -276,7 +276,7 @@ endpoints:
|
|||
```sh
|
||||
|
||||
$ archgw up arch_config.yaml
|
||||
2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.3.18
|
||||
2024-12-05 16:56:27,979 - cli.main - INFO - Starting archgw cli version: 0.3.22
|
||||
2024-12-05 16:56:28,485 - cli.utils - INFO - Schema validation successful!
|
||||
2024-12-05 16:56:28,485 - cli.main - INFO - Starting arch model server and arch gateway
|
||||
2024-12-05 16:56:51,647 - cli.core - INFO - Container is healthy!
|
||||
|
|
|
|||
|
|
@ -14,6 +14,38 @@ properties:
|
|||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
url:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- url
|
||||
filters:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
url:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
enum:
|
||||
- mcp
|
||||
transport:
|
||||
type: string
|
||||
enum:
|
||||
- streamable-http
|
||||
tool:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- url
|
||||
listeners:
|
||||
oneOf:
|
||||
- type: array
|
||||
|
|
@ -331,6 +363,31 @@ properties:
|
|||
model:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
state_storage:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum:
|
||||
- memory
|
||||
- postgres
|
||||
connection_string:
|
||||
type: string
|
||||
description: Required when type is postgres. Supports environment variable substitution using $VAR or ${VAR} syntax.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
# Note: connection_string is conditionally required based on type
|
||||
# If type is 'postgres', connection_string must be provided
|
||||
# If type is 'memory', connection_string is not needed
|
||||
allOf:
|
||||
- if:
|
||||
properties:
|
||||
type:
|
||||
const: postgres
|
||||
then:
|
||||
required:
|
||||
- connection_string
|
||||
prompt_guards:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
|||
|
|
@ -22,4 +22,3 @@ services:
|
|||
- OPENAI_API_KEY=${OPENAI_API_KEY:?error}
|
||||
- MISTRAL_API_KEY=${MISTRAL_API_KEY:?error}
|
||||
- OTEL_TRACING_HTTP_ENDPOINT=http://host.docker.internal:4318/v1/traces
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-51000}
|
||||
|
|
|
|||
|
|
@ -51,11 +51,11 @@ static_resources:
|
|||
envoy_grpc:
|
||||
cluster_name: opentelemetry_collector
|
||||
timeout: 0.250s
|
||||
service_name: archgw(inbound)
|
||||
service_name: plano(inbound)
|
||||
random_sampling:
|
||||
value: {{ arch_tracing.random_sampling }}
|
||||
{% endif %}
|
||||
stat_prefix: ingress_traffic
|
||||
stat_prefix: plano(inbound)
|
||||
codec_type: AUTO
|
||||
scheme_header_transformation:
|
||||
scheme_to_overwrite: https
|
||||
|
|
@ -95,21 +95,6 @@ static_resources:
|
|||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
|
||||
generate_request_id: true
|
||||
tracing:
|
||||
provider:
|
||||
name: envoy.tracers.opentelemetry
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
|
||||
grpc_service:
|
||||
envoy_grpc:
|
||||
cluster_name: opentelemetry_collector
|
||||
timeout: 0.250s
|
||||
service_name: ingress_traffic
|
||||
random_sampling:
|
||||
value: {{ arch_tracing.random_sampling }}
|
||||
{% endif %}
|
||||
stat_prefix: ingress_traffic
|
||||
codec_type: AUTO
|
||||
scheme_header_transformation:
|
||||
|
|
@ -221,7 +206,7 @@ static_resources:
|
|||
- name: outbound_api_traffic
|
||||
address:
|
||||
socket_address:
|
||||
address: 0.0.0.0
|
||||
address: 127.0.0.1
|
||||
port_value: 11000
|
||||
traffic_direction: OUTBOUND
|
||||
filter_chains:
|
||||
|
|
@ -229,21 +214,21 @@ static_resources:
|
|||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
|
||||
generate_request_id: true
|
||||
tracing:
|
||||
provider:
|
||||
name: envoy.tracers.opentelemetry
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
|
||||
grpc_service:
|
||||
envoy_grpc:
|
||||
cluster_name: opentelemetry_collector
|
||||
timeout: 0.250s
|
||||
service_name: outbound_api_traffic
|
||||
random_sampling:
|
||||
value: {{ arch_tracing.random_sampling }}
|
||||
{% endif %}
|
||||
# {% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
|
||||
# generate_request_id: true
|
||||
# tracing:
|
||||
# provider:
|
||||
# name: envoy.tracers.opentelemetry
|
||||
# typed_config:
|
||||
# "@type": type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig
|
||||
# grpc_service:
|
||||
# envoy_grpc:
|
||||
# cluster_name: opentelemetry_collector
|
||||
# timeout: 0.250s
|
||||
# service_name: tools
|
||||
# random_sampling:
|
||||
# value: {{ arch_tracing.random_sampling }}
|
||||
# {% endif %}
|
||||
stat_prefix: outbound_api_traffic
|
||||
codec_type: AUTO
|
||||
scheme_header_transformation:
|
||||
|
|
@ -262,19 +247,16 @@ static_resources:
|
|||
domains:
|
||||
- "*"
|
||||
routes:
|
||||
{% for internal_cluster in ["arch_fc", "model_server"] %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
headers:
|
||||
- name: "x-arch-upstream"
|
||||
string_match:
|
||||
exact: {{ internal_cluster }}
|
||||
exact: bright_staff
|
||||
route:
|
||||
auto_host_rewrite: true
|
||||
cluster: {{ internal_cluster }}
|
||||
cluster: bright_staff
|
||||
timeout: 300s
|
||||
{% endfor %}
|
||||
|
||||
{% for cluster_name, cluster in arch_clusters.items() %}
|
||||
- match:
|
||||
prefix: "/"
|
||||
|
|
@ -317,7 +299,7 @@ static_resources:
|
|||
envoy_grpc:
|
||||
cluster_name: opentelemetry_collector
|
||||
timeout: 0.250s
|
||||
service_name: arch_gateway
|
||||
service_name: plano(inbound)
|
||||
random_sampling:
|
||||
value: {{ arch_tracing.random_sampling }}
|
||||
{% endif %}
|
||||
|
|
@ -416,7 +398,7 @@ static_resources:
|
|||
envoy_grpc:
|
||||
cluster_name: opentelemetry_collector
|
||||
timeout: 0.250s
|
||||
service_name: archgw(outbound)
|
||||
service_name: plano(outbound)
|
||||
random_sampling:
|
||||
value: {{ arch_tracing.random_sampling }}
|
||||
{% endif %}
|
||||
|
|
@ -487,6 +469,50 @@ static_resources:
|
|||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
|
||||
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
|
||||
- name: otel_collector_proxy
|
||||
address:
|
||||
socket_address:
|
||||
address: 127.0.0.1
|
||||
port_value: 9903
|
||||
traffic_direction: OUTBOUND
|
||||
filter_chains:
|
||||
- filters:
|
||||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
stat_prefix: otel_proxy
|
||||
codec_type: AUTO
|
||||
access_log:
|
||||
- name: envoy.access_loggers.file
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
|
||||
path: "/var/log/access_otel.log"
|
||||
format: |
|
||||
[%START_TIME%] "%REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% %PROTOCOL%" %RESPONSE_CODE% %RESPONSE_FLAGS% %BYTES_RECEIVED% %BYTES_SENT% %DURATION% %RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)% "%REQ(X-FORWARDED-FOR)%" "%REQ(USER-AGENT)%" "%REQ(X-REQUEST-ID)%" "%REQ(:AUTHORITY)%" "%UPSTREAM_HOST%" "%UPSTREAM_CLUSTER%"
|
||||
route_config:
|
||||
name: otel_route
|
||||
virtual_hosts:
|
||||
- name: otel_backend
|
||||
domains: ["*"]
|
||||
routes:
|
||||
- match:
|
||||
prefix: "/v1/traces"
|
||||
route:
|
||||
cluster: opentelemetry_collector_http
|
||||
timeout: 5s
|
||||
retry_policy:
|
||||
retry_on: "5xx,connect-failure,refused-stream,reset"
|
||||
num_retries: 3
|
||||
per_try_timeout: 2s
|
||||
host_selection_retry_max_attempts: 5
|
||||
retriable_status_codes: [500, 502, 503, 504]
|
||||
http_filters:
|
||||
- name: envoy.filters.http.router
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
{% endif %}
|
||||
|
||||
- name: egress_traffic_llm
|
||||
address:
|
||||
socket_address:
|
||||
|
|
@ -599,7 +625,7 @@ static_resources:
|
|||
clusters:
|
||||
|
||||
- name: arch
|
||||
connect_timeout: 0.5s
|
||||
connect_timeout: 5s
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
|
|
@ -868,24 +894,6 @@ static_resources:
|
|||
tls_params:
|
||||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
|
||||
{% for internal_cluster in ["arch_fc", "model_server"] %}
|
||||
- name: {{ internal_cluster }}
|
||||
connect_timeout: 0.5s
|
||||
type: STRICT_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: {{ internal_cluster }}
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: host.docker.internal
|
||||
port_value: 51000
|
||||
hostname: {{ internal_cluster }}
|
||||
{% endfor %}
|
||||
- name: mistral_7b_instruct
|
||||
connect_timeout: 0.5s
|
||||
type: STRICT_DNS
|
||||
|
|
@ -1035,7 +1043,6 @@ static_resources:
|
|||
port_value: 12001
|
||||
hostname: arch_listener_llm
|
||||
|
||||
|
||||
{% if "random_sampling" in arch_tracing and arch_tracing["random_sampling"] > 0 %}
|
||||
- name: opentelemetry_collector
|
||||
type: STRICT_DNS
|
||||
|
|
@ -1069,4 +1076,19 @@ static_resources:
|
|||
socket_address:
|
||||
address: host.docker.internal
|
||||
port_value: 4318
|
||||
# Circuit breaker configuration to prevent overwhelming OTEL collector
|
||||
circuit_breakers:
|
||||
thresholds:
|
||||
- priority: DEFAULT
|
||||
max_connections: 100
|
||||
max_pending_requests: 100
|
||||
max_requests: 100
|
||||
max_retries: 3
|
||||
# Health checking and outlier detection
|
||||
outlier_detection:
|
||||
consecutive_5xx: 5
|
||||
interval: 10s
|
||||
base_ejection_time: 30s
|
||||
max_ejection_percent: 50
|
||||
enforcing_consecutive_5xx: 100
|
||||
{% endif %}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
nodaemon=true
|
||||
|
||||
[program:brightstaff]
|
||||
command=sh -c "RUST_LOG=info /app/brightstaff 2>&1 | tee /var/log/brightstaff.log | while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done"
|
||||
command=sh -c "envsubst < /app/arch_config_rendered.yaml > /app/arch_config_rendered.env_sub.yaml && RUST_LOG=debug ARCH_CONFIG_PATH_RENDERED=/app/arch_config_rendered.env_sub.yaml /app/brightstaff 2>&1 | tee /var/log/brightstaff.log | while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done"
|
||||
stdout_logfile=/dev/stdout
|
||||
redirect_stderr=true
|
||||
stdout_logfile_maxbytes=0
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ source venv/bin/activate
|
|||
|
||||
### Step 3: Run the build script
|
||||
```bash
|
||||
pip install archgw==0.3.18
|
||||
pip install archgw==0.3.22
|
||||
```
|
||||
|
||||
## Uninstall Instructions: archgw CLI
|
||||
|
|
@ -56,15 +56,8 @@ poetry install
|
|||
archgw build
|
||||
```
|
||||
|
||||
### Step 5: download models
|
||||
This will help download models so model_server can load faster. This should be done once.
|
||||
|
||||
```bash
|
||||
archgw download-models
|
||||
```
|
||||
|
||||
### Logs
|
||||
`archgw` command can also view logs from gateway and model_server. Use following command to view logs,
|
||||
`archgw` command can also view logs from the gateway. Use following command to view logs,
|
||||
|
||||
```bash
|
||||
archgw logs --follow
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ SUPPORTED_PROVIDERS_WITH_BASE_URL = [
|
|||
"ollama",
|
||||
"qwen",
|
||||
"amazon_bedrock",
|
||||
"arch",
|
||||
]
|
||||
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [
|
||||
"arch",
|
||||
"deepseek",
|
||||
"groq",
|
||||
"mistral",
|
||||
|
|
@ -101,8 +101,17 @@ def validate_and_render_schema():
|
|||
|
||||
# Process agents section and convert to endpoints
|
||||
agents = config_yaml.get("agents", [])
|
||||
for agent in agents:
|
||||
filters = config_yaml.get("filters", [])
|
||||
agents_combined = agents + filters
|
||||
agent_id_keys = set()
|
||||
|
||||
for agent in agents_combined:
|
||||
agent_id = agent.get("id")
|
||||
if agent_id in agent_id_keys:
|
||||
raise Exception(
|
||||
f"Duplicate agent id {agent_id}, please provide unique id for each agent"
|
||||
)
|
||||
agent_id_keys.add(agent_id)
|
||||
agent_endpoint = agent.get("url")
|
||||
|
||||
if agent_id and agent_endpoint:
|
||||
|
|
@ -304,6 +313,16 @@ def validate_and_render_schema():
|
|||
}
|
||||
)
|
||||
|
||||
# Always add arch-function model provider if not already defined
|
||||
if "arch-function" not in model_provider_name_set:
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "arch-function",
|
||||
"provider_interface": "arch",
|
||||
"model": "Arch-Function",
|
||||
}
|
||||
)
|
||||
|
||||
config_yaml["model_providers"] = deepcopy(updated_model_providers)
|
||||
|
||||
listeners_with_provider = 0
|
||||
|
|
|
|||
|
|
@ -1,13 +1,5 @@
|
|||
import os
|
||||
|
||||
|
||||
KATANEMO_DOCKERHUB_REPO = "katanemo/archgw"
|
||||
KATANEMO_LOCAL_MODEL_LIST = [
|
||||
"katanemo/Arch-Guard",
|
||||
]
|
||||
SERVICE_NAME_ARCHGW = "archgw"
|
||||
SERVICE_NAME_MODEL_SERVER = "model_server"
|
||||
SERVICE_ALL = "all"
|
||||
MODEL_SERVER_LOG_FILE = "~/archgw_logs/modelserver.log"
|
||||
ARCHGW_DOCKER_NAME = "archgw"
|
||||
ARCHGW_DOCKER_IMAGE = os.getenv("ARCHGW_DOCKER_IMAGE", "katanemo/archgw:0.3.18")
|
||||
ARCHGW_DOCKER_IMAGE = os.getenv("ARCHGW_DOCKER_IMAGE", "katanemo/archgw:0.3.22")
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@ from cli.utils import convert_legacy_listeners, getLogger
|
|||
from cli.consts import (
|
||||
ARCHGW_DOCKER_IMAGE,
|
||||
ARCHGW_DOCKER_NAME,
|
||||
KATANEMO_LOCAL_MODEL_LIST,
|
||||
)
|
||||
from huggingface_hub import snapshot_download
|
||||
import subprocess
|
||||
from cli.docker_cli import (
|
||||
docker_container_status,
|
||||
|
|
@ -144,49 +142,6 @@ def stop_docker_container(service=ARCHGW_DOCKER_NAME):
|
|||
log.info(f"Failed to shut down services: {str(e)}")
|
||||
|
||||
|
||||
def download_models_from_hf():
|
||||
for model in KATANEMO_LOCAL_MODEL_LIST:
|
||||
log.info(f"Downloading model: {model}")
|
||||
snapshot_download(repo_id=model)
|
||||
|
||||
|
||||
def start_arch_modelserver(foreground):
|
||||
"""
|
||||
Start the model server. This assumes that the archgw_modelserver package is installed locally
|
||||
|
||||
"""
|
||||
try:
|
||||
log.info("archgw_modelserver restart")
|
||||
if foreground:
|
||||
subprocess.run(
|
||||
["archgw_modelserver", "start", "--foreground"],
|
||||
check=True,
|
||||
)
|
||||
else:
|
||||
subprocess.run(
|
||||
["archgw_modelserver", "start"],
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def stop_arch_modelserver():
|
||||
"""
|
||||
Stop the model server. This assumes that the archgw_modelserver package is installed locally
|
||||
|
||||
"""
|
||||
try:
|
||||
subprocess.run(
|
||||
["archgw_modelserver", "stop"],
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def start_cli_agent(arch_config_file=None, settings_json="{}"):
|
||||
"""Start a CLI client connected to Arch."""
|
||||
|
||||
|
|
|
|||
|
|
@ -20,20 +20,14 @@ from cli.utils import (
|
|||
find_config_file,
|
||||
)
|
||||
from cli.core import (
|
||||
start_arch_modelserver,
|
||||
stop_arch_modelserver,
|
||||
start_arch,
|
||||
stop_docker_container,
|
||||
download_models_from_hf,
|
||||
start_cli_agent,
|
||||
)
|
||||
from cli.consts import (
|
||||
ARCHGW_DOCKER_IMAGE,
|
||||
ARCHGW_DOCKER_NAME,
|
||||
KATANEMO_DOCKERHUB_REPO,
|
||||
SERVICE_NAME_ARCHGW,
|
||||
SERVICE_NAME_MODEL_SERVER,
|
||||
SERVICE_ALL,
|
||||
)
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
|
@ -47,9 +41,8 @@ logo = r"""
|
|||
|
||||
"""
|
||||
|
||||
# Command to build archgw and model_server Docker images
|
||||
# Command to build archgw Docker images
|
||||
ARCHGW_DOCKERFILE = "./arch/Dockerfile"
|
||||
MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
|
||||
|
||||
|
||||
def get_version():
|
||||
|
|
@ -60,18 +53,6 @@ def get_version():
|
|||
return "version not found"
|
||||
|
||||
|
||||
def verify_service_name(service):
|
||||
"""Verify if the service name is valid."""
|
||||
if service not in [
|
||||
SERVICE_NAME_ARCHGW,
|
||||
SERVICE_NAME_MODEL_SERVER,
|
||||
SERVICE_ALL,
|
||||
]:
|
||||
print(f"Error: Invalid service {service}. Exiting")
|
||||
sys.exit(1)
|
||||
return True
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.option("--version", is_flag=True, help="Show the archgw cli version and exit.")
|
||||
@click.pass_context
|
||||
|
|
@ -89,17 +70,11 @@ def main(ctx, version):
|
|||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--service",
|
||||
default=SERVICE_ALL,
|
||||
help="Optional parameter to specify which service to build. Options are model_server, archgw",
|
||||
)
|
||||
def build(service):
|
||||
def build():
|
||||
"""Build Arch from source. Must be in root of cloned repo."""
|
||||
verify_service_name(service)
|
||||
|
||||
# Check if /arch/Dockerfile exists
|
||||
if service == SERVICE_NAME_ARCHGW or service == SERVICE_ALL:
|
||||
if os.path.exists(ARCHGW_DOCKERFILE):
|
||||
if os.path.exists(ARCHGW_DOCKERFILE):
|
||||
click.echo("Building archgw image...")
|
||||
try:
|
||||
|
|
@ -110,8 +85,6 @@ def build(service):
|
|||
"-f",
|
||||
ARCHGW_DOCKERFILE,
|
||||
"-t",
|
||||
f"{KATANEMO_DOCKERHUB_REPO}:latest",
|
||||
"-t",
|
||||
f"{ARCHGW_DOCKER_IMAGE}",
|
||||
".",
|
||||
"--add-host=host.docker.internal:host-gateway",
|
||||
|
|
@ -128,57 +101,20 @@ def build(service):
|
|||
|
||||
click.echo("archgw image built successfully.")
|
||||
|
||||
"""Install the model server dependencies using Poetry."""
|
||||
if service == SERVICE_NAME_MODEL_SERVER or service == SERVICE_ALL:
|
||||
# Check if pyproject.toml exists
|
||||
if os.path.exists(MODEL_SERVER_BUILD_FILE):
|
||||
click.echo("Installing model server dependencies with Poetry...")
|
||||
try:
|
||||
subprocess.run(
|
||||
["poetry", "install", "--no-cache"],
|
||||
cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE),
|
||||
check=True,
|
||||
)
|
||||
click.echo("Model server dependencies installed successfully.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"Error installing model server dependencies: {e}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
click.echo(f"Error: pyproject.toml not found in {MODEL_SERVER_BUILD_FILE}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("file", required=False) # Optional file argument
|
||||
@click.option(
|
||||
"--path", default=".", help="Path to the directory containing arch_config.yaml"
|
||||
)
|
||||
@click.option(
|
||||
"--service",
|
||||
default=SERVICE_ALL,
|
||||
help="Service to start. Options are model_server, archgw.",
|
||||
)
|
||||
@click.option(
|
||||
"--foreground",
|
||||
default=False,
|
||||
help="Run Arch in the foreground. Default is False",
|
||||
is_flag=True,
|
||||
)
|
||||
def up(file, path, service, foreground):
|
||||
def up(file, path, foreground):
|
||||
"""Starts Arch."""
|
||||
verify_service_name(service)
|
||||
|
||||
if service == SERVICE_ALL and foreground:
|
||||
# foreground can only be specified when starting individual services
|
||||
log.info("foreground flag is only supported for individual services. Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
if service == SERVICE_NAME_MODEL_SERVER:
|
||||
log.info("Download models from HuggingFace...")
|
||||
download_models_from_hf()
|
||||
start_arch_modelserver(foreground)
|
||||
return
|
||||
|
||||
# Use the utility function to find config file
|
||||
arch_config_file = find_config_file(path, file)
|
||||
|
||||
|
|
@ -202,7 +138,6 @@ def up(file, path, service, foreground):
|
|||
# Set the ARCH_CONFIG_FILE environment variable
|
||||
env_stage = {
|
||||
"OTEL_TRACING_HTTP_ENDPOINT": "http://host.docker.internal:4318/v1/traces",
|
||||
"MODEL_SERVER_PORT": os.getenv("MODEL_SERVER_PORT", "51000"),
|
||||
}
|
||||
env = os.environ.copy()
|
||||
# Remove PATH variable if present
|
||||
|
|
@ -242,40 +177,13 @@ def up(file, path, service, foreground):
|
|||
env_stage[access_key] = env_file_dict[access_key]
|
||||
|
||||
env.update(env_stage)
|
||||
|
||||
if service == SERVICE_NAME_ARCHGW:
|
||||
start_arch(arch_config_file, env, foreground=foreground)
|
||||
else:
|
||||
# Check if ingress_traffic listener is configured before starting model_server
|
||||
if has_ingress_listener(arch_config_file):
|
||||
download_models_from_hf()
|
||||
start_arch_modelserver(foreground)
|
||||
else:
|
||||
log.info(
|
||||
"Skipping model_server startup: no ingress_traffic listener configured in arch_config.yaml"
|
||||
)
|
||||
|
||||
start_arch(arch_config_file, env, foreground=foreground)
|
||||
start_arch(arch_config_file, env, foreground=foreground)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--service",
|
||||
default=SERVICE_ALL,
|
||||
help="Service to down. Options are all, model_server, archgw. Default is all",
|
||||
)
|
||||
def down(service):
|
||||
def down():
|
||||
"""Stops Arch."""
|
||||
|
||||
verify_service_name(service)
|
||||
|
||||
if service == SERVICE_NAME_MODEL_SERVER:
|
||||
stop_arch_modelserver()
|
||||
elif service == SERVICE_NAME_ARCHGW:
|
||||
stop_docker_container()
|
||||
else:
|
||||
stop_arch_modelserver()
|
||||
stop_docker_container(SERVICE_NAME_ARCHGW)
|
||||
stop_docker_container()
|
||||
|
||||
|
||||
@click.command()
|
||||
|
|
@ -303,7 +211,7 @@ def generate_prompt_targets(file):
|
|||
@click.command()
|
||||
@click.option(
|
||||
"--debug",
|
||||
help="For detailed debug logs to trace calls from archgw <> model_server <> api_server, etc",
|
||||
help="For detailed debug logs to trace calls from archgw <> api_server, etc",
|
||||
is_flag=True,
|
||||
)
|
||||
@click.option("--follow", help="Follow the logs", is_flag=True)
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ def convert_legacy_listeners(
|
|||
"timeout": "30s",
|
||||
}
|
||||
|
||||
# Handle None case
|
||||
if listeners is None:
|
||||
return [llm_gateway_listener], llm_gateway_listener, prompt_gateway_listener
|
||||
|
||||
if isinstance(listeners, dict):
|
||||
# legacy listeners
|
||||
# check if type is array or object
|
||||
|
|
@ -148,6 +152,24 @@ def get_llm_provider_access_keys(arch_config_file):
|
|||
if access_key is not None:
|
||||
access_key_list.append(access_key)
|
||||
|
||||
# Extract environment variables from state_storage.connection_string
|
||||
state_storage = arch_config_yaml.get("state_storage_v1_responses")
|
||||
if state_storage:
|
||||
connection_string = state_storage.get("connection_string")
|
||||
if connection_string and isinstance(connection_string, str):
|
||||
# Extract all $VAR and ${VAR} patterns from connection string
|
||||
import re
|
||||
|
||||
# Match both $VAR and ${VAR} patterns
|
||||
pattern = r"\$\{?([A-Z_][A-Z0-9_]*)\}?"
|
||||
matches = re.findall(pattern, connection_string)
|
||||
for var in matches:
|
||||
access_key_list.append(f"${var}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid connection string received in state_storage_v1_responses"
|
||||
)
|
||||
|
||||
return access_key_list
|
||||
|
||||
|
||||
|
|
|
|||
2481
arch/tools/poetry.lock
generated
2481
arch/tools/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,34 +1,28 @@
|
|||
[project]
|
||||
name = "archgw"
|
||||
version = "0.3.18"
|
||||
description = "Python-based CLI tool to manage Arch Gateway."
|
||||
authors = [{ name = "Katanemo Labs, Inc." }]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"archgw_modelserver==0.3.18",
|
||||
"click>=8.1.7,<9.0.0",
|
||||
"jinja2>=3.1.4,<4.0.0",
|
||||
"jsonschema>=4.23.0,<5.0.0",
|
||||
"pyyaml>=6.0.2,<7.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
archgw = "cli.main:main"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.4.1,<9.0.0",
|
||||
]
|
||||
|
||||
[tool.poetry]
|
||||
name = "archgw"
|
||||
version = "0.3.22"
|
||||
description = "Python-based CLI tool to manage Arch Gateway."
|
||||
authors = ["Katanemo Labs, Inc."]
|
||||
readme = "README.md"
|
||||
packages = [{ include = "cli" }]
|
||||
dependencies = { archgw_modelserver = { path = "../../model_server", develop = true } }
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10"
|
||||
click = ">=8.1.7,<9.0.0"
|
||||
jinja2 = ">=3.1.4,<4.0.0"
|
||||
jsonschema = ">=4.23.0,<5.0.0"
|
||||
pyyaml = ">=6.0.2,<7.0.0"
|
||||
requests = ">=2.31.0,<3.0.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = ">=8.4.1,<9.0.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
archgw = "cli.main:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=2.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = ["-v"]
|
||||
|
|
|
|||
|
|
@ -94,21 +94,16 @@ def test_validate_and_render_happy_path_agent_config(monkeypatch):
|
|||
version: v0.3.0
|
||||
|
||||
agents:
|
||||
- name: query_rewriter
|
||||
kind: openai
|
||||
endpoint: http://localhost:10500
|
||||
- name: context_builder
|
||||
kind: openai
|
||||
endpoint: http://localhost:10501
|
||||
- name: response_generator
|
||||
kind: openai
|
||||
endpoint: http://localhost:10502
|
||||
- name: research_agent
|
||||
kind: openai
|
||||
endpoint: http://localhost:10500
|
||||
- name: input_guard_rails
|
||||
kind: openai
|
||||
endpoint: http://localhost:10503
|
||||
- id: query_rewriter
|
||||
url: http://localhost:10500
|
||||
- id: context_builder
|
||||
url: http://localhost:10501
|
||||
- id: response_generator
|
||||
url: http://localhost:10502
|
||||
- id: research_agent
|
||||
url: http://localhost:10500
|
||||
- id: input_guard_rails
|
||||
url: http://localhost:10503
|
||||
|
||||
listeners:
|
||||
- name: tmobile
|
||||
|
|
@ -156,7 +151,7 @@ listeners:
|
|||
mock.mock_open().return_value, # ARCH_CONFIG_FILE_RENDERED (write)
|
||||
]
|
||||
with mock.patch("builtins.open", m_open):
|
||||
with mock.patch("config_generator.Environment"):
|
||||
with mock.patch("cli.config_generator.Environment"):
|
||||
validate_and_render_schema()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,10 +12,6 @@
|
|||
"name": "archgw_cli",
|
||||
"path": "arch/tools"
|
||||
},
|
||||
{
|
||||
"name": "model_server",
|
||||
"path": "model_server"
|
||||
},
|
||||
{
|
||||
"name": "tests_e2e",
|
||||
"path": "tests/e2e"
|
||||
|
|
@ -24,10 +20,6 @@
|
|||
"name": "tests_archgw",
|
||||
"path": "tests/archgw"
|
||||
},
|
||||
{
|
||||
"name": "tests_modelserver",
|
||||
"path": "tests/modelserver"
|
||||
},
|
||||
{
|
||||
"name": "chatbot_ui",
|
||||
"path": "demos/shared/chatbot_ui"
|
||||
|
|
@ -42,6 +34,7 @@
|
|||
"editor.defaultFormatter": "ms-python.black-formatter",
|
||||
"editor.formatOnSave": true
|
||||
},
|
||||
"rust-analyzer.cargo.features": ["trace-collection"]
|
||||
},
|
||||
"extensions": {
|
||||
"recommendations": [
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.2
|
||||
docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.3.22
|
||||
|
|
|
|||
590
crates/Cargo.lock
generated
590
crates/Cargo.lock
generated
|
|
@ -78,6 +78,43 @@ dependencies = [
|
|||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-openai"
|
||||
version = "0.30.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6bf39a15c8d613eb61892dc9a287c02277639ebead41ee611ad23aaa613f1a82"
|
||||
dependencies = [
|
||||
"async-openai-macros",
|
||||
"backoff",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"derive_builder",
|
||||
"eventsource-stream",
|
||||
"futures",
|
||||
"rand 0.9.2",
|
||||
"reqwest",
|
||||
"reqwest-eventsource",
|
||||
"secrecy",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-openai-macros"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.88"
|
||||
|
|
@ -130,6 +167,75 @@ dependencies = [
|
|||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.7.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tower 0.5.2",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"sync_wrapper",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backoff"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"getrandom 0.2.16",
|
||||
"instant",
|
||||
"pin-project-lite",
|
||||
"rand 0.8.5",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.75"
|
||||
|
|
@ -201,10 +307,14 @@ dependencies = [
|
|||
name = "brightstaff"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-openai",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"common",
|
||||
"eventsource-client",
|
||||
"eventsource-stream",
|
||||
"flate2",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"hermesllm",
|
||||
|
|
@ -219,6 +329,7 @@ dependencies = [
|
|||
"opentelemetry-stdout",
|
||||
"opentelemetry_sdk",
|
||||
"pretty_assertions",
|
||||
"rand 0.9.2",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
@ -227,10 +338,12 @@ dependencies = [
|
|||
"thiserror 2.0.12",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -250,6 +363,12 @@ version = "3.18.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.10.1"
|
||||
|
|
@ -281,6 +400,12 @@ version = "1.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "cfg_aliases"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.41"
|
||||
|
|
@ -289,8 +414,10 @@ checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
|
|||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
|
|
@ -307,6 +434,7 @@ dependencies = [
|
|||
name = "common"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"derivative",
|
||||
"duration-string",
|
||||
"governor",
|
||||
|
|
@ -316,12 +444,16 @@ dependencies = [
|
|||
"pretty_assertions",
|
||||
"proxy-wasm",
|
||||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"serde_yaml",
|
||||
"serial_test",
|
||||
"thiserror 1.0.69",
|
||||
"tiktoken-rs",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"url",
|
||||
"urlencoding",
|
||||
]
|
||||
|
|
@ -336,6 +468,16 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation-sys"
|
||||
version = "0.8.7"
|
||||
|
|
@ -426,6 +568,37 @@ dependencies = [
|
|||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder"
|
||||
version = "0.20.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
|
||||
dependencies = [
|
||||
"derive_builder_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder_core"
|
||||
version = "0.20.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder_macro"
|
||||
version = "0.20.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
|
||||
dependencies = [
|
||||
"derive_builder_core",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "diff"
|
||||
version = "0.1.13"
|
||||
|
|
@ -440,6 +613,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
|||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -527,6 +701,12 @@ dependencies = [
|
|||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fallible-iterator"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.12.0"
|
||||
|
|
@ -543,6 +723,16 @@ version = "2.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
|
|
@ -650,6 +840,12 @@ version = "0.3.31"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
|
||||
|
||||
[[package]]
|
||||
name = "futures-timer"
|
||||
version = "3.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.31"
|
||||
|
|
@ -685,8 +881,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -696,9 +894,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"r-efi",
|
||||
"wasi 0.14.2+wasi-0.2.4",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -798,10 +998,12 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"aws-smithy-eventstream",
|
||||
"bytes",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"thiserror 2.0.12",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -810,6 +1012,15 @@ version = "0.4.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "hmac"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
|
||||
dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.12"
|
||||
|
|
@ -934,7 +1145,7 @@ dependencies = [
|
|||
"hyper 0.14.32",
|
||||
"log",
|
||||
"rustls 0.21.12",
|
||||
"rustls-native-certs",
|
||||
"rustls-native-certs 0.6.3",
|
||||
"tokio",
|
||||
"tokio-rustls 0.24.1",
|
||||
]
|
||||
|
|
@ -949,6 +1160,7 @@ dependencies = [
|
|||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"rustls 0.23.27",
|
||||
"rustls-native-certs 0.8.2",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-rustls 0.26.2",
|
||||
|
|
@ -1181,6 +1393,15 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "instant"
|
||||
version = "0.1.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.11.0"
|
||||
|
|
@ -1234,6 +1455,17 @@ version = "0.2.172"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.9.4"
|
||||
|
|
@ -1285,6 +1517,12 @@ version = "0.4.27"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
|
||||
|
||||
[[package]]
|
||||
name = "lru-slab"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
|
||||
|
||||
[[package]]
|
||||
name = "matchers"
|
||||
version = "0.1.0"
|
||||
|
|
@ -1294,6 +1532,22 @@ dependencies = [
|
|||
"regex-automata 0.1.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||
|
||||
[[package]]
|
||||
name = "md-5"
|
||||
version = "0.10.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "md5"
|
||||
version = "0.7.0"
|
||||
|
|
@ -1312,6 +1566,16 @@ version = "0.3.17"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
||||
|
||||
[[package]]
|
||||
name = "mime_guess"
|
||||
version = "2.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
|
||||
dependencies = [
|
||||
"mime",
|
||||
"unicase",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "minimal-lexical"
|
||||
version = "0.2.1"
|
||||
|
|
@ -1325,6 +1589,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a"
|
||||
dependencies = [
|
||||
"adler2",
|
||||
"simd-adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1354,7 +1619,7 @@ dependencies = [
|
|||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"rand 0.9.1",
|
||||
"rand 0.9.2",
|
||||
"regex",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
|
|
@ -1374,7 +1639,7 @@ dependencies = [
|
|||
"openssl-probe",
|
||||
"openssl-sys",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
"security-framework 2.11.1",
|
||||
"security-framework-sys",
|
||||
"tempfile",
|
||||
]
|
||||
|
|
@ -1581,7 +1846,7 @@ dependencies = [
|
|||
"glob",
|
||||
"opentelemetry",
|
||||
"percent-encoding",
|
||||
"rand 0.9.1",
|
||||
"rand 0.9.2",
|
||||
"serde_json",
|
||||
"thiserror 2.0.12",
|
||||
"tracing",
|
||||
|
|
@ -1628,6 +1893,24 @@ version = "2.3.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
||||
|
||||
[[package]]
|
||||
name = "phf"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078"
|
||||
dependencies = [
|
||||
"phf_shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf_shared"
|
||||
version = "0.11.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
|
||||
dependencies = [
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.1.10"
|
||||
|
|
@ -1672,6 +1955,37 @@ version = "1.11.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e"
|
||||
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fbef655056b916eb868048276cfd5d6a7dea4f81560dfd047f97c8c6fe3fcfd4"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"hmac",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"rand 0.9.2",
|
||||
"sha2",
|
||||
"stringprep",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-types"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "potential_utf"
|
||||
version = "0.1.2"
|
||||
|
|
@ -1770,6 +2084,61 @@ dependencies = [
|
|||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"cfg_aliases",
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls 0.23.27",
|
||||
"socket2",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-proto"
|
||||
version = "0.11.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"getrandom 0.3.3",
|
||||
"lru-slab",
|
||||
"rand 0.9.2",
|
||||
"ring",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls 0.23.27",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
"thiserror 2.0.12",
|
||||
"tinyvec",
|
||||
"tracing",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-udp"
|
||||
version = "0.5.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
|
||||
dependencies = [
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2",
|
||||
"tracing",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.40"
|
||||
|
|
@ -1798,9 +2167,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.9.1"
|
||||
version = "0.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
|
||||
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
|
||||
dependencies = [
|
||||
"rand_chacha 0.9.0",
|
||||
"rand_core 0.9.3",
|
||||
|
|
@ -1846,9 +2215,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.12"
|
||||
version = "0.5.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af"
|
||||
checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
]
|
||||
|
|
@ -1941,10 +2310,14 @@ dependencies = [
|
|||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
"native-tls",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"quinn",
|
||||
"rustls 0.23.27",
|
||||
"rustls-native-certs 0.8.2",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
@ -1952,6 +2325,7 @@ dependencies = [
|
|||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-rustls 0.26.2",
|
||||
"tokio-util",
|
||||
"tower 0.5.2",
|
||||
"tower-http",
|
||||
|
|
@ -1963,6 +2337,22 @@ dependencies = [
|
|||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reqwest-eventsource"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde"
|
||||
dependencies = [
|
||||
"eventsource-stream",
|
||||
"futures-core",
|
||||
"futures-timer",
|
||||
"mime",
|
||||
"nom",
|
||||
"pin-project-lite",
|
||||
"reqwest",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.14"
|
||||
|
|
@ -1989,6 +2379,12 @@ version = "1.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.0.7"
|
||||
|
|
@ -2021,6 +2417,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.103.3",
|
||||
"subtle",
|
||||
|
|
@ -2036,7 +2433,19 @@ dependencies = [
|
|||
"openssl-probe",
|
||||
"rustls-pemfile",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
"security-framework 2.11.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9980d917ebb0c0536119ba501e90834767bffc3d60641457fd84a1f3fd337923"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pki-types",
|
||||
"schannel",
|
||||
"security-framework 3.5.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2054,6 +2463,7 @@ version = "1.12.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79"
|
||||
dependencies = [
|
||||
"web-time",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
|
|
@ -2142,6 +2552,16 @@ version = "3.0.8"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21"
|
||||
|
||||
[[package]]
|
||||
name = "secrecy"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.11.1"
|
||||
|
|
@ -2149,7 +2569,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"core-foundation",
|
||||
"core-foundation 0.9.4",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"security-framework-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "3.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"core-foundation 0.10.1",
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
"security-framework-sys",
|
||||
|
|
@ -2191,12 +2624,23 @@ version = "1.0.140"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
|
||||
dependencies = [
|
||||
"indexmap 2.9.0",
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_path_to_error"
|
||||
version = "0.1.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_urlencoded"
|
||||
version = "0.7.1"
|
||||
|
|
@ -2313,12 +2757,24 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simd-adler32"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2"
|
||||
|
||||
[[package]]
|
||||
name = "similar"
|
||||
version = "2.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
|
||||
|
||||
[[package]]
|
||||
name = "siphasher"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.9"
|
||||
|
|
@ -2359,6 +2815,17 @@ version = "1.2.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
||||
|
||||
[[package]]
|
||||
name = "stringprep"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1"
|
||||
dependencies = [
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
"unicode-properties",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
|
|
@ -2420,7 +2887,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"core-foundation",
|
||||
"core-foundation 0.9.4",
|
||||
"system-configuration-sys",
|
||||
]
|
||||
|
||||
|
|
@ -2509,7 +2976,7 @@ dependencies = [
|
|||
"fancy-regex",
|
||||
"lazy_static",
|
||||
"parking_lot",
|
||||
"rustc-hash",
|
||||
"rustc-hash 1.1.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2553,6 +3020,21 @@ dependencies = [
|
|||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
|
||||
dependencies = [
|
||||
"tinyvec_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec_macros"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.45.1"
|
||||
|
|
@ -2602,6 +3084,32 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-postgres"
|
||||
version = "0.7.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c95d533c83082bb6490e0189acaa0bbeef9084e60471b696ca6988cd0541fb0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"percent-encoding",
|
||||
"phf",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
"postgres-types",
|
||||
"rand 0.9.2",
|
||||
"socket2",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"whoami",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.24.1"
|
||||
|
|
@ -2705,6 +3213,7 @@ dependencies = [
|
|||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2743,6 +3252,7 @@ version = "0.1.41"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
|
||||
dependencies = [
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
|
|
@ -2829,12 +3339,39 @@ version = "1.18.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
|
||||
|
||||
[[package]]
|
||||
name = "unicase"
|
||||
version = "2.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-bidi"
|
||||
version = "0.3.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-normalization"
|
||||
version = "0.1.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8"
|
||||
dependencies = [
|
||||
"tinyvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-properties"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d"
|
||||
|
||||
[[package]]
|
||||
name = "unsafe-libyaml"
|
||||
version = "0.2.11"
|
||||
|
|
@ -2870,6 +3407,18 @@ version = "1.0.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2"
|
||||
dependencies = [
|
||||
"getrandom 0.3.3",
|
||||
"js-sys",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.1"
|
||||
|
|
@ -2918,6 +3467,12 @@ dependencies = [
|
|||
"wit-bindgen-rt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasite"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.100"
|
||||
|
|
@ -3022,6 +3577,17 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d"
|
||||
dependencies = [
|
||||
"libredox",
|
||||
"wasite",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["llm_gateway", "prompt_gateway", "common", "brightstaff", "hermesllm"]
|
||||
|
||||
[workspace.metadata.rust-analyzer]
|
||||
# Enable features for better IDE support
|
||||
cargo.features = ["trace-collection"]
|
||||
|
|
|
|||
|
|
@ -4,10 +4,14 @@ version = "0.1.0"
|
|||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
async-openai = "0.30.1"
|
||||
async-trait = "0.1"
|
||||
bytes = "1.10.1"
|
||||
common = { version = "0.1.0", path = "../common" }
|
||||
chrono = "0.4"
|
||||
common = { version = "0.1.0", path = "../common", features = ["trace-collection"] }
|
||||
eventsource-client = "0.15.0"
|
||||
eventsource-stream = "0.2.3"
|
||||
flate2 = "1.0"
|
||||
futures = "0.3.31"
|
||||
futures-util = "0.3.31"
|
||||
hermesllm = { version = "0.1.0", path = "../hermesllm" }
|
||||
|
|
@ -21,6 +25,7 @@ opentelemetry-otlp = {version="0.29.0", features=["trace", "tonic", "grpc-tonic"
|
|||
opentelemetry-stdout = "0.29.0"
|
||||
opentelemetry_sdk = "0.29.0"
|
||||
pretty_assertions = "1.4.1"
|
||||
rand = "0.9.2"
|
||||
reqwest = { version = "0.12.15", features = ["stream"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
|
|
@ -28,10 +33,12 @@ serde_with = "3.13.0"
|
|||
serde_yaml = "0.9.34"
|
||||
thiserror = "2.0.12"
|
||||
tokio = { version = "1.44.2", features = ["full"] }
|
||||
tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] }
|
||||
tokio-stream = "0.1"
|
||||
time = { version = "0.3", features = ["formatting", "macros"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "1.0"
|
||||
|
|
|
|||
|
|
@ -1,16 +1,24 @@
|
|||
use std::sync::Arc;
|
||||
use std::time::{Instant, SystemTime};
|
||||
|
||||
use bytes::Bytes;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use common::consts::TRACE_PARENT_HEADER;
|
||||
use common::traces::{SpanBuilder, SpanKind, parse_traceparent, generate_random_span_id};
|
||||
use hermesllm::apis::OpenAIMessage;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::providers::request::ProviderRequest;
|
||||
use hermesllm::ProviderRequestType;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response};
|
||||
use serde::ser::Error as SerError;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use super::response_handler::ResponseHandler;
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::tracing::{OperationNameBuilder, operation_component, http};
|
||||
|
||||
/// Main errors for agent chat completions
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -33,11 +41,51 @@ pub async fn agent_chat(
|
|||
_: String,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
trace_collector: Arc<common::traces::TraceCollector>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
match handle_agent_chat(request, router_service, agents_list, listeners).await {
|
||||
match handle_agent_chat(
|
||||
request,
|
||||
router_service,
|
||||
agents_list,
|
||||
listeners,
|
||||
trace_collector,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
// Print detailed error information with full error chain
|
||||
// Check if this is a client error from the pipeline that should be cascaded
|
||||
if let AgentFilterChainError::Pipeline(PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) = &err
|
||||
{
|
||||
warn!(
|
||||
"Client error from agent '{}' (HTTP {}): {}",
|
||||
agent, status, body
|
||||
);
|
||||
|
||||
// Create error response with the original status code and body
|
||||
let error_json = serde_json::json!({
|
||||
"error": "ClientError",
|
||||
"agent": agent,
|
||||
"status": status,
|
||||
"agent_response": body
|
||||
});
|
||||
|
||||
let json_string = error_json.to_string();
|
||||
let mut response = Response::new(ResponseHandler::create_full_body(json_string));
|
||||
*response.status_mut() = hyper::StatusCode::from_u16(*status)
|
||||
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
"application/json".parse().unwrap(),
|
||||
);
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// Print detailed error information with full error chain for other errors
|
||||
let mut error_chain = Vec::new();
|
||||
let mut current_error: &dyn std::error::Error = &err;
|
||||
|
||||
|
|
@ -78,10 +126,11 @@ async fn handle_agent_chat(
|
|||
router_service: Arc<RouterService>,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
trace_collector: Arc<common::traces::TraceCollector>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
// Initialize services
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let pipeline_processor = PipelineProcessor::default();
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let response_handler = ResponseHandler::new();
|
||||
|
||||
// Extract listener name from headers
|
||||
|
|
@ -101,6 +150,13 @@ async fn handle_agent_chat(
|
|||
info!("Handling request for listener: {}", listener.name);
|
||||
|
||||
// Parse request body
|
||||
let request_path = request
|
||||
.uri()
|
||||
.path()
|
||||
.to_string()
|
||||
.strip_prefix("/agents")
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
|
|
@ -109,61 +165,141 @@ async fn handle_agent_chat(
|
|||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
|
||||
warn!(
|
||||
"Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
AgentFilterChainError::RequestParsing(err)
|
||||
// Determine the API type from the endpoint
|
||||
let api_type =
|
||||
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
|
||||
let err_msg = format!("Unsupported endpoint: {}", request_path);
|
||||
warn!("{}", err_msg);
|
||||
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
|
||||
})?;
|
||||
|
||||
let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
return Err(AgentFilterChainError::RequestParsing(
|
||||
serde_json::Error::custom(err_msg),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let message: Vec<OpenAIMessage> = client_request.get_messages();
|
||||
|
||||
// let chat_completions_request: ChatCompletionsRequest =
|
||||
// serde_json::from_slice(&chat_request_bytes).map_err(|err| {
|
||||
// warn!(
|
||||
// "Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
// err
|
||||
// );
|
||||
// AgentFilterChainError::RequestParsing(err)
|
||||
// })?;
|
||||
|
||||
// Extract trace parent for routing
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
.find(|(key, _)| key.as_str() == "traceparent")
|
||||
.find(|(key, _)| key.as_str() == TRACE_PARENT_HEADER)
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(&chat_completions_request.messages, &listener, trace_parent)
|
||||
.await?;
|
||||
|
||||
debug!("Processing agent pipeline: {}", selected_agent.id);
|
||||
|
||||
// Create agent map for pipeline processing
|
||||
// Create agent map for pipeline processing and agent selection
|
||||
let agent_map = {
|
||||
let agents = agents_list.read().await;
|
||||
let agents = agents.as_ref().unwrap();
|
||||
agent_selector.create_agent_map(agents)
|
||||
};
|
||||
|
||||
// Parse trace parent to get trace_id and parent_span_id
|
||||
let (trace_id, parent_span_id) = if let Some(ref tp) = trace_parent {
|
||||
parse_traceparent(tp)
|
||||
} else {
|
||||
(String::new(), None)
|
||||
};
|
||||
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(&message, &listener, trace_parent.clone())
|
||||
.await?;
|
||||
|
||||
debug!("Processing agent pipeline: {}", selected_agent.id);
|
||||
|
||||
// Record the start time for agent span
|
||||
let agent_start_time = SystemTime::now();
|
||||
let agent_start_instant = Instant::now();
|
||||
// let (span_id, trace_id) = trace_collector.start_span(
|
||||
// trace_parent.clone(),
|
||||
// operation_component::AGENT,
|
||||
// &format!("/agents{}", request_path),
|
||||
// &selected_agent.id,
|
||||
// );
|
||||
|
||||
let span_id = generate_random_span_id();
|
||||
|
||||
// Process the filter chain
|
||||
let processed_messages = pipeline_processor
|
||||
let chat_history = pipeline_processor
|
||||
.process_filter_chain(
|
||||
&chat_completions_request,
|
||||
&message,
|
||||
&selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
Some(&trace_collector),
|
||||
trace_id.clone(),
|
||||
span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Get terminal agent and send final response
|
||||
let terminal_agent_name = selected_agent.id;
|
||||
let terminal_agent_name = selected_agent.id.clone();
|
||||
let terminal_agent = agent_map.get(&terminal_agent_name).unwrap();
|
||||
|
||||
debug!("Processing terminal agent: {}", terminal_agent_name);
|
||||
debug!("Terminal agent details: {:?}", terminal_agent);
|
||||
|
||||
let llm_response = pipeline_processor
|
||||
.invoke_upstream_agent(
|
||||
&processed_messages,
|
||||
&chat_completions_request,
|
||||
.invoke_agent(
|
||||
&chat_history,
|
||||
client_request,
|
||||
terminal_agent,
|
||||
&request_headers,
|
||||
trace_id.clone(),
|
||||
span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Record agent span after processing is complete
|
||||
let agent_end_time = SystemTime::now();
|
||||
let agent_elapsed = agent_start_instant.elapsed();
|
||||
|
||||
// Build full path with /agents prefix
|
||||
let full_path = format!("/agents{}", request_path);
|
||||
|
||||
// Build operation name: POST {full_path} {agent_name}
|
||||
let operation_name = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(&full_path)
|
||||
.with_target(&terminal_agent_name)
|
||||
.build();
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_span_id(span_id)
|
||||
.with_kind(SpanKind::Internal)
|
||||
.with_start_time(agent_start_time)
|
||||
.with_end_time(agent_end_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, full_path)
|
||||
.with_attribute("agent.name", terminal_agent_name.clone())
|
||||
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
span_builder = span_builder.with_trace_id(trace_id);
|
||||
}
|
||||
if let Some(parent_id) = parent_span_id {
|
||||
span_builder = span_builder.with_parent_span_id(parent_id);
|
||||
}
|
||||
|
||||
let span = span_builder.build();
|
||||
// Use plano(agent) as service name for the agent processing span
|
||||
trace_collector.record_span(operation_component::AGENT, span);
|
||||
|
||||
// Create streaming response
|
||||
response_handler
|
||||
.create_streaming_response(llm_response)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ pub enum AgentSelectionError {
|
|||
RoutingError(String),
|
||||
#[error("Default agent not found for listener: {0}")]
|
||||
DefaultAgentNotFound(String),
|
||||
#[error("MCP client error: {0}")]
|
||||
McpError(String),
|
||||
}
|
||||
|
||||
/// Service for selecting agents based on routing preferences and listener configuration
|
||||
|
|
@ -29,7 +31,9 @@ pub struct AgentSelector {
|
|||
|
||||
impl AgentSelector {
|
||||
pub fn new(router_service: Arc<RouterService>) -> Self {
|
||||
Self { router_service }
|
||||
Self {
|
||||
router_service,
|
||||
}
|
||||
}
|
||||
|
||||
/// Find listener by name from the request headers
|
||||
|
|
@ -77,7 +81,9 @@ impl AgentSelector {
|
|||
return Ok(agents[0].clone());
|
||||
}
|
||||
|
||||
let usage_preferences = self.convert_agent_description_to_routing_preferences(agents);
|
||||
let usage_preferences = self
|
||||
.convert_agent_description_to_routing_preferences(agents)
|
||||
.await;
|
||||
debug!(
|
||||
"Agents usage preferences for agent routing str: {}",
|
||||
serde_json::to_string(&usage_preferences).unwrap_or_default()
|
||||
|
|
@ -131,20 +137,23 @@ impl AgentSelector {
|
|||
}
|
||||
|
||||
/// Convert agent descriptions to routing preferences
|
||||
fn convert_agent_description_to_routing_preferences(
|
||||
async fn convert_agent_description_to_routing_preferences(
|
||||
&self,
|
||||
agents: &[AgentFilterChain],
|
||||
) -> Vec<ModelUsagePreference> {
|
||||
agents
|
||||
.iter()
|
||||
.map(|agent| ModelUsagePreference {
|
||||
model: agent.id.clone(),
|
||||
let mut preferences = Vec::new();
|
||||
|
||||
for agent_chain in agents {
|
||||
preferences.push(ModelUsagePreference {
|
||||
model: agent_chain.id.clone(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: agent.id.clone(),
|
||||
description: agent.description.as_ref().unwrap_or(&String::new()).clone(),
|
||||
name: agent_chain.id.clone(),
|
||||
description: agent_chain.description.clone().unwrap_or_default(),
|
||||
}],
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
}
|
||||
|
||||
preferences
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -183,8 +192,10 @@ mod tests {
|
|||
fn create_test_agent_struct(name: &str) -> Agent {
|
||||
Agent {
|
||||
id: name.to_string(),
|
||||
kind: Some("test".to_string()),
|
||||
agent_type: Some("test".to_string()),
|
||||
url: "http://localhost:8080".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -240,8 +251,8 @@ mod tests {
|
|||
assert!(agent_map.contains_key("agent2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_agent_description_to_routing_preferences() {
|
||||
#[tokio::test]
|
||||
async fn test_convert_agent_description_to_routing_preferences() {
|
||||
let router_service = create_test_router_service();
|
||||
let selector = AgentSelector::new(router_service);
|
||||
|
||||
|
|
@ -250,7 +261,9 @@ mod tests {
|
|||
create_test_agent("agent2", "Second agent description", false),
|
||||
];
|
||||
|
||||
let preferences = selector.convert_agent_description_to_routing_preferences(&agents);
|
||||
let preferences = selector
|
||||
.convert_agent_description_to_routing_preferences(&agents)
|
||||
.await;
|
||||
|
||||
assert_eq!(preferences.len(), 2);
|
||||
assert_eq!(preferences[0].model, "agent1");
|
||||
|
|
|
|||
|
|
@ -1,276 +0,0 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{ModelAlias, ModelUsagePreference};
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub async fn chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let mut request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
"Received request body (raw utf8): {}",
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let is_streaming_request = client_request.is_streaming();
|
||||
let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = model_aliases.get(&model_from_request) {
|
||||
debug!(
|
||||
"Model Alias: 'From {}' -> 'To{}'",
|
||||
model_from_request, model_alias.target
|
||||
);
|
||||
model_alias.target.clone()
|
||||
} else {
|
||||
model_from_request.clone()
|
||||
}
|
||||
} else {
|
||||
model_from_request.clone()
|
||||
};
|
||||
client_request.set_model(resolved_model.clone());
|
||||
|
||||
// Clone metadata for routing and remove archgw_preference_config from original
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!("Removed archgw_preference_config from metadata");
|
||||
}
|
||||
|
||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||
|
||||
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
|
||||
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
|
||||
match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(
|
||||
hermesllm::apis::OpenAIApi::ChatCompletions,
|
||||
),
|
||||
)) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(
|
||||
ProviderRequestType::MessagesRequest(_)
|
||||
| ProviderRequestType::BedrockConverse(_)
|
||||
| ProviderRequestType::BedrockConverseStream(_),
|
||||
) => {
|
||||
// This should not happen after conversion to OpenAI format
|
||||
warn!("Unexpected: got MessagesRequest after converting to OpenAI format");
|
||||
let err_msg = "Request conversion failed".to_string();
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Failed to convert request to ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
let err_msg = format!("Failed to convert request: {}", err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"[ARCH_ROUTER REQ]: {}",
|
||||
&serde_json::to_string(&chat_completions_request_for_arch_router).unwrap()
|
||||
);
|
||||
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
.find(|(ty, _)| ty.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
|
||||
let latest_message_for_log = chat_completions_request_for_arch_router
|
||||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
|
||||
let truncated: String = latest_message_for_log
|
||||
.chars()
|
||||
.take(MAX_MESSAGE_LENGTH)
|
||||
.collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
latest_message_for_log
|
||||
};
|
||||
|
||||
info!(
|
||||
"request received, request type: chat_completion, usage preferences from request: {}, request path: {}, latest message: {}",
|
||||
usage_preferences.is_some(),
|
||||
request_path,
|
||||
latest_message_for_log
|
||||
);
|
||||
|
||||
debug!("usage preferences from request: {:?}", usage_preferences);
|
||||
|
||||
let model_name = match router_service
|
||||
.determine_route(
|
||||
&chat_completions_request_for_arch_router.messages,
|
||||
trace_parent.clone(),
|
||||
usage_preferences,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(route) => match route {
|
||||
Some((_, model_name)) => model_name,
|
||||
None => {
|
||||
info!(
|
||||
"No route determined, using default model from request: {}",
|
||||
chat_completions_request_for_arch_router.model
|
||||
);
|
||||
chat_completions_request_for_arch_router.model.clone()
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to determine route: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
|
||||
full_qualified_llm_provider_url, model_name
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(&model_name).unwrap(),
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
|
||||
if let Some(trace_parent) = trace_parent {
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static("traceparent"),
|
||||
header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
}
|
||||
// remove content-length header if it exists
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(client_request_bytes_for_upstream)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
// copy over the headers and status code from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
let mut response = Response::builder().status(upstream_status);
|
||||
let headers = response.headers_mut().unwrap();
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
// channel to create async stream
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(16);
|
||||
|
||||
// Spawn a task to send data as it becomes available
|
||||
tokio::spawn(async move {
|
||||
let mut byte_stream = llm_response.bytes_stream();
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let item = match item {
|
||||
Ok(item) => item,
|
||||
Err(err) => {
|
||||
warn!("Error receiving chunk: {:?}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if tx.send(item).await.is_err() {
|
||||
warn!("Receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
match response.body(stream_body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to create response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
Ok(internal_error)
|
||||
}
|
||||
}
|
||||
}
|
||||
1934
crates/brightstaff/src/handlers/function_calling.rs
Normal file
1934
crates/brightstaff/src/handlers/function_calling.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -42,19 +42,23 @@ mod integration_tests {
|
|||
// Setup services
|
||||
let router_service = create_test_router_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let pipeline_processor = PipelineProcessor::default();
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
|
||||
// Create test data
|
||||
let agents = vec![
|
||||
Agent {
|
||||
id: "filter-agent".to_string(),
|
||||
kind: Some("filter".to_string()),
|
||||
agent_type: Some("filter".to_string()),
|
||||
url: "http://localhost:8081".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
},
|
||||
Agent {
|
||||
id: "terminal-agent".to_string(),
|
||||
kind: Some("terminal".to_string()),
|
||||
agent_type: Some("terminal".to_string()),
|
||||
url: "http://localhost:8082".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
},
|
||||
];
|
||||
|
||||
|
|
@ -107,7 +111,15 @@ mod integration_tests {
|
|||
|
||||
let headers = HeaderMap::new();
|
||||
let result = pipeline_processor
|
||||
.process_filter_chain(&request, &test_pipeline, &agent_map, &headers)
|
||||
.process_filter_chain(
|
||||
&request.messages,
|
||||
&test_pipeline,
|
||||
&agent_map,
|
||||
&headers,
|
||||
None,
|
||||
String::new(),
|
||||
String::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
println!("Pipeline processing result: {:?}", result);
|
||||
|
|
|
|||
49
crates/brightstaff/src/handlers/jsonrpc.rs
Normal file
49
crates/brightstaff/src/handlers/jsonrpc.rs
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub const JSON_RPC_VERSION: &str = "2.0";
|
||||
pub const TOOL_CALL_METHOD : &str = "tools/call";
|
||||
pub const MCP_INITIALIZE: &str = "initialize";
|
||||
pub const MCP_INITIALIZE_NOTIFICATION: &str = "initialize/notification";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JsonRpcId {
|
||||
String(String),
|
||||
Number(u64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcNotification {
|
||||
pub jsonrpc: String,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
}
|
||||
462
crates/brightstaff/src/handlers/llm.rs
Normal file
462
crates/brightstaff/src/handlers/llm.rs
Normal file
|
|
@ -0,0 +1,462 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{LlmProvider, ModelAlias};
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
|
||||
use common::traces::TraceCollector;
|
||||
use hermesllm::apis::openai_responses::InputParam;
|
||||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
use crate::state::{
|
||||
StateStorage, StateStorageError,
|
||||
extract_input_items, retrieve_and_combine_input
|
||||
};
|
||||
use crate::tracing::operation_component;
|
||||
|
||||
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub async fn llm_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<Vec<LlmProvider>>>,
|
||||
trace_collector: Arc<TraceCollector>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
let request_id = request_headers
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// Extract or generate traceparent - this establishes the trace context for all spans
|
||||
let traceparent: String = request_headers
|
||||
.get(TRACE_PARENT_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| {
|
||||
use uuid::Uuid;
|
||||
let trace_id = Uuid::new_v4().to_string().replace("-", "");
|
||||
format!("00-{}-0000000000000000-01", trace_id)
|
||||
});
|
||||
|
||||
let mut request_headers = request_headers;
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | REQUEST_BODY (UTF8): {}",
|
||||
request_id,
|
||||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err);
|
||||
let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
};
|
||||
|
||||
// === v1/responses state management: Extract input items early ===
|
||||
let mut original_input_items = Vec::new();
|
||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
|
||||
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)));
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let temperature = client_request.get_temperature();
|
||||
let is_streaming_request = client_request.is_streaming();
|
||||
let resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
||||
|
||||
// Extract tool names and user message preview for span attributes
|
||||
let tool_names = client_request.get_tool_names();
|
||||
let user_message_preview = client_request.get_recent_user_message()
|
||||
.map(|msg| truncate_message(&msg, 50));
|
||||
|
||||
client_request.set_model(resolved_model.clone());
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id);
|
||||
}
|
||||
|
||||
// === v1/responses state management: Determine upstream API and combine input if needed ===
|
||||
// Do this BEFORE routing since routing consumes the request
|
||||
// Only process state if state_storage is configured
|
||||
let mut should_manage_state = false;
|
||||
if is_responses_api_client && state_storage.is_some() {
|
||||
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request {
|
||||
// Extract original input once
|
||||
original_input_items = extract_input_items(&responses_req.input);
|
||||
|
||||
// Get the upstream path and check if it's ResponsesAPI
|
||||
let upstream_path = get_upstream_path(
|
||||
&llm_providers,
|
||||
&resolved_model,
|
||||
&request_path,
|
||||
&resolved_model,
|
||||
is_streaming_request,
|
||||
).await;
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
|
||||
|
||||
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
|
||||
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)));
|
||||
|
||||
if should_manage_state {
|
||||
// Retrieve and combine conversation history if previous_response_id exists
|
||||
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
|
||||
match retrieve_and_combine_input(
|
||||
state_storage.as_ref().unwrap().clone(),
|
||||
prev_resp_id,
|
||||
original_input_items, // Pass ownership instead of cloning
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(combined_input) => {
|
||||
// Update both the request and original_input_items
|
||||
responses_req.input = InputParam::Items(combined_input.clone());
|
||||
original_input_items = combined_input;
|
||||
info!("[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Updated request with conversation history ({} items)", request_id, original_input_items.len());
|
||||
}
|
||||
Err(StateStorageError::NotFound(_)) => {
|
||||
// Return 409 Conflict when previous_response_id not found
|
||||
warn!("[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Previous response_id not found: {}", request_id, prev_resp_id);
|
||||
let err_msg = format!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Conversation state not found for previous_response_id: {}",
|
||||
request_id, prev_resp_id
|
||||
);
|
||||
let mut conflict_response = Response::new(full(err_msg));
|
||||
*conflict_response.status_mut() = StatusCode::CONFLICT;
|
||||
return Ok(conflict_response);
|
||||
}
|
||||
Err(e) => {
|
||||
// Log warning but continue on other storage errors
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to retrieve conversation state for {}: {}",
|
||||
request_id, prev_resp_id, e
|
||||
);
|
||||
// Restore original_input_items since we passed ownership
|
||||
original_input_items = extract_input_items(&responses_req.input);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize request for upstream BEFORE router consumes it
|
||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||
|
||||
// Determine routing using the dedicated router_chat module
|
||||
let routing_result = match router_chat_get_upstream_model(
|
||||
router_service,
|
||||
client_request, // Pass the original request - router_chat will convert it
|
||||
&request_headers,
|
||||
trace_collector.clone(),
|
||||
&traceparent,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
let mut internal_error = Response::new(full(err.message));
|
||||
*internal_error.status_mut() = err.status_code;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
let model_name = routing_result.model_name;
|
||||
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Resolved Model: {}",
|
||||
request_id, full_qualified_llm_provider_url, model_name
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(&model_name).unwrap(),
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
// remove content-length header if it exists
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
// Capture start time right before sending request to upstream
|
||||
let request_start_time = std::time::Instant::now();
|
||||
let request_start_system_time = std::time::SystemTime::now();
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(client_request_bytes_for_upstream)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
// copy over the headers and status code from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
let mut response = Response::builder().status(upstream_status);
|
||||
let headers = response.headers_mut().unwrap();
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
// Build LLM span with actual status code using constants
|
||||
let byte_stream = llm_response.bytes_stream();
|
||||
|
||||
// Build the LLM span (will be finalized after streaming completes)
|
||||
let llm_span = build_llm_span(
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&resolved_model,
|
||||
&model_name,
|
||||
upstream_status.as_u16(),
|
||||
is_streaming_request,
|
||||
request_start_system_time,
|
||||
tool_names,
|
||||
user_message_preview,
|
||||
temperature,
|
||||
&llm_providers,
|
||||
).await;
|
||||
|
||||
// Create base processor for metrics and tracing
|
||||
let base_processor = ObservableStreamProcessor::new(
|
||||
trace_collector,
|
||||
operation_component::LLM,
|
||||
llm_span,
|
||||
request_start_time,
|
||||
);
|
||||
|
||||
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
|
||||
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
|
||||
let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() {
|
||||
// Extract Content-Encoding header to handle decompression for state parsing
|
||||
let content_encoding = response_headers
|
||||
.get("content-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Wrap with state management processor to store state after response completes
|
||||
let state_processor = ResponsesStateProcessor::new(
|
||||
base_processor,
|
||||
state_storage.unwrap(),
|
||||
original_input_items,
|
||||
resolved_model.clone(),
|
||||
model_name.clone(),
|
||||
is_streaming_request,
|
||||
false, // Not OpenAI upstream since should_manage_state is true
|
||||
content_encoding,
|
||||
request_id.clone(),
|
||||
);
|
||||
create_streaming_response(byte_stream, state_processor, 16)
|
||||
} else {
|
||||
// Use base processor without state management
|
||||
create_streaming_response(byte_stream, base_processor, 16)
|
||||
};
|
||||
|
||||
match response.body(streaming_response.body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to create response: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
Ok(internal_error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolves model aliases by looking up the requested model in the model_aliases map.
|
||||
/// Returns the target model if an alias is found, otherwise returns the original model.
|
||||
fn resolve_model_alias(
|
||||
model_from_request: &str,
|
||||
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
) -> String {
|
||||
if let Some(aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = aliases.get(model_from_request) {
|
||||
debug!(
|
||||
"Model Alias: 'From {}' -> 'To {}'",
|
||||
model_from_request, model_alias.target
|
||||
);
|
||||
return model_alias.target.clone();
|
||||
}
|
||||
}
|
||||
model_from_request.to_string()
|
||||
}
|
||||
|
||||
/// Builds the LLM span with all required and optional attributes.
|
||||
async fn build_llm_span(
|
||||
traceparent: &str,
|
||||
request_path: &str,
|
||||
resolved_model: &str,
|
||||
model_name: &str,
|
||||
status_code: u16,
|
||||
is_streaming: bool,
|
||||
start_time: std::time::SystemTime,
|
||||
tool_names: Option<Vec<String>>,
|
||||
user_message_preview: Option<String>,
|
||||
temperature: Option<f32>,
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
) -> common::traces::Span {
|
||||
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
|
||||
use crate::tracing::{http, llm, OperationNameBuilder};
|
||||
|
||||
// Calculate the upstream path based on provider configuration
|
||||
let upstream_path = get_upstream_path(
|
||||
llm_providers,
|
||||
model_name,
|
||||
request_path,
|
||||
resolved_model,
|
||||
is_streaming,
|
||||
).await;
|
||||
|
||||
// Build operation name showing path transformation if different
|
||||
let operation_name = if request_path != upstream_path {
|
||||
OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(&format!("{} >> {}", request_path, upstream_path))
|
||||
.with_target(resolved_model)
|
||||
.build()
|
||||
} else {
|
||||
OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(request_path)
|
||||
.with_target(resolved_model)
|
||||
.build()
|
||||
};
|
||||
|
||||
let (trace_id, parent_span_id) = parse_traceparent(traceparent);
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_trace_id(&trace_id)
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_start_time(start_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::STATUS_CODE, status_code.to_string())
|
||||
.with_attribute(http::TARGET, request_path.to_string())
|
||||
.with_attribute(http::UPSTREAM_TARGET, upstream_path)
|
||||
.with_attribute(llm::MODEL_NAME, resolved_model.to_string())
|
||||
.with_attribute(llm::IS_STREAMING, is_streaming.to_string());
|
||||
|
||||
// Only set parent span ID if it exists (not a root span)
|
||||
if let Some(parent) = parent_span_id {
|
||||
span_builder = span_builder.with_parent_span_id(&parent);
|
||||
}
|
||||
|
||||
// Add optional attributes
|
||||
if let Some(temp) = temperature {
|
||||
span_builder = span_builder.with_attribute(llm::TEMPERATURE, temp.to_string());
|
||||
}
|
||||
|
||||
if let Some(tools) = tool_names {
|
||||
let formatted_tools = tools.iter()
|
||||
.map(|name| format!("{}(...)", name))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
span_builder = span_builder.with_attribute(llm::TOOLS, formatted_tools);
|
||||
}
|
||||
|
||||
if let Some(preview) = user_message_preview {
|
||||
span_builder = span_builder.with_attribute(llm::USER_MESSAGE_PREVIEW, preview);
|
||||
}
|
||||
|
||||
span_builder.build()
|
||||
}
|
||||
|
||||
/// Calculates the upstream path for the provider based on the model name.
|
||||
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
|
||||
/// then uses target_endpoint_for_provider to calculate the correct upstream path.
|
||||
async fn get_upstream_path(
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
model_name: &str,
|
||||
request_path: &str,
|
||||
resolved_model: &str,
|
||||
is_streaming: bool,
|
||||
) -> String {
|
||||
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
|
||||
|
||||
// Calculate the upstream path using the proper API
|
||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
|
||||
.expect("Should have valid API endpoint");
|
||||
|
||||
client_api.target_endpoint_for_provider(
|
||||
&provider_id,
|
||||
request_path,
|
||||
resolved_model,
|
||||
is_streaming,
|
||||
base_url_path_prefix.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
|
||||
async fn get_provider_info(
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
model_name: &str,
|
||||
) -> (hermesllm::ProviderId, Option<String>) {
|
||||
let providers_lock = llm_providers.read().await;
|
||||
|
||||
// First, try to find by model name or provider name
|
||||
let provider = providers_lock.iter().find(|p| {
|
||||
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|
||||
|| p.name == model_name
|
||||
});
|
||||
|
||||
if let Some(provider) = provider {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
return (provider_id, prefix);
|
||||
}
|
||||
|
||||
let default_provider = providers_lock.iter().find(|p| {
|
||||
p.default.unwrap_or(false)
|
||||
});
|
||||
|
||||
if let Some(provider) = default_provider {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
(provider_id, prefix)
|
||||
} else {
|
||||
// Last resort: use OpenAI as hardcoded fallback
|
||||
warn!("No default provider found, falling back to OpenAI");
|
||||
(hermesllm::ProviderId::OpenAI, None)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,9 +1,13 @@
|
|||
pub mod agent_chat_completions;
|
||||
pub mod agent_selector;
|
||||
pub mod chat_completions;
|
||||
pub mod llm;
|
||||
pub mod router_chat;
|
||||
pub mod models;
|
||||
pub mod function_calling;
|
||||
pub mod pipeline_processor;
|
||||
pub mod response_handler;
|
||||
pub mod utils;
|
||||
pub mod jsonrpc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,24 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use common::configuration::{Agent, AgentFilterChain};
|
||||
use common::consts::{ARCH_UPSTREAM_HOST_HEADER, ENVOY_RETRY_HEADER};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use common::consts::{
|
||||
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::traces::{SpanBuilder, SpanKind, generate_random_span_id};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hyper::header::HeaderMap;
|
||||
use tracing::{debug, warn};
|
||||
use std::time::{Instant, SystemTime};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::tracing::operation_component::{self};
|
||||
use crate::tracing::{http, OperationNameBuilder};
|
||||
|
||||
use crate::handlers::jsonrpc::{
|
||||
JsonRpcId, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, JSON_RPC_VERSION,
|
||||
MCP_INITIALIZE, MCP_INITIALIZE_NOTIFICATION, TOOL_CALL_METHOD,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Errors that can occur during pipeline processing
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -19,19 +33,41 @@ pub enum PipelineError {
|
|||
NoChoicesInResponse(String),
|
||||
#[error("No content in response from agent '{0}'")]
|
||||
NoContentInResponse(String),
|
||||
#[error("No result in response from agent '{0}'")]
|
||||
NoResultInResponse(String),
|
||||
#[error("No structured content in response from agent '{0}'")]
|
||||
NoStructuredContentInResponse(String),
|
||||
#[error("No messages in response from agent '{0}'")]
|
||||
NoMessagesInResponse(String),
|
||||
#[error("Client error from agent '{agent}' (HTTP {status}): {body}")]
|
||||
ClientError {
|
||||
agent: String,
|
||||
status: u16,
|
||||
body: String,
|
||||
},
|
||||
#[error("Server error from agent '{agent}' (HTTP {status}): {body}")]
|
||||
ServerError {
|
||||
agent: String,
|
||||
status: u16,
|
||||
body: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Service for processing agent pipelines
|
||||
pub struct PipelineProcessor {
|
||||
client: reqwest::Client,
|
||||
url: String,
|
||||
agent_id_session_map: HashMap<String, String>,
|
||||
}
|
||||
|
||||
const ENVOY_API_ROUTER_ADDRESS: &str = "http://localhost:11000";
|
||||
|
||||
impl Default for PipelineProcessor {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
url: "http://localhost:11000/v1/chat/completions".to_string(),
|
||||
url: ENVOY_API_ROUTER_ADDRESS.to_string(),
|
||||
agent_id_session_map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -41,18 +77,128 @@ impl PipelineProcessor {
|
|||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
url,
|
||||
agent_id_session_map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a span for filter execution
|
||||
fn record_filter_span(
|
||||
&self,
|
||||
collector: &std::sync::Arc<common::traces::TraceCollector>,
|
||||
agent_name: &str,
|
||||
tool_name: &str,
|
||||
start_time: SystemTime,
|
||||
end_time: SystemTime,
|
||||
elapsed: std::time::Duration,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
span_id: String,
|
||||
) -> String {
|
||||
// let (trace_id, parent_span_id) = self.extract_trace_context();
|
||||
|
||||
// Build operation name: POST /agents/* {filter_name}
|
||||
// Using generic path since we don't have access to specific endpoint here
|
||||
let operation_name = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path("/agents/*")
|
||||
.with_target(agent_name)
|
||||
.build();
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_span_id(span_id.clone())
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_start_time(start_time)
|
||||
.with_end_time(end_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, "/agents/*")
|
||||
.with_attribute("filter.name", agent_name.to_string())
|
||||
.with_attribute("filter.tool_name", tool_name.to_string())
|
||||
.with_attribute(
|
||||
"duration_ms",
|
||||
format!("{:.2}", elapsed.as_secs_f64() * 1000.0),
|
||||
);
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
span_builder = span_builder.with_trace_id(trace_id);
|
||||
}
|
||||
if !parent_span_id.is_empty() {
|
||||
span_builder = span_builder.with_parent_span_id(parent_span_id);
|
||||
}
|
||||
|
||||
let span = span_builder.build();
|
||||
// Use plano(filter) as service name for filter execution spans
|
||||
collector.record_span(operation_component::AGENT_FILTER, span);
|
||||
span_id.clone()
|
||||
}
|
||||
|
||||
/// Record a span for MCP protocol interactions
|
||||
fn record_mcp_span(
|
||||
&self,
|
||||
collector: &std::sync::Arc<common::traces::TraceCollector>,
|
||||
operation: &str,
|
||||
agent_id: &str,
|
||||
start_time: SystemTime,
|
||||
end_time: SystemTime,
|
||||
elapsed: std::time::Duration,
|
||||
additional_attrs: Option<HashMap<&str, String>>,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
span_id: Option<String>,
|
||||
) {
|
||||
// let (trace_id, parent_span_id) = self.extract_trace_context();
|
||||
|
||||
// Build operation name: POST /mcp {agent_id}
|
||||
let operation_name = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path("/mcp")
|
||||
.with_operation(operation)
|
||||
.with_target(agent_id)
|
||||
.build();
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_span_id(span_id.unwrap_or_else(|| generate_random_span_id()))
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_start_time(start_time)
|
||||
.with_end_time(end_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string()))
|
||||
.with_attribute("mcp.operation", operation.to_string())
|
||||
.with_attribute("mcp.agent_id", agent_id.to_string())
|
||||
.with_attribute(
|
||||
"duration_ms",
|
||||
format!("{:.2}", elapsed.as_secs_f64() * 1000.0),
|
||||
);
|
||||
|
||||
if let Some(attrs) = additional_attrs {
|
||||
for (key, value) in attrs {
|
||||
span_builder = span_builder.with_attribute(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
span_builder = span_builder.with_trace_id(trace_id);
|
||||
}
|
||||
if !parent_span_id.is_empty() {
|
||||
span_builder = span_builder.with_parent_span_id(parent_span_id);
|
||||
}
|
||||
|
||||
let span = span_builder.build();
|
||||
// MCP spans also use plano(filter) service name as they are part of filter operations
|
||||
collector.record_span(operation_component::AGENT_FILTER, span);
|
||||
}
|
||||
|
||||
/// Process the filter chain of agents (all except the terminal agent)
|
||||
pub async fn process_filter_chain(
|
||||
&self,
|
||||
initial_request: &ChatCompletionsRequest,
|
||||
&mut self,
|
||||
chat_history: &[Message],
|
||||
agent_filter_chain: &AgentFilterChain,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
trace_collector: Option<&std::sync::Arc<common::traces::TraceCollector>>,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
let mut chat_completions_history = initial_request.messages.clone();
|
||||
let mut chat_history_updated = chat_history.to_vec();
|
||||
|
||||
for agent_name in &agent_filter_chain.filter_chain {
|
||||
debug!("Processing filter agent: {}", agent_name);
|
||||
|
|
@ -61,101 +207,490 @@ impl PipelineProcessor {
|
|||
.get(agent_name)
|
||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
debug!("Agent details: {:?}", agent);
|
||||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
|
||||
let response_content = self
|
||||
.send_agent_filter_chain_request(
|
||||
&chat_completions_history,
|
||||
initial_request,
|
||||
info!(
|
||||
"executing filter: {}/{}, url: {}, conversation length: {}",
|
||||
agent_name,
|
||||
tool_name,
|
||||
agent.url,
|
||||
chat_history.len()
|
||||
);
|
||||
|
||||
let start_time = SystemTime::now();
|
||||
let start_instant = Instant::now();
|
||||
|
||||
// Generate filter span ID before execution so MCP spans can use it as parent
|
||||
let filter_span_id = generate_random_span_id();
|
||||
|
||||
chat_history_updated = self
|
||||
.execute_filter(
|
||||
&chat_history_updated,
|
||||
agent,
|
||||
request_headers,
|
||||
trace_collector,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("Received response from filter agent {}", agent_name);
|
||||
let end_time = SystemTime::now();
|
||||
let elapsed = start_instant.elapsed();
|
||||
|
||||
// Parse the response content as new message history
|
||||
chat_completions_history =
|
||||
serde_json::from_str(&response_content).inspect_err(|err| {
|
||||
warn!(
|
||||
"Failed to parse response from agent {}, err: {}, response: {}",
|
||||
agent_name, err, response_content
|
||||
)
|
||||
})?;
|
||||
info!(
|
||||
"Filter '{}' completed in {:.2}ms, updated conversation length: {}",
|
||||
agent_name,
|
||||
elapsed.as_secs_f64() * 1000.0,
|
||||
chat_history_updated.len()
|
||||
);
|
||||
|
||||
// Record span for this filter execution
|
||||
if let Some(collector) = trace_collector {
|
||||
self.record_filter_span(
|
||||
collector,
|
||||
agent_name,
|
||||
tool_name,
|
||||
start_time,
|
||||
end_time,
|
||||
elapsed,
|
||||
trace_id.clone(),
|
||||
parent_span_id.clone(),
|
||||
filter_span_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(chat_completions_history)
|
||||
Ok(chat_history_updated)
|
||||
}
|
||||
|
||||
/// Send request to a specific agent and return the response content
|
||||
async fn send_agent_filter_chain_request(
|
||||
/// Build common MCP headers for requests
|
||||
fn build_mcp_headers(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
original_request: &ChatCompletionsRequest,
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<String, PipelineError> {
|
||||
let mut request = original_request.clone();
|
||||
request.messages = messages.to_vec();
|
||||
agent_id: &str,
|
||||
session_id: Option<&str>,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
) -> Result<HeaderMap, PipelineError> {
|
||||
let trace_parent = format!("00-{}-{}-01", trace_id, parent_span_id);
|
||||
let mut headers = request_headers.clone();
|
||||
headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
let request_body = serde_json::to_string(&request)?;
|
||||
debug!("Sending request to agent {}", agent.id);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
headers.remove(TRACE_PARENT_HEADER);
|
||||
headers.insert(
|
||||
TRACE_PARENT_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(agent_id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent_id.to_string()))?,
|
||||
);
|
||||
|
||||
headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
|
||||
headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
|
||||
);
|
||||
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
if let Some(sid) = session_id {
|
||||
headers.insert(
|
||||
"mcp-session-id",
|
||||
hyper::header::HeaderValue::from_str(sid).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Parse SSE formatted response and extract JSON-RPC data
|
||||
fn parse_sse_response(
|
||||
&self,
|
||||
response_bytes: &[u8],
|
||||
agent_id: &str,
|
||||
) -> Result<String, PipelineError> {
|
||||
let response_str = String::from_utf8_lossy(response_bytes);
|
||||
let lines: Vec<&str> = response_str.lines().collect();
|
||||
|
||||
// Validate SSE format: first line should be "event: message"
|
||||
if lines.is_empty() || lines[0] != "event: message" {
|
||||
warn!(
|
||||
"Invalid SSE response format from agent {}: expected 'event: message' as first line, got: {:?}",
|
||||
agent_id,
|
||||
lines.first()
|
||||
);
|
||||
return Err(PipelineError::NoContentInResponse(format!(
|
||||
"Invalid SSE response format from agent {}: expected 'event: message' as first line",
|
||||
agent_id
|
||||
)));
|
||||
}
|
||||
|
||||
// Find the data line
|
||||
let data_lines: Vec<&str> = lines
|
||||
.iter()
|
||||
.filter(|line| line.starts_with("data: "))
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
if data_lines.len() != 1 {
|
||||
warn!(
|
||||
"Expected exactly one 'data:' line from agent {}, found {}",
|
||||
agent_id,
|
||||
data_lines.len()
|
||||
);
|
||||
return Err(PipelineError::NoContentInResponse(format!(
|
||||
"Expected exactly one 'data:' line from agent {}, found {}",
|
||||
agent_id,
|
||||
data_lines.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Skip "data: " prefix
|
||||
Ok(data_lines[0][6..].to_string())
|
||||
}
|
||||
|
||||
/// Send an MCP request and return the response
|
||||
async fn send_mcp_request(
|
||||
&self,
|
||||
json_rpc_request: &JsonRpcRequest,
|
||||
headers: HeaderMap,
|
||||
agent_id: &str,
|
||||
) -> Result<reqwest::Response, PipelineError> {
|
||||
let request_body = serde_json::to_string(json_rpc_request)?;
|
||||
|
||||
debug!(
|
||||
"Sending MCP request to agent {}: {}",
|
||||
agent_id, request_body
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.url)
|
||||
.headers(agent_headers)
|
||||
.post(format!("{}/mcp", self.url))
|
||||
.headers(headers)
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Build a tools/call JSON-RPC request
|
||||
fn build_tool_call_request(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
messages: &[Message],
|
||||
) -> Result<JsonRpcRequest, PipelineError> {
|
||||
let mut arguments = HashMap::new();
|
||||
arguments.insert("messages".to_string(), serde_json::to_value(messages)?);
|
||||
|
||||
let mut params = HashMap::new();
|
||||
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
|
||||
params.insert("arguments".to_string(), serde_json::to_value(arguments)?);
|
||||
|
||||
Ok(JsonRpcRequest {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: JsonRpcId::String(Uuid::new_v4().to_string()),
|
||||
method: TOOL_CALL_METHOD.to_string(),
|
||||
params: Some(params),
|
||||
})
|
||||
}
|
||||
|
||||
/// Send request to a specific agent and return the response content
|
||||
async fn execute_filter(
|
||||
&mut self,
|
||||
messages: &[Message],
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
trace_collector: Option<&std::sync::Arc<common::traces::TraceCollector>>,
|
||||
trace_id: String,
|
||||
filter_span_id: String,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
// Get or create MCP session
|
||||
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
|
||||
session_id.clone()
|
||||
} else {
|
||||
let session_id = self
|
||||
.get_new_session_id(
|
||||
&agent.id,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.await;
|
||||
self.agent_id_session_map
|
||||
.insert(agent.id.clone(), session_id.clone());
|
||||
session_id
|
||||
};
|
||||
|
||||
info!(
|
||||
"Using MCP session ID {} for agent {}",
|
||||
mcp_session_id, agent.id
|
||||
);
|
||||
|
||||
// Build JSON-RPC request
|
||||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
let json_rpc_request = self.build_tool_call_request(tool_name, messages)?;
|
||||
|
||||
// Generate span ID for this MCP tool call (child of filter span)
|
||||
let mcp_span_id = generate_random_span_id();
|
||||
|
||||
// Build headers
|
||||
let agent_headers =
|
||||
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id), trace_id.clone(), mcp_span_id.clone())?;
|
||||
|
||||
// Send request with tracing
|
||||
let start_time = SystemTime::now();
|
||||
let start_instant = Instant::now();
|
||||
|
||||
let response = self
|
||||
.send_mcp_request(
|
||||
&json_rpc_request,
|
||||
agent_headers,
|
||||
&agent.id,
|
||||
)
|
||||
.await?;
|
||||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
// Parse the response as JSON to extract the content
|
||||
let response_json: serde_json::Value = serde_json::from_slice(&response_bytes)?;
|
||||
let end_time = SystemTime::now();
|
||||
let elapsed = start_instant.elapsed();
|
||||
|
||||
let content = response_json
|
||||
.get("choices")
|
||||
.and_then(|choices| choices.as_array())
|
||||
.and_then(|choices| choices.first())
|
||||
.and_then(|choice| choice.get("message"))
|
||||
.and_then(|message| message.get("content"))
|
||||
.and_then(|content| content.as_str())
|
||||
.ok_or_else(|| PipelineError::NoContentInResponse(agent.id.clone()))?
|
||||
// Record MCP tool call span
|
||||
if let Some(collector) = trace_collector {
|
||||
let mut attrs = HashMap::new();
|
||||
attrs.insert("mcp.method", "tools/call".to_string());
|
||||
attrs.insert("mcp.tool_name", tool_name.to_string());
|
||||
attrs.insert("mcp.session_id", mcp_session_id.clone());
|
||||
attrs.insert("http.status_code", http_status.as_u16().to_string());
|
||||
|
||||
self.record_mcp_span(
|
||||
collector,
|
||||
"tool_call",
|
||||
&agent.id,
|
||||
start_time,
|
||||
end_time,
|
||||
elapsed,
|
||||
Some(attrs),
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
Some(mcp_span_id),
|
||||
);
|
||||
}
|
||||
|
||||
// Handle HTTP errors
|
||||
if !http_status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
return Err(if http_status.is_client_error() {
|
||||
PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
} else {
|
||||
PipelineError::ServerError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
info!(
|
||||
"Response from agent {}: {}",
|
||||
agent.id,
|
||||
String::from_utf8_lossy(&response_bytes)
|
||||
);
|
||||
|
||||
// Parse SSE response
|
||||
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
|
||||
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
|
||||
let response_result = response
|
||||
.result
|
||||
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
|
||||
|
||||
// Check if error field is set in response result
|
||||
if response_result
|
||||
.get("isError")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let error_message = response_result
|
||||
.get("content")
|
||||
.and_then(|v| v.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|v| v.get("text"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown_error")
|
||||
.to_string();
|
||||
|
||||
return Err(PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_message,
|
||||
});
|
||||
}
|
||||
|
||||
// Extract structured content and parse messages
|
||||
let response_json = response_result
|
||||
.get("structuredContent")
|
||||
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
|
||||
|
||||
let messages: Vec<Message> = response_json
|
||||
.get("result")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| PipelineError::NoMessagesInResponse(agent.id.clone()))?
|
||||
.iter()
|
||||
.map(|msg_value| serde_json::from_value(msg_value.clone()))
|
||||
.collect::<Result<Vec<Message>, _>>()
|
||||
.map_err(PipelineError::ParseError)?;
|
||||
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
/// Build an initialize JSON-RPC request
|
||||
fn build_initialize_request(&self) -> JsonRpcRequest {
|
||||
JsonRpcRequest {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: JsonRpcId::String(Uuid::new_v4().to_string()),
|
||||
method: MCP_INITIALIZE.to_string(),
|
||||
params: Some({
|
||||
let mut params = HashMap::new();
|
||||
params.insert(
|
||||
"protocolVersion".to_string(),
|
||||
serde_json::Value::String("2024-11-05".to_string()),
|
||||
);
|
||||
params.insert("capabilities".to_string(), serde_json::json!({}));
|
||||
params.insert(
|
||||
"clientInfo".to_string(),
|
||||
serde_json::json!({
|
||||
"name": BRIGHT_STAFF_SERVICE_NAME,
|
||||
"version": "1.0.0"
|
||||
}),
|
||||
);
|
||||
params
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send initialized notification after session creation
|
||||
async fn send_initialized_notification(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
session_id: &str,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
) -> Result<(), PipelineError> {
|
||||
let initialized_notification = JsonRpcNotification {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
method: MCP_INITIALIZE_NOTIFICATION.to_string(),
|
||||
params: None,
|
||||
};
|
||||
|
||||
let notification_body = serde_json::to_string(&initialized_notification)?;
|
||||
debug!("Sending initialized notification for agent {}", agent_id);
|
||||
|
||||
let headers = self.build_mcp_headers(&HeaderMap::new(), agent_id, Some(session_id), trace_id.clone(), parent_span_id.clone())?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/mcp", self.url))
|
||||
.headers(headers)
|
||||
.body(notification_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
"Initialized notification response status: {}",
|
||||
response.status()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_new_session_id(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
trace_id: String,
|
||||
parent_span_id: String,
|
||||
) -> String {
|
||||
info!("Initializing MCP session for agent {}", agent_id);
|
||||
|
||||
let initialize_request = self.build_initialize_request();
|
||||
let headers = self
|
||||
.build_mcp_headers(&HeaderMap::new(), agent_id, None, trace_id.clone(), parent_span_id.clone())
|
||||
.expect("Failed to build headers for initialization");
|
||||
|
||||
let response = self
|
||||
.send_mcp_request(&initialize_request, headers, agent_id)
|
||||
.await
|
||||
.expect("Failed to initialize MCP session");
|
||||
|
||||
info!("Initialize response status: {}", response.status());
|
||||
|
||||
let session_id = response
|
||||
.headers()
|
||||
.get("mcp-session-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.expect("No mcp-session-id in response")
|
||||
.to_string();
|
||||
|
||||
Ok(content)
|
||||
info!(
|
||||
"Created new MCP session for agent {}: {}",
|
||||
agent_id, session_id
|
||||
);
|
||||
|
||||
// Send initialized notification
|
||||
self.send_initialized_notification(
|
||||
agent_id,
|
||||
&session_id,
|
||||
trace_id.clone(),
|
||||
parent_span_id.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("Failed to send initialized notification");
|
||||
|
||||
session_id
|
||||
}
|
||||
|
||||
/// Send request to terminal agent and return the raw response for streaming
|
||||
pub async fn invoke_upstream_agent(
|
||||
pub async fn invoke_agent(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
original_request: &ChatCompletionsRequest,
|
||||
mut original_request: ProviderRequestType,
|
||||
terminal_agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
trace_id: String,
|
||||
agent_span_id: String,
|
||||
) -> Result<reqwest::Response, PipelineError> {
|
||||
let mut request = original_request.clone();
|
||||
request.messages = messages.to_vec();
|
||||
// let mut request = original_request.clone();
|
||||
original_request.set_messages(messages);
|
||||
|
||||
let request_body = serde_json::to_string(&request)?;
|
||||
let request_body = ProviderRequestType::to_bytes(&original_request).unwrap();
|
||||
// let request_body = serde_json::to_string(&request)?;
|
||||
debug!("Sending request to terminal agent {}", terminal_agent.id);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
// Set traceparent header to make the egress span a child of the agent span
|
||||
if !trace_id.is_empty() && !agent_span_id.is_empty() {
|
||||
let trace_parent = format!("00-{}-{}-01", trace_id, agent_span_id);
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
agent_headers.insert(
|
||||
TRACE_PARENT_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&terminal_agent.id)
|
||||
|
|
@ -169,7 +704,7 @@ impl PipelineProcessor {
|
|||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.url)
|
||||
.post(format!("{}/v1/chat/completions", self.url))
|
||||
.headers(agent_headers)
|
||||
.body(request_body)
|
||||
.send()
|
||||
|
|
@ -183,6 +718,7 @@ impl PipelineProcessor {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use mockito::Server;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
|
|
@ -206,23 +742,149 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_agent_not_found_error() {
|
||||
let processor = PipelineProcessor::default();
|
||||
let mut processor = PipelineProcessor::default();
|
||||
let agent_map = HashMap::new();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let initial_request = ChatCompletionsRequest {
|
||||
messages: vec![create_test_message(Role::User, "Hello")],
|
||||
model: "test-model".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
let messages = vec![create_test_message(Role::User, "Hello")];
|
||||
|
||||
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
|
||||
|
||||
let result = processor
|
||||
.process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers)
|
||||
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers, None, String::new(), String::new())
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
matches!(result.unwrap_err(), PipelineError::AgentNotFound(_));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_filter_http_status_error() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _m = server
|
||||
.mock("POST", "/mcp")
|
||||
.with_status(500)
|
||||
.with_body("boom")
|
||||
.create();
|
||||
|
||||
let server_url = server.url();
|
||||
let mut processor = PipelineProcessor::new(server_url.clone());
|
||||
processor
|
||||
.agent_id_session_map
|
||||
.insert("agent-1".to_string(), "session-1".to_string());
|
||||
|
||||
let agent = Agent {
|
||||
id: "agent-1".to_string(),
|
||||
transport: None,
|
||||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Hello")];
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-123".to_string(), "span-123".to_string())
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(PipelineError::ServerError { status, body, .. }) => {
|
||||
assert_eq!(status, 500);
|
||||
assert_eq!(body, "boom");
|
||||
}
|
||||
_ => panic!("Expected server error for 500 status"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_filter_http_client_error() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _m = server
|
||||
.mock("POST", "/mcp")
|
||||
.with_status(400)
|
||||
.with_body("bad request")
|
||||
.create();
|
||||
|
||||
let server_url = server.url();
|
||||
let mut processor = PipelineProcessor::new(server_url.clone());
|
||||
processor
|
||||
.agent_id_session_map
|
||||
.insert("agent-3".to_string(), "session-3".to_string());
|
||||
|
||||
let agent = Agent {
|
||||
id: "agent-3".to_string(),
|
||||
transport: None,
|
||||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Ping")];
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-456".to_string(), "span-456".to_string())
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(PipelineError::ClientError { status, body, .. }) => {
|
||||
assert_eq!(status, 400);
|
||||
assert_eq!(body, "bad request");
|
||||
}
|
||||
_ => panic!("Expected client error for 400 status"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_filter_mcp_error_flag() {
|
||||
let rpc_body = serde_json::json!({
|
||||
"jsonrpc": JSON_RPC_VERSION,
|
||||
"id": "1",
|
||||
"result": {
|
||||
"isError": true,
|
||||
"content": [
|
||||
{ "text": "bad tool call" }
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string());
|
||||
|
||||
let mut server = Server::new_async().await;
|
||||
let _m = server
|
||||
.mock("POST", "/mcp")
|
||||
.with_status(200)
|
||||
.with_body(sse_body)
|
||||
.create();
|
||||
|
||||
let server_url = server.url();
|
||||
let mut processor = PipelineProcessor::new(server_url.clone());
|
||||
processor
|
||||
.agent_id_session_map
|
||||
.insert("agent-2".to_string(), "session-2".to_string());
|
||||
|
||||
let agent = Agent {
|
||||
id: "agent-2".to_string(),
|
||||
transport: None,
|
||||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Hi")];
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-789".to_string(), "span-789".to_string())
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(PipelineError::ClientError { status, body, .. }) => {
|
||||
assert_eq!(status, 200);
|
||||
assert_eq!(body, "bad tool call");
|
||||
}
|
||||
_ => panic!("Expected client error when isError flag is set"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
249
crates/brightstaff/src/handlers/router_chat.rs
Normal file
249
crates/brightstaff/src/handlers/router_chat.rs
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use common::consts::{REQUEST_ID_HEADER};
|
||||
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hyper::StatusCode;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::tracing::{OperationNameBuilder, operation_component, http, routing};
|
||||
|
||||
pub struct RoutingResult {
|
||||
pub model_name: String
|
||||
}
|
||||
|
||||
pub struct RoutingError {
|
||||
pub message: String,
|
||||
pub status_code: StatusCode,
|
||||
}
|
||||
|
||||
impl RoutingError {
|
||||
pub fn internal_error(message: String) -> Self {
|
||||
Self {
|
||||
message,
|
||||
status_code: StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines the routing decision if
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
|
||||
/// * `Err(RoutingError)` - Contains error details and optional span ID
|
||||
pub async fn router_chat_get_upstream_model(
|
||||
router_service: Arc<RouterService>,
|
||||
client_request: ProviderRequestType,
|
||||
request_headers: &hyper::HeaderMap,
|
||||
trace_collector: Arc<TraceCollector>,
|
||||
traceparent: &str,
|
||||
request_path: &str,
|
||||
) -> Result<RoutingResult, RoutingError> {
|
||||
// Clone metadata for routing before converting (which consumes client_request)
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
let request_id = request_headers
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
// Convert to ChatCompletionsRequest for routing (regardless of input type)
|
||||
let chat_request = match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(
|
||||
hermesllm::apis::OpenAIApi::ChatCompletions,
|
||||
),
|
||||
)) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(
|
||||
ProviderRequestType::MessagesRequest(_)
|
||||
| ProviderRequestType::BedrockConverse(_)
|
||||
| ProviderRequestType::BedrockConverseStream(_)
|
||||
| ProviderRequestType::ResponsesAPIRequest(_),
|
||||
) => {
|
||||
warn!("Unexpected: got non-ChatCompletions request after converting to OpenAI format");
|
||||
return Err(RoutingError::internal_error(
|
||||
"Request conversion failed".to_string(),
|
||||
));
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
|
||||
return Err(RoutingError::internal_error(format!(
|
||||
"Failed to convert request: {}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"[PLANO_REQ_ID: {}]: ROUTER_REQ: {}",
|
||||
request_id,
|
||||
&serde_json::to_string(&chat_request).unwrap()
|
||||
);
|
||||
|
||||
// Extract trace_parent from headers
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
.find(|(ty, _)| ty.as_str() == "traceparent")
|
||||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
// Extract usage preferences from metadata
|
||||
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
|
||||
// Prepare log message with latest message from chat request
|
||||
let latest_message_for_log = chat_request
|
||||
.messages
|
||||
.last()
|
||||
.map_or("None".to_string(), |msg| {
|
||||
msg.content.to_string().replace('\n', "\\n")
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
|
||||
let truncated: String = latest_message_for_log
|
||||
.chars()
|
||||
.take(MAX_MESSAGE_LENGTH)
|
||||
.collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
latest_message_for_log
|
||||
};
|
||||
|
||||
info!(
|
||||
"[PLANO_REQ_ID: {}] | ROUTER_REQ | Usage preferences from request: {}, request_path: {}, latest message: {}",
|
||||
request_id,
|
||||
usage_preferences.is_some(),
|
||||
request_path,
|
||||
latest_message_for_log
|
||||
);
|
||||
|
||||
// Capture start time for routing span
|
||||
let routing_start_time = std::time::Instant::now();
|
||||
let routing_start_system_time = std::time::SystemTime::now();
|
||||
|
||||
// Attempt to determine route using the router service
|
||||
let routing_result = router_service
|
||||
.determine_route(&chat_request.messages, trace_parent, usage_preferences)
|
||||
.await;
|
||||
|
||||
match routing_result {
|
||||
Ok(route) => match route {
|
||||
Some((_, model_name)) => {
|
||||
// Record successful routing span
|
||||
let mut attrs: HashMap<String, String> = HashMap::new();
|
||||
attrs.insert("route.selected_model".to_string(), model_name.clone());
|
||||
record_routing_span(
|
||||
trace_collector,
|
||||
traceparent,
|
||||
routing_start_time,
|
||||
routing_start_system_time,
|
||||
attrs,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name
|
||||
})
|
||||
}
|
||||
None => {
|
||||
// No route determined, use default model from request
|
||||
info!(
|
||||
"[PLANO_REQ_ID: {}] | ROUTER_REQ | No route determined, using default model from request: {}",
|
||||
request_id,
|
||||
chat_request.model
|
||||
);
|
||||
|
||||
let default_model = chat_request.model.clone();
|
||||
let mut attrs = HashMap::new();
|
||||
attrs.insert("route.selected_model".to_string(), default_model.clone());
|
||||
record_routing_span(
|
||||
trace_collector,
|
||||
traceparent,
|
||||
routing_start_time,
|
||||
routing_start_system_time,
|
||||
attrs,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name: default_model
|
||||
})
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
// Record failed routing span
|
||||
let mut attrs = HashMap::new();
|
||||
attrs.insert("route.selected_model".to_string(), "unknown".to_string());
|
||||
attrs.insert("error.message".to_string(), err.to_string());
|
||||
record_routing_span(
|
||||
trace_collector,
|
||||
traceparent,
|
||||
routing_start_time,
|
||||
routing_start_system_time,
|
||||
attrs,
|
||||
)
|
||||
.await;
|
||||
|
||||
Err(RoutingError::internal_error(
|
||||
format!("Failed to determine route: {}", err)
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to record a routing span with the given attributes.
|
||||
/// Reduces code duplication across different routing outcomes.
|
||||
async fn record_routing_span(
|
||||
trace_collector: Arc<TraceCollector>,
|
||||
traceparent: &str,
|
||||
start_time: std::time::Instant,
|
||||
start_system_time: std::time::SystemTime,
|
||||
attrs: HashMap<String, String>,
|
||||
) {
|
||||
// The routing always uses OpenAI Chat Completions format internally,
|
||||
// so we log that as the actual API being used for routing
|
||||
let routing_api_path = "/v1/chat/completions";
|
||||
|
||||
let routing_operation_name = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(routing_api_path)
|
||||
.with_target("Arch-Router-1.5B")
|
||||
.build();
|
||||
|
||||
let (trace_id, parent_span_id) = parse_traceparent(traceparent);
|
||||
|
||||
// Build the routing span directly using constants
|
||||
let mut span_builder = SpanBuilder::new(&routing_operation_name)
|
||||
.with_trace_id(&trace_id)
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_start_time(start_system_time)
|
||||
.with_end_time(std::time::SystemTime::now())
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, routing_api_path.to_string())
|
||||
.with_attribute(routing::ROUTE_DETERMINATION_MS, start_time.elapsed().as_millis().to_string());
|
||||
|
||||
// Only set parent span ID if it exists (not a root span)
|
||||
if let Some(parent) = parent_span_id {
|
||||
span_builder = span_builder.with_parent_span_id(&parent);
|
||||
}
|
||||
|
||||
// Add all custom attributes
|
||||
for (key, value) in attrs {
|
||||
span_builder = span_builder.with_attribute(key, value);
|
||||
}
|
||||
|
||||
let span = span_builder.build();
|
||||
|
||||
// Record the span directly to the collector
|
||||
trace_collector.record_span(operation_component::ROUTING, span);
|
||||
}
|
||||
259
crates/brightstaff/src/handlers/utils.rs
Normal file
259
crates/brightstaff/src/handlers/utils.rs
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
use bytes::Bytes;
|
||||
use common::traces::{Span, Attribute, AttributeValue, TraceCollector, Event};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::StreamBody;
|
||||
use hyper::body::Frame;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Instant, SystemTime};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::warn;
|
||||
|
||||
// Import tracing constants
|
||||
use crate::tracing::{llm, error};
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
pub trait StreamProcessor: Send + 'static {
|
||||
/// Process an incoming chunk of bytes
|
||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String>;
|
||||
|
||||
/// Called when the first bytes are received (for time-to-first-token tracking)
|
||||
fn on_first_bytes(&mut self) {}
|
||||
|
||||
/// Called when streaming completes successfully
|
||||
fn on_complete(&mut self) {}
|
||||
|
||||
/// Called when streaming encounters an error
|
||||
fn on_error(&mut self, _error: &str) {}
|
||||
}
|
||||
|
||||
/// A processor that tracks streaming metrics and finalizes the span
|
||||
pub struct ObservableStreamProcessor {
|
||||
collector: Arc<TraceCollector>,
|
||||
service_name: String,
|
||||
span: Span,
|
||||
total_bytes: usize,
|
||||
chunk_count: usize,
|
||||
start_time: Instant,
|
||||
time_to_first_token: Option<u128>,
|
||||
}
|
||||
|
||||
impl ObservableStreamProcessor {
|
||||
/// Create a new passthrough processor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `collector` - The trace collector to record the span to
|
||||
/// * `service_name` - The service name for this span (e.g., "archgw(llm)")
|
||||
/// * `span` - The span to finalize after streaming completes
|
||||
/// * `start_time` - When the request started (for duration calculation)
|
||||
pub fn new(
|
||||
collector: Arc<TraceCollector>,
|
||||
service_name: impl Into<String>,
|
||||
span: Span,
|
||||
start_time: Instant,
|
||||
) -> Self {
|
||||
Self {
|
||||
collector,
|
||||
service_name: service_name.into(),
|
||||
span,
|
||||
total_bytes: 0,
|
||||
chunk_count: 0,
|
||||
start_time,
|
||||
time_to_first_token: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamProcessor for ObservableStreamProcessor {
|
||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||
self.total_bytes += chunk.len();
|
||||
self.chunk_count += 1;
|
||||
Ok(Some(chunk))
|
||||
}
|
||||
|
||||
fn on_first_bytes(&mut self) {
|
||||
// Record time to first token (only for streaming)
|
||||
if self.time_to_first_token.is_none() {
|
||||
self.time_to_first_token = Some(self.start_time.elapsed().as_millis());
|
||||
}
|
||||
}
|
||||
|
||||
fn on_complete(&mut self) {
|
||||
// Update span with streaming metrics and end time
|
||||
let end_time_nanos = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos();
|
||||
|
||||
self.span.end_time_unix_nano = format!("{}", end_time_nanos);
|
||||
|
||||
// Add streaming metrics as attributes using constants
|
||||
self.span.attributes.push(Attribute {
|
||||
key: llm::RESPONSE_BYTES.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(self.total_bytes.to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
self.span.attributes.push(Attribute {
|
||||
key: llm::DURATION_MS.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(self.start_time.elapsed().as_millis().to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
// Add time to first token if available (streaming only)
|
||||
if let Some(ttft) = self.time_to_first_token {
|
||||
self.span.attributes.push(Attribute {
|
||||
key: llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(ttft.to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
// Add time to first token as a span event
|
||||
// Calculate the timestamp by adding ttft duration to span start time
|
||||
if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::<u128>() {
|
||||
// Convert ttft from milliseconds to nanoseconds and add to start time
|
||||
let event_timestamp = start_time_nanos + (ttft * 1_000_000);
|
||||
let mut event = Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
|
||||
event.add_attribute(
|
||||
llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
|
||||
ttft.to_string(),
|
||||
);
|
||||
|
||||
// Initialize events vector if needed
|
||||
if self.span.events.is_none() {
|
||||
self.span.events = Some(Vec::new());
|
||||
}
|
||||
|
||||
if let Some(ref mut events) = self.span.events {
|
||||
events.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Record the finalized span
|
||||
self.collector.record_span(&self.service_name, self.span.clone());
|
||||
}
|
||||
|
||||
fn on_error(&mut self, error_msg: &str) {
|
||||
warn!("Stream error in PassthroughProcessor: {}", error_msg);
|
||||
|
||||
// Update span with error info and end time
|
||||
let end_time_nanos = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos();
|
||||
|
||||
self.span.end_time_unix_nano = format!("{}", end_time_nanos);
|
||||
|
||||
self.span.attributes.push(Attribute {
|
||||
key: error::ERROR.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some("true".to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
self.span.attributes.push(Attribute {
|
||||
key: error::MESSAGE.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(error_msg.to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
self.span.attributes.push(Attribute {
|
||||
key: llm::DURATION_MS.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(self.start_time.elapsed().as_millis().to_string()),
|
||||
},
|
||||
});
|
||||
|
||||
// Record the error span
|
||||
self.collector.record_span(&self.service_name, self.span.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of creating a streaming response
|
||||
pub struct StreamingResponse {
|
||||
pub body: BoxBody<Bytes, hyper::Error>,
|
||||
pub processor_handle: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
pub fn create_streaming_response<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut processor: P,
|
||||
buffer_size: usize,
|
||||
) -> StreamingResponse
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size);
|
||||
|
||||
// Spawn a task to process and forward chunks
|
||||
let processor_handle = tokio::spawn(async move {
|
||||
let mut is_first_chunk = true;
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Error receiving chunk: {:?}", err);
|
||||
warn!("{}", err_msg);
|
||||
processor.on_error(&err_msg);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Call on_first_bytes for the first chunk
|
||||
if is_first_chunk {
|
||||
processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
// Process the chunk
|
||||
match processor.process_chunk(chunk) {
|
||||
Ok(Some(processed_chunk)) => {
|
||||
if tx.send(processed_chunk).await.is_err() {
|
||||
warn!("Receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
// Skip this chunk
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Processor error: {}", err);
|
||||
processor.on_error(&err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
processor.on_complete();
|
||||
});
|
||||
|
||||
// Convert channel receiver to HTTP stream
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
StreamingResponse {
|
||||
body: stream_body,
|
||||
processor_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates a message to the specified maximum length, adding "..." if truncated.
|
||||
pub fn truncate_message(message: &str, max_length: usize) -> String {
|
||||
if message.chars().count() > max_length {
|
||||
let truncated: String = message.chars().take(max_length).collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
message.to_string()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
pub mod handlers;
|
||||
pub mod router;
|
||||
pub mod state;
|
||||
pub mod tracing;
|
||||
pub mod utils;
|
||||
|
|
|
|||
|
|
@ -1,11 +1,16 @@
|
|||
use brightstaff::handlers::agent_chat_completions::agent_chat;
|
||||
use brightstaff::handlers::chat_completions::chat;
|
||||
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
||||
use brightstaff::handlers::llm::llm_chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::Configuration;
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH};
|
||||
use common::configuration::{Agent, Configuration};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
use common::traces::TraceCollector;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::server::conn::http1;
|
||||
|
|
@ -45,10 +50,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let _tracer_provider = init_tracer();
|
||||
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
|
||||
|
||||
info!(
|
||||
"current working directory: {}",
|
||||
env::current_dir().unwrap().display()
|
||||
);
|
||||
// loading arch_config.yaml file
|
||||
let arch_config_path = env::var("ARCH_CONFIG_PATH_RENDERED")
|
||||
.unwrap_or_else(|_| "./arch_config_rendered.yaml".to_string());
|
||||
|
|
@ -62,22 +63,23 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let arch_config = Arc::new(config);
|
||||
|
||||
// combine agents and filters into a single list of agents
|
||||
let all_agents: Vec<Agent> = arch_config
|
||||
.agents
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.chain(arch_config.filters.as_deref().unwrap_or_default())
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
|
||||
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
|
||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
|
||||
|
||||
debug!(
|
||||
"arch_config: {:?}",
|
||||
&serde_json::to_string(arch_config.as_ref()).unwrap()
|
||||
);
|
||||
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
info!("llm provider url: {}", llm_provider_url);
|
||||
info!("listening on http://{}", bind_address);
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
|
||||
let routing_model_name: String = arch_config
|
||||
.routing
|
||||
.as_ref()
|
||||
|
|
@ -99,18 +101,69 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let model_aliases = Arc::new(arch_config.model_aliases.clone());
|
||||
|
||||
// Initialize trace collector and start background flusher
|
||||
// Tracing is enabled if the tracing config is present in arch_config.yaml
|
||||
// Pass Some(true/false) to override, or None to use env var OTEL_TRACING_ENABLED
|
||||
let tracing_enabled = if arch_config.tracing.is_some() {
|
||||
info!("Tracing configuration found in arch_config.yaml");
|
||||
Some(true)
|
||||
} else {
|
||||
info!(
|
||||
"No tracing configuration in arch_config.yaml, will check OTEL_TRACING_ENABLED env var"
|
||||
);
|
||||
None
|
||||
};
|
||||
let trace_collector = Arc::new(TraceCollector::new(tracing_enabled));
|
||||
let _flusher_handle = trace_collector.clone().start_background_flusher();
|
||||
|
||||
// Initialize conversation state storage for v1/responses
|
||||
// Configurable via arch_config.yaml state_storage section
|
||||
// If not configured, state management is disabled
|
||||
// Environment variables are substituted by envsubst before config is read
|
||||
let state_storage: Option<Arc<dyn StateStorage>> = if let Some(storage_config) = &arch_config.state_storage {
|
||||
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||
common::configuration::StateStorageType::Memory => {
|
||||
info!("Initialized conversation state storage: Memory");
|
||||
Arc::new(MemoryConversationalStorage::new())
|
||||
}
|
||||
common::configuration::StateStorageType::Postgres => {
|
||||
let connection_string = storage_config
|
||||
.connection_string
|
||||
.as_ref()
|
||||
.expect("connection_string is required for postgres state_storage");
|
||||
|
||||
debug!("Postgres connection string (full): {}", connection_string);
|
||||
info!("Initializing conversation state storage: Postgres");
|
||||
Arc::new(
|
||||
PostgreSQLConversationStorage::new(connection_string.clone())
|
||||
.await
|
||||
.expect("Failed to initialize Postgres state storage"),
|
||||
)
|
||||
}
|
||||
};
|
||||
Some(storage)
|
||||
} else {
|
||||
info!("No state_storage configured - conversation state management disabled");
|
||||
None
|
||||
};
|
||||
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let model_aliases: Arc<
|
||||
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
|
||||
> = Arc::clone(&model_aliases);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = agents_list.clone();
|
||||
let agents_list = combined_agents_filters_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
let trace_collector = trace_collector.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
|
|
@ -119,28 +172,46 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let agents_list = agents_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
let trace_collector = trace_collector.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
chat(req, router_service, fully_qualified_url, model_aliases)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/agents/v1/chat/completions") => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
agent_chat(
|
||||
let path = req.uri().path();
|
||||
// Check if path starts with /agents
|
||||
if path.starts_with("/agents") {
|
||||
// Check if it matches one of the agent API paths
|
||||
let stripped_path = path.strip_prefix("/agents").unwrap();
|
||||
if matches!(
|
||||
stripped_path,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path);
|
||||
return agent_chat(
|
||||
req,
|
||||
router_service,
|
||||
fully_qualified_url,
|
||||
agents_list,
|
||||
listeners,
|
||||
trace_collector,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
.await;
|
||||
}
|
||||
}
|
||||
match (req.method(), path) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, path);
|
||||
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/function_calling") => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, "/v1/chat/completions");
|
||||
function_calling_chat_handler(req, fully_qualified_url)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
|
||||
Ok(list_models(llm_providers).await)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
pub mod llm_router;
|
||||
pub mod orchestrator_model;
|
||||
pub mod orchestrator_model_v1;
|
||||
pub mod router_model;
|
||||
pub mod router_model_v1;
|
||||
|
|
|
|||
30
crates/brightstaff/src/router/orchestrator_model.rs
Normal file
30
crates/brightstaff/src/router/orchestrator_model.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
use common::configuration::AgentUsagePreference;
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OrchestratorModelError {
|
||||
#[error("Failed to parse JSON: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OrchestratorModelError>;
|
||||
|
||||
/// OrchestratorModel trait for handling orchestration requests.
|
||||
/// Unlike RouterModel which returns a single route, OrchestratorModel
|
||||
/// can return multiple routes as the model output format is:
|
||||
/// {"route": ["route_name_1", "route_name_2", ...]}
|
||||
pub trait OrchestratorModel: Send + Sync {
|
||||
fn generate_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: &Option<Vec<AgentUsagePreference>>,
|
||||
) -> ChatCompletionsRequest;
|
||||
/// Returns a vector of (route_name, model_name) tuples for all matched routes.
|
||||
fn parse_response(
|
||||
&self,
|
||||
content: &str,
|
||||
usage_preferences: &Option<Vec<AgentUsagePreference>>,
|
||||
) -> Result<Option<Vec<(String, String)>>>;
|
||||
fn get_model_name(&self) -> String;
|
||||
}
|
||||
1097
crates/brightstaff/src/router/orchestrator_model_v1.rs
Normal file
1097
crates/brightstaff/src/router/orchestrator_model_v1.rs
Normal file
File diff suppressed because it is too large
Load diff
611
crates/brightstaff/src/state/memory.rs
Normal file
611
crates/brightstaff/src/state/memory.rs
Normal file
|
|
@ -0,0 +1,611 @@
|
|||
use super::{OpenAIConversationState, StateStorage, StateStorageError};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// In-memory storage backend for conversation state
|
||||
/// Uses a HashMap wrapped in Arc<RwLock<>> for thread-safe access
|
||||
#[derive(Clone)]
|
||||
pub struct MemoryConversationalStorage {
|
||||
storage: Arc<RwLock<HashMap<String, OpenAIConversationState>>>,
|
||||
}
|
||||
|
||||
impl MemoryConversationalStorage {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
storage: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryConversationalStorage {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StateStorage for MemoryConversationalStorage {
|
||||
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
|
||||
let response_id = state.response_id.clone();
|
||||
let mut storage = self.storage.write().await;
|
||||
|
||||
debug!(
|
||||
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Storing conversation state: model={}, provider={}, input_items={}",
|
||||
response_id, state.model, state.provider, state.input_items.len()
|
||||
);
|
||||
|
||||
storage.insert(response_id, state);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
|
||||
let storage = self.storage.read().await;
|
||||
|
||||
match storage.get(response_id) {
|
||||
Some(state) => {
|
||||
debug!(
|
||||
"[PLANO | MEMORY_STORAGE | RESP_ID:{} | Retrieved conversation state: input_items={}",
|
||||
response_id, state.input_items.len()
|
||||
);
|
||||
Ok(state.clone())
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
"[PLANO_RESP_ID:{} | MEMORY_STORAGE | Conversation state not found",
|
||||
response_id
|
||||
);
|
||||
Err(StateStorageError::NotFound(response_id.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
|
||||
let storage = self.storage.read().await;
|
||||
Ok(storage.contains_key(response_id))
|
||||
}
|
||||
|
||||
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
|
||||
let mut storage = self.storage.write().await;
|
||||
|
||||
if storage.remove(response_id).is_some() {
|
||||
debug!(
|
||||
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Deleted conversation state",
|
||||
response_id
|
||||
);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(StateStorageError::NotFound(response_id.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent};
|
||||
|
||||
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
|
||||
let mut input_items = Vec::new();
|
||||
for i in 0..num_messages {
|
||||
input_items.push(InputItem::Message(InputMessage {
|
||||
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: format!("Message {}", i),
|
||||
}]),
|
||||
}));
|
||||
}
|
||||
|
||||
OpenAIConversationState {
|
||||
response_id: response_id.to_string(),
|
||||
input_items,
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_put_and_get_success() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state: OpenAIConversationState = create_test_state("resp_001", 3);
|
||||
|
||||
// Store
|
||||
storage.put(state.clone()).await.unwrap();
|
||||
|
||||
// Retrieve
|
||||
let retrieved = storage.get("resp_001").await.unwrap();
|
||||
assert_eq!(retrieved.response_id, state.response_id);
|
||||
assert_eq!(retrieved.model, state.model);
|
||||
assert_eq!(retrieved.provider, state.provider);
|
||||
assert_eq!(retrieved.input_items.len(), 3);
|
||||
assert_eq!(retrieved.created_at, state.created_at);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_put_overwrites_existing() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// First state
|
||||
let state1 = create_test_state("resp_002", 2);
|
||||
storage.put(state1).await.unwrap();
|
||||
|
||||
// Overwrite with new state
|
||||
let state2 = OpenAIConversationState {
|
||||
response_id: "resp_002".to_string(),
|
||||
input_items: vec![],
|
||||
created_at: 9999999999,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
storage.put(state2.clone()).await.unwrap();
|
||||
|
||||
// Should retrieve the new state
|
||||
let retrieved = storage.get("resp_002").await.unwrap();
|
||||
assert_eq!(retrieved.model, "gpt-4");
|
||||
assert_eq!(retrieved.provider, "openai");
|
||||
assert_eq!(retrieved.input_items.len(), 0);
|
||||
assert_eq!(retrieved.created_at, 9999999999);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_not_found() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let result = storage.get("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result.unwrap_err() {
|
||||
StateStorageError::NotFound(id) => {
|
||||
assert_eq!(id, "nonexistent");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exists_returns_false_for_nonexistent() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
assert!(!storage.exists("resp_003").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exists_returns_true_after_put() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state = create_test_state("resp_004", 1);
|
||||
|
||||
assert!(!storage.exists("resp_004").await.unwrap());
|
||||
storage.put(state).await.unwrap();
|
||||
assert!(storage.exists("resp_004").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_success() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state = create_test_state("resp_005", 2);
|
||||
|
||||
storage.put(state).await.unwrap();
|
||||
assert!(storage.exists("resp_005").await.unwrap());
|
||||
|
||||
// Delete
|
||||
storage.delete("resp_005").await.unwrap();
|
||||
|
||||
// Should no longer exist
|
||||
assert!(!storage.exists("resp_005").await.unwrap());
|
||||
assert!(storage.get("resp_005").await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_not_found() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let result = storage.delete("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result.unwrap_err() {
|
||||
StateStorageError::NotFound(id) => {
|
||||
assert_eq!(id, "nonexistent");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_combines_inputs() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Create a previous state with 2 messages
|
||||
let prev_state = create_test_state("resp_006", 2);
|
||||
|
||||
// Create current input with 1 message
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "New message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
// Merge
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 3 messages total (2 from prev + 1 current)
|
||||
assert_eq!(merged.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_preserves_order() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Previous state has messages 0 and 1
|
||||
let prev_state = create_test_state("resp_007", 2);
|
||||
|
||||
// Current input has message 2
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Message 2".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Verify order: prev messages first, then current
|
||||
let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") };
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") };
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_empty_current_input() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let prev_state = create_test_state("resp_008", 3);
|
||||
|
||||
let merged = storage.merge(&prev_state, vec![]);
|
||||
|
||||
// Should just have the previous state's items
|
||||
assert_eq!(merged.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_empty_previous_state() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_009".to_string(),
|
||||
input_items: vec![],
|
||||
created_at: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Only message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should just have the current input
|
||||
assert_eq!(merged.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_access() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Spawn multiple tasks that write concurrently
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..10 {
|
||||
let storage_clone = storage.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let state = create_test_state(&format!("resp_{}", i), i % 3);
|
||||
storage_clone.put(state).await.unwrap();
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// Verify all states were stored
|
||||
for i in 0..10 {
|
||||
assert!(storage.exists(&format!("resp_{}", i)).await.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_operations_on_same_id() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state = create_test_state("resp_010", 1);
|
||||
|
||||
// Put
|
||||
storage.put(state.clone()).await.unwrap();
|
||||
|
||||
// Get
|
||||
let retrieved = storage.get("resp_010").await.unwrap();
|
||||
assert_eq!(retrieved.response_id, "resp_010");
|
||||
|
||||
// Exists
|
||||
assert!(storage.exists("resp_010").await.unwrap());
|
||||
|
||||
// Put again (overwrite)
|
||||
let new_state = create_test_state("resp_010", 5);
|
||||
storage.put(new_state).await.unwrap();
|
||||
|
||||
// Get updated
|
||||
let updated = storage.get("resp_010").await.unwrap();
|
||||
assert_eq!(updated.input_items.len(), 5);
|
||||
|
||||
// Delete
|
||||
storage.delete("resp_010").await.unwrap();
|
||||
|
||||
// Should not exist
|
||||
assert!(!storage.exists("resp_010").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_tool_call_flow() {
|
||||
// This test simulates a realistic tool call conversation flow:
|
||||
// 1. User sends message: "What's the weather?"
|
||||
// 2. Model responds with function call (converted to assistant message)
|
||||
// 3. User sends function call output in next request with previous_response_id
|
||||
// The merge should combine: user message + assistant function call + function output
|
||||
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Step 1: Previous state contains the initial exchange
|
||||
// - User message: "What's the weather in SF?"
|
||||
// - Assistant message (converted from FunctionCall): "Called function: get_weather..."
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_001".to_string(),
|
||||
input_items: vec![
|
||||
// Original user message
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather in San Francisco?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Assistant's function call (converted from OutputItem::FunctionCall)
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
};
|
||||
|
||||
// Step 2: Current request includes function call output
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
// Step 3: Merge should combine all conversation history
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 3 items: user question + assistant function call + function output
|
||||
assert_eq!(merged.len(), 3);
|
||||
|
||||
// Verify the order and content
|
||||
let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") };
|
||||
assert!(matches!(msg1.role, MessageRole::User));
|
||||
match &msg1.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("weather in San Francisco"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") };
|
||||
assert!(matches!(msg2.role, MessageRole::Assistant));
|
||||
match &msg2.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("get_weather"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") };
|
||||
assert!(matches!(msg3.role, MessageRole::User));
|
||||
match &msg3.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("Function result"));
|
||||
assert!(text.contains("temperature"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_multiple_tool_calls() {
|
||||
// Test a more complex scenario with multiple tool calls
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Previous state has: user message + 2 function calls from assistant
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_002".to_string(),
|
||||
input_items: vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather and time in SF?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
|
||||
// Current input: function outputs for both calls
|
||||
let current_input = vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Weather result: {\"temp\": 68}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Time result: {\"time\": \"14:30\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 5 items total: 1 user + 2 assistant calls + 2 function outputs
|
||||
assert_eq!(merged.len(), 5);
|
||||
|
||||
// Verify first item is original user message
|
||||
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
|
||||
assert!(matches!(first.role, MessageRole::User));
|
||||
|
||||
// Verify last two are function outputs
|
||||
let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") };
|
||||
assert!(matches!(second_last.role, MessageRole::User));
|
||||
match &second_last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("Weather result")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
|
||||
assert!(matches!(last.role, MessageRole::User));
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("Time result")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_preserves_conversation_context_for_multi_turn() {
|
||||
// Simulate a multi-turn conversation with tool calls
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Previous state: full conversation history up to this point
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_003".to_string(),
|
||||
input_items: vec![
|
||||
// Turn 1: User asks about weather
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Turn 1: Assistant calls get_weather
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Turn 2: User provides function output
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Weather: sunny, 72°F".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Turn 2: Assistant responds with text
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "It's sunny and 72°F in San Francisco today!".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
};
|
||||
|
||||
// Turn 3: User asks follow-up question
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Should I bring an umbrella?".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have all 5 messages in order
|
||||
assert_eq!(merged.len(), 5);
|
||||
|
||||
// Verify the entire conversation flow is preserved
|
||||
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
|
||||
match &first.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("umbrella")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
}
|
||||
147
crates/brightstaff/src/state/mod.rs
Normal file
147
crates/brightstaff/src/state/mod.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
use async_trait::async_trait;
|
||||
use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug};
|
||||
|
||||
pub mod memory;
|
||||
pub mod response_state_processor;
|
||||
pub mod postgresql;
|
||||
|
||||
/// Represents the conversational state for a v1/responses request
|
||||
/// Contains the complete input/output history that can be restored
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIConversationState {
|
||||
/// The response ID this state is associated with
|
||||
pub response_id: String,
|
||||
|
||||
/// The complete input history (original input + accumulated outputs)
|
||||
/// This is what gets prepended to new requests via previous_response_id
|
||||
pub input_items: Vec<InputItem>,
|
||||
|
||||
/// Timestamp when this state was created
|
||||
pub created_at: i64,
|
||||
|
||||
/// Model used for this response
|
||||
pub model: String,
|
||||
|
||||
/// Provider that generated this response (e.g., "anthropic", "openai")
|
||||
pub provider: String,
|
||||
}
|
||||
|
||||
/// Error types for state storage operations
|
||||
#[derive(Debug)]
|
||||
pub enum StateStorageError {
|
||||
/// State not found for given response_id
|
||||
NotFound(String),
|
||||
|
||||
/// Storage backend error (network, database, etc.)
|
||||
StorageError(String),
|
||||
|
||||
/// Serialization/deserialization error
|
||||
SerializationError(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for StateStorageError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id),
|
||||
StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg),
|
||||
StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for StateStorageError {}
|
||||
|
||||
/// Trait for conversation state storage backends
|
||||
#[async_trait]
|
||||
pub trait StateStorage: Send + Sync {
|
||||
/// Store conversation state for a response
|
||||
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError>;
|
||||
|
||||
/// Retrieve conversation state by response_id
|
||||
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError>;
|
||||
|
||||
/// Check if state exists for a response_id
|
||||
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError>;
|
||||
|
||||
/// Delete state for a response_id (optional, for cleanup)
|
||||
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError>;
|
||||
|
||||
fn merge(
|
||||
&self,
|
||||
prev_state: &OpenAIConversationState,
|
||||
current_input: Vec<InputItem>,
|
||||
) -> Vec<InputItem> {
|
||||
// Default implementation: prepend previous input, append current
|
||||
let prev_count = prev_state.input_items.len();
|
||||
let current_count = current_input.len();
|
||||
|
||||
let mut combined_input = prev_state.input_items.clone();
|
||||
combined_input.extend(current_input);
|
||||
|
||||
debug!(
|
||||
"PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}",
|
||||
prev_state.response_id,
|
||||
prev_count,
|
||||
current_count,
|
||||
combined_input.len(),
|
||||
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
|
||||
);
|
||||
|
||||
combined_input
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Storage backend type enum
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StorageBackend {
|
||||
Memory,
|
||||
Supabase,
|
||||
}
|
||||
|
||||
impl StorageBackend {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"memory" => Some(StorageBackend::Memory),
|
||||
"supabase" => Some(StorageBackend::Supabase),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === Utility functions for state management ===
|
||||
|
||||
/// Extract input items from InputParam, converting text to structured format
|
||||
pub fn extract_input_items(input: &InputParam) -> Vec<InputItem> {
|
||||
match input {
|
||||
InputParam::Text(text) => {
|
||||
vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: text.clone(),
|
||||
}]),
|
||||
})]
|
||||
}
|
||||
InputParam::Items(items) => items.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve previous conversation state and combine with current input
|
||||
/// Returns combined input if previous state found, or original input if not found/error
|
||||
pub async fn retrieve_and_combine_input(
|
||||
storage: Arc<dyn StateStorage>,
|
||||
previous_response_id: &str,
|
||||
current_input: Vec<InputItem>,
|
||||
) -> Result<Vec<InputItem>, StateStorageError> {
|
||||
|
||||
// First get the previous state
|
||||
let prev_state = storage.get(previous_response_id).await?;
|
||||
let combined_input = storage.merge(&prev_state, current_input);
|
||||
Ok(combined_input)
|
||||
}
|
||||
432
crates/brightstaff/src/state/postgresql.rs
Normal file
432
crates/brightstaff/src/state/postgresql.rs
Normal file
|
|
@ -0,0 +1,432 @@
|
|||
use super::{OpenAIConversationState, StateStorage, StateStorageError};
|
||||
use async_trait::async_trait;
|
||||
use serde_json;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::OnceCell;
|
||||
use tokio_postgres::{Client, NoTls};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Supabase/PostgreSQL storage backend for conversation state
|
||||
#[derive(Clone)]
|
||||
pub struct PostgreSQLConversationStorage {
|
||||
client: Arc<Client>,
|
||||
table_verified: Arc<OnceCell<()>>,
|
||||
}
|
||||
|
||||
impl PostgreSQLConversationStorage {
|
||||
/// Creates a new Supabase storage instance with the given connection string
|
||||
pub async fn new(connection_string: String) -> Result<Self, StateStorageError> {
|
||||
let (client, connection) = tokio_postgres::connect(&connection_string, NoTls)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!("Failed to connect to database: {}", e))
|
||||
})?;
|
||||
|
||||
// Spawn the connection to run in the background
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
warn!("Database connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
client: Arc::new(client),
|
||||
table_verified: Arc::new(OnceCell::new()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Ensures the conversation_states table exists (checks once, caches result)
|
||||
async fn ensure_ready(&self) -> Result<(), StateStorageError> {
|
||||
self.table_verified
|
||||
.get_or_try_init(|| async {
|
||||
let row = self
|
||||
.client
|
||||
.query_one(
|
||||
"SELECT EXISTS (
|
||||
SELECT FROM pg_tables
|
||||
WHERE tablename = 'conversation_states'
|
||||
)",
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to verify table existence: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
let exists: bool = row.get(0);
|
||||
|
||||
if !exists {
|
||||
return Err(StateStorageError::StorageError(
|
||||
"Table 'conversation_states' does not exist. \
|
||||
Please run the setup SQL from docs/db_setup/conversation_states.sql"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
info!("Conversation state storage table verified");
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StateStorage for PostgreSQLConversationStorage {
|
||||
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
|
||||
self.ensure_ready().await?;
|
||||
|
||||
// Serialize input_items to JSONB
|
||||
let input_items_json = serde_json::to_value(&state.input_items).map_err(|e| {
|
||||
StateStorageError::StorageError(format!("Failed to serialize input_items: {}", e))
|
||||
})?;
|
||||
|
||||
// Upsert the conversation state
|
||||
self.client
|
||||
.execute(
|
||||
r#"
|
||||
INSERT INTO conversation_states
|
||||
(response_id, input_items, created_at, model, provider, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW())
|
||||
ON CONFLICT (response_id)
|
||||
DO UPDATE SET
|
||||
input_items = EXCLUDED.input_items,
|
||||
model = EXCLUDED.model,
|
||||
provider = EXCLUDED.provider,
|
||||
updated_at = NOW()
|
||||
"#,
|
||||
&[
|
||||
&state.response_id,
|
||||
&input_items_json,
|
||||
&state.created_at,
|
||||
&state.model,
|
||||
&state.provider,
|
||||
],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to store conversation state for {}: {}",
|
||||
state.response_id, e
|
||||
))
|
||||
})?;
|
||||
|
||||
debug!("Stored conversation state for {}", state.response_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
|
||||
self.ensure_ready().await?;
|
||||
|
||||
let row = self
|
||||
.client
|
||||
.query_opt(
|
||||
r#"
|
||||
SELECT response_id, input_items, created_at, model, provider
|
||||
FROM conversation_states
|
||||
WHERE response_id = $1
|
||||
"#,
|
||||
&[&response_id],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to fetch conversation state for {}: {}",
|
||||
response_id, e
|
||||
))
|
||||
})?;
|
||||
|
||||
match row {
|
||||
Some(row) => {
|
||||
let response_id: String = row.get("response_id");
|
||||
let input_items_json: serde_json::Value = row.get("input_items");
|
||||
let created_at: i64 = row.get("created_at");
|
||||
let model: String = row.get("model");
|
||||
let provider: String = row.get("provider");
|
||||
|
||||
// Deserialize input_items from JSONB
|
||||
let input_items =
|
||||
serde_json::from_value(input_items_json).map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to deserialize input_items: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(OpenAIConversationState {
|
||||
response_id,
|
||||
input_items,
|
||||
created_at,
|
||||
model,
|
||||
provider,
|
||||
})
|
||||
}
|
||||
None => Err(StateStorageError::NotFound(format!(
|
||||
"Conversation state not found for response_id: {}",
|
||||
response_id
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
|
||||
self.ensure_ready().await?;
|
||||
|
||||
let row = self
|
||||
.client
|
||||
.query_one(
|
||||
"SELECT EXISTS(SELECT 1 FROM conversation_states WHERE response_id = $1)",
|
||||
&[&response_id],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to check existence for {}: {}",
|
||||
response_id, e
|
||||
))
|
||||
})?;
|
||||
|
||||
let exists: bool = row.get(0);
|
||||
Ok(exists)
|
||||
}
|
||||
|
||||
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
|
||||
self.ensure_ready().await?;
|
||||
|
||||
let rows_affected = self
|
||||
.client
|
||||
.execute(
|
||||
"DELETE FROM conversation_states WHERE response_id = $1",
|
||||
&[&response_id],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to delete conversation state for {}: {}",
|
||||
response_id, e
|
||||
))
|
||||
})?;
|
||||
|
||||
if rows_affected == 0 {
|
||||
return Err(StateStorageError::NotFound(format!(
|
||||
"Conversation state not found for response_id: {}",
|
||||
response_id
|
||||
)));
|
||||
}
|
||||
|
||||
debug!("Deleted conversation state for {}", response_id);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
PostgreSQL schema is maintained in docs/db_setup/conversation_states.sql
|
||||
Run that SQL file against your database before using this storage backend.
|
||||
*/
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole};
|
||||
|
||||
fn create_test_state(response_id: &str) -> OpenAIConversationState {
|
||||
OpenAIConversationState {
|
||||
response_id: response_id.to_string(),
|
||||
input_items: vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Test message".to_string(),
|
||||
}]),
|
||||
})],
|
||||
created_at: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
// Note: These tests require a running PostgreSQL database
|
||||
// Set TEST_DATABASE_URL environment variable to run integration tests
|
||||
// Example: TEST_DATABASE_URL=postgresql://user:pass@localhost/test_db
|
||||
|
||||
async fn get_test_storage() -> Option<PostgreSQLConversationStorage> {
|
||||
if let Ok(db_url) = std::env::var("TEST_DATABASE_URL") {
|
||||
match PostgreSQLConversationStorage::new(db_url).await {
|
||||
Ok(storage) => Some(storage),
|
||||
Err(e) => {
|
||||
eprintln!("Failed to create test storage: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eprintln!("TEST_DATABASE_URL not set, skipping Supabase integration tests");
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_put_and_get_success() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state = create_test_state("test_resp_001");
|
||||
storage.put(state.clone()).await.unwrap();
|
||||
|
||||
let retrieved = storage.get("test_resp_001").await.unwrap();
|
||||
assert_eq!(retrieved.response_id, "test_resp_001");
|
||||
assert_eq!(retrieved.input_items.len(), 1);
|
||||
assert_eq!(retrieved.model, "gpt-4");
|
||||
assert_eq!(retrieved.provider, "openai");
|
||||
|
||||
// Cleanup
|
||||
let _ = storage.delete("test_resp_001").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_put_overwrites_existing() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state1 = create_test_state("test_resp_002");
|
||||
storage.put(state1).await.unwrap();
|
||||
|
||||
let mut state2 = create_test_state("test_resp_002");
|
||||
state2.model = "gpt-4-turbo".to_string();
|
||||
state2.input_items.push(InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Response".to_string(),
|
||||
}]),
|
||||
}));
|
||||
storage.put(state2).await.unwrap();
|
||||
|
||||
let retrieved = storage.get("test_resp_002").await.unwrap();
|
||||
assert_eq!(retrieved.model, "gpt-4-turbo");
|
||||
assert_eq!(retrieved.input_items.len(), 2);
|
||||
|
||||
// Cleanup
|
||||
let _ = storage.delete("test_resp_002").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_get_not_found() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let result = storage.get("nonexistent_id").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_exists_returns_false() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let exists = storage.exists("nonexistent_id").await.unwrap();
|
||||
assert!(!exists);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_exists_returns_true_after_put() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state = create_test_state("test_resp_003");
|
||||
storage.put(state).await.unwrap();
|
||||
|
||||
let exists = storage.exists("test_resp_003").await.unwrap();
|
||||
assert!(exists);
|
||||
|
||||
// Cleanup
|
||||
let _ = storage.delete("test_resp_003").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_delete_success() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state = create_test_state("test_resp_004");
|
||||
storage.put(state).await.unwrap();
|
||||
|
||||
storage.delete("test_resp_004").await.unwrap();
|
||||
|
||||
let exists = storage.exists("test_resp_004").await.unwrap();
|
||||
assert!(!exists);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_delete_not_found() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let result = storage.delete("nonexistent_id").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_merge_works() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let prev_state = create_test_state("test_resp_005");
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "New message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 2 messages (1 from prev + 1 current)
|
||||
assert_eq!(merged.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_table_verification() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
// This should trigger table verification
|
||||
let result = storage.ensure_ready().await;
|
||||
assert!(result.is_ok(), "Table verification should succeed");
|
||||
|
||||
// Second call should use cached result
|
||||
let result2 = storage.ensure_ready().await;
|
||||
assert!(result2.is_ok(), "Cached verification should succeed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Run manually with: cargo test test_verify_data_in_supabase -- --ignored
|
||||
async fn test_verify_data_in_supabase() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Create a test record that persists
|
||||
let state = create_test_state("manual_test_verification");
|
||||
storage.put(state).await.unwrap();
|
||||
|
||||
println!("✅ Data written to Supabase!");
|
||||
println!("Check your Supabase dashboard:");
|
||||
println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';");
|
||||
println!("\nTo cleanup, run:");
|
||||
println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';");
|
||||
|
||||
// DON'T cleanup - leave it for manual verification
|
||||
}
|
||||
}
|
||||
302
crates/brightstaff/src/state/response_state_processor.rs
Normal file
302
crates/brightstaff/src/state/response_state_processor.rs
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
use bytes::Bytes;
|
||||
use flate2::read::GzDecoder;
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputItem, OutputItem, ResponsesAPIStreamEvent,
|
||||
};
|
||||
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
|
||||
use hermesllm::transforms::response::output_to_input::outputs_to_inputs;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, debug, warn};
|
||||
|
||||
use crate::handlers::utils::StreamProcessor;
|
||||
use crate::state::{OpenAIConversationState, StateStorage};
|
||||
|
||||
/// Processor that wraps another processor and handles v1/responses state management
|
||||
/// Captures response_id and output from streaming responses, stores state after completion
|
||||
pub struct ResponsesStateProcessor<P: StreamProcessor> {
|
||||
/// The underlying processor (e.g., ObservableStreamProcessor for metrics)
|
||||
inner: P,
|
||||
|
||||
/// State storage backend
|
||||
storage: Arc<dyn StateStorage>,
|
||||
|
||||
/// Original input items from the request
|
||||
original_input: Vec<InputItem>,
|
||||
|
||||
/// Model name
|
||||
model: String,
|
||||
|
||||
/// Provider name
|
||||
provider: String,
|
||||
|
||||
/// Whether this is a streaming request
|
||||
is_streaming: bool,
|
||||
|
||||
/// Whether upstream is OpenAI (skip storage if true)
|
||||
is_openai_upstream: bool,
|
||||
|
||||
/// Content-Encoding header value (e.g., "gzip", "br", None)
|
||||
content_encoding: Option<String>,
|
||||
|
||||
/// Request ID for logging
|
||||
request_id: String,
|
||||
|
||||
/// Buffer for accumulating chunks (needed for non-streaming compressed responses)
|
||||
chunk_buffer: Vec<u8>,
|
||||
|
||||
/// Captured response_id from response.completed event
|
||||
response_id: Option<String>,
|
||||
|
||||
/// Captured output items from response.completed event
|
||||
output_items: Option<Vec<OutputItem>>,
|
||||
}
|
||||
|
||||
impl<P: StreamProcessor> ResponsesStateProcessor<P> {
|
||||
pub fn new(
|
||||
inner: P,
|
||||
storage: Arc<dyn StateStorage>,
|
||||
original_input: Vec<InputItem>,
|
||||
model: String,
|
||||
provider: String,
|
||||
is_streaming: bool,
|
||||
is_openai_upstream: bool,
|
||||
content_encoding: Option<String>,
|
||||
request_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
storage,
|
||||
original_input,
|
||||
model,
|
||||
provider,
|
||||
is_streaming,
|
||||
is_openai_upstream,
|
||||
content_encoding,
|
||||
request_id,
|
||||
chunk_buffer: Vec::new(),
|
||||
response_id: None,
|
||||
output_items: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Decompress accumulated buffer based on Content-Encoding header
|
||||
fn decompress_buffer(&self) -> Vec<u8> {
|
||||
if self.chunk_buffer.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
match self.content_encoding.as_deref() {
|
||||
Some("gzip") => {
|
||||
let mut decoder = GzDecoder::new(self.chunk_buffer.as_slice());
|
||||
let mut decompressed = Vec::new();
|
||||
match decoder.read_to_end(&mut decompressed) {
|
||||
Ok(_) => {
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Successfully decompressed {} bytes to {} bytes",
|
||||
self.request_id,
|
||||
self.chunk_buffer.len(),
|
||||
decompressed.len()
|
||||
);
|
||||
decompressed
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to decompress gzip buffer: {}",
|
||||
self.request_id,
|
||||
e
|
||||
);
|
||||
self.chunk_buffer.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(encoding) => {
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Unsupported Content-Encoding: {}. Only gzip is currently supported.",
|
||||
self.request_id,
|
||||
encoding
|
||||
);
|
||||
self.chunk_buffer.clone()
|
||||
}
|
||||
None => self.chunk_buffer.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse response to extract response_id and output
|
||||
/// For streaming: parse SSE events looking for response.completed (per chunk)
|
||||
/// For non-streaming: buffer all chunks, then decompress and parse on completion
|
||||
fn try_parse_response_chunk(&mut self, chunk: &[u8]) {
|
||||
if self.is_streaming {
|
||||
// Streaming: Try to parse SSE events from this chunk
|
||||
// Note: For compressed streaming, we'd need to buffer and decompress first
|
||||
// but most streaming responses aren't compressed since SSE needs to be readable
|
||||
let sse_iter = match SseStreamIter::try_from(chunk) {
|
||||
Ok(iter) => iter,
|
||||
Err(_) => return, // Not valid SSE format, skip
|
||||
};
|
||||
|
||||
// Process each SSE event in the chunk, looking for data lines with response.completed
|
||||
for event in sse_iter {
|
||||
// Only process data lines (skip event-only lines)
|
||||
if let Some(data_str) = &event.data {
|
||||
// Try to parse as ResponsesAPIStreamEvent
|
||||
if let Ok(stream_event) = serde_json::from_str::<ResponsesAPIStreamEvent>(data_str) {
|
||||
// Check if this is a ResponseCompleted event
|
||||
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
|
||||
self.request_id,
|
||||
response.id,
|
||||
response.output.len()
|
||||
);
|
||||
self.response_id = Some(response.id.clone());
|
||||
self.output_items = Some(response.output.clone());
|
||||
return; // Found what we need, exit early
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Non-streaming: Buffer chunks, will decompress and parse on completion
|
||||
self.chunk_buffer.extend_from_slice(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse buffered non-streaming response (called on completion)
|
||||
fn try_parse_buffered_response(&mut self) {
|
||||
if self.is_streaming || self.chunk_buffer.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Decompress if needed
|
||||
let decompressed = self.decompress_buffer();
|
||||
|
||||
// Parse complete JSON response
|
||||
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(&decompressed) {
|
||||
Ok(response) => {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}",
|
||||
self.request_id,
|
||||
response.id,
|
||||
response.output.len()
|
||||
);
|
||||
self.response_id = Some(response.id.clone());
|
||||
self.output_items = Some(response.output.clone());
|
||||
}
|
||||
Err(e) => {
|
||||
// Log parse error with chunk preview for debugging
|
||||
let chunk_preview = String::from_utf8_lossy(&decompressed);
|
||||
let preview_len = chunk_preview.len().min(200);
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to parse non-streaming ResponsesAPIResponse: {}. Decompressed preview (first {} bytes): {}",
|
||||
self.request_id,
|
||||
e,
|
||||
preview_len,
|
||||
&chunk_preview[..preview_len]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: StreamProcessor> StreamProcessor for ResponsesStateProcessor<P> {
|
||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||
// Buffer/parse chunk for response extraction
|
||||
self.try_parse_response_chunk(&chunk);
|
||||
|
||||
// Forward to inner processor
|
||||
self.inner.process_chunk(chunk)
|
||||
}
|
||||
|
||||
fn on_first_bytes(&mut self) {
|
||||
self.inner.on_first_bytes();
|
||||
}
|
||||
|
||||
fn on_complete(&mut self) {
|
||||
// For non-streaming, decompress and parse buffered response
|
||||
self.try_parse_buffered_response();
|
||||
|
||||
// First, let the inner processor complete
|
||||
self.inner.on_complete();
|
||||
|
||||
// Skip storage for OpenAI upstream
|
||||
if self.is_openai_upstream {
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Skipping state storage for OpenAI upstream provider",
|
||||
self.request_id
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Store state if we captured response_id and output
|
||||
if let (Some(response_id), Some(output_items)) = (&self.response_id, &self.output_items) {
|
||||
// Convert output items to input items for next request
|
||||
let output_as_inputs = outputs_to_inputs(output_items);
|
||||
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Converting outputs to inputs: output_items_count={}, converted_input_items_count={}",
|
||||
self.request_id, output_items.len(), output_as_inputs.len()
|
||||
);
|
||||
|
||||
// Combine original input + output as new input history
|
||||
let mut combined_input = self.original_input.clone();
|
||||
combined_input.extend(output_as_inputs);
|
||||
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Storing state: original_input_count={}, combined_input_count={}, combined_json={}",
|
||||
self.request_id,
|
||||
self.original_input.len(),
|
||||
combined_input.len(),
|
||||
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
|
||||
);
|
||||
|
||||
let state = OpenAIConversationState {
|
||||
response_id: response_id.clone(),
|
||||
input_items: combined_input,
|
||||
created_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs() as i64,
|
||||
model: self.model.clone(),
|
||||
provider: self.provider.clone(),
|
||||
};
|
||||
|
||||
// Store asynchronously (fire and forget with logging)
|
||||
let storage = self.storage.clone();
|
||||
let response_id_clone = response_id.clone();
|
||||
let request_id = self.request_id.clone();
|
||||
let items_count = state.input_items.len();
|
||||
tokio::spawn(async move {
|
||||
match storage.put(state).await {
|
||||
Ok(()) => {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Successfully stored conversation state for response_id: {}, items_count={}",
|
||||
request_id,
|
||||
response_id_clone,
|
||||
items_count
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to store conversation state for response_id {}: {}",
|
||||
request_id,
|
||||
response_id_clone,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | No response_id captured from upstream response - cannot store conversation state. response_id present: {}, output present: {}",
|
||||
self.request_id,
|
||||
self.response_id.is_some(),
|
||||
self.output_items.is_some()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn on_error(&mut self, error: &str) {
|
||||
self.inner.on_error(error);
|
||||
}
|
||||
}
|
||||
335
crates/brightstaff/src/tracing/constants.rs
Normal file
335
crates/brightstaff/src/tracing/constants.rs
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
/// OpenTelemetry Semantic Conventions
|
||||
///
|
||||
/// This module defines standard attribute keys following OTEL semantic conventions.
|
||||
/// See: https://opentelemetry.io/docs/specs/semconv/
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - HTTP
|
||||
// =============================================================================
|
||||
|
||||
/// Semantic conventions for HTTP-related span attributes
|
||||
pub mod http {
|
||||
/// HTTP request method
|
||||
/// Example: "GET", "POST", "PUT"
|
||||
pub const METHOD: &str = "http.method";
|
||||
|
||||
/// HTTP response status code
|
||||
/// Example: "200", "404", "500"
|
||||
pub const STATUS_CODE: &str = "http.status_code";
|
||||
|
||||
/// Full HTTP request URL
|
||||
pub const URL: &str = "http.url";
|
||||
|
||||
/// HTTP request target (path + query)
|
||||
/// Example: "/v1/chat/completions?stream=true"
|
||||
pub const TARGET: &str = "http.target";
|
||||
|
||||
/// Upstream target path after routing transformation
|
||||
/// Example: "/api/paas/v4/chat/completions" (for Zhipu provider)
|
||||
pub const UPSTREAM_TARGET: &str = "http.upstream_target";
|
||||
|
||||
/// HTTP request scheme
|
||||
/// Example: "http", "https"
|
||||
pub const SCHEME: &str = "http.scheme";
|
||||
|
||||
/// Value of the HTTP User-Agent header
|
||||
pub const USER_AGENT: &str = "http.user_agent";
|
||||
|
||||
/// Size of the request payload body in bytes
|
||||
pub const REQUEST_CONTENT_LENGTH: &str = "http.request_content_length";
|
||||
|
||||
/// Size of the response payload body in bytes
|
||||
pub const RESPONSE_CONTENT_LENGTH: &str = "http.response_content_length";
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - LLM Specific
|
||||
// =============================================================================
|
||||
|
||||
/// Custom attributes for LLM operations
|
||||
/// These follow the emerging OTEL GenAI semantic conventions
|
||||
pub mod llm {
|
||||
/// Name of the LLM model being called
|
||||
/// Example: "gpt-4", "claude-3-sonnet", "llama-2-70b"
|
||||
pub const MODEL_NAME: &str = "llm.model";
|
||||
|
||||
/// Provider of the LLM
|
||||
/// Example: "openai", "anthropic", "azure-openai"
|
||||
pub const PROVIDER: &str = "llm.provider";
|
||||
|
||||
/// Type of LLM operation
|
||||
/// Example: "chat", "completion", "embedding"
|
||||
pub const OPERATION_TYPE: &str = "llm.operation_type";
|
||||
|
||||
/// Whether the request is streaming
|
||||
pub const IS_STREAMING: &str = "llm.is_streaming";
|
||||
|
||||
/// Total bytes received in the response
|
||||
pub const RESPONSE_BYTES: &str = "llm.response_bytes";
|
||||
|
||||
/// Duration of the LLM call in milliseconds
|
||||
pub const DURATION_MS: &str = "llm.duration_ms";
|
||||
|
||||
/// Time to first token in milliseconds (streaming only)
|
||||
pub const TIME_TO_FIRST_TOKEN_MS: &str = "llm.time_to_first_token";
|
||||
|
||||
/// Number of prompt tokens used
|
||||
pub const PROMPT_TOKENS: &str = "llm.usage.prompt_tokens";
|
||||
|
||||
/// Number of completion tokens generated
|
||||
pub const COMPLETION_TOKENS: &str = "llm.usage.completion_tokens";
|
||||
|
||||
/// Total tokens used (prompt + completion)
|
||||
pub const TOTAL_TOKENS: &str = "llm.usage.total_tokens";
|
||||
|
||||
/// Temperature parameter used
|
||||
pub const TEMPERATURE: &str = "llm.temperature";
|
||||
|
||||
/// Max tokens parameter used
|
||||
pub const MAX_TOKENS: &str = "llm.max_tokens";
|
||||
|
||||
/// Top-p parameter used
|
||||
pub const TOP_P: &str = "llm.top_p";
|
||||
|
||||
/// List of tool names provided in the request
|
||||
pub const TOOLS: &str = "llm.tools";
|
||||
|
||||
/// Preview of the user message (truncated)
|
||||
pub const USER_MESSAGE_PREVIEW: &str = "llm.user_message_preview";
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - Routing & Gateway
|
||||
// =============================================================================
|
||||
|
||||
/// Attributes specific to LLM routing and gateway operations
|
||||
pub mod routing {
|
||||
/// Strategy used to select the LLM endpoint
|
||||
/// Example: "round-robin", "least-latency", "cost-optimized"
|
||||
pub const STRATEGY: &str = "routing.strategy";
|
||||
|
||||
/// Selected upstream endpoint
|
||||
pub const UPSTREAM_ENDPOINT: &str = "routing.upstream_endpoint";
|
||||
|
||||
/// Time taken to determine the route in milliseconds
|
||||
pub const ROUTE_DETERMINATION_MS: &str = "routing.determination_ms";
|
||||
|
||||
/// Whether a fallback endpoint was used
|
||||
pub const IS_FALLBACK: &str = "routing.is_fallback";
|
||||
|
||||
/// Reason for route selection
|
||||
pub const SELECTION_REASON: &str = "routing.selection_reason";
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - Error Handling
|
||||
// =============================================================================
|
||||
|
||||
/// Attributes for error and exception tracking
|
||||
pub mod error {
|
||||
/// Whether an error occurred
|
||||
pub const ERROR: &str = "error";
|
||||
|
||||
/// Type/class of the error
|
||||
/// Example: "TimeoutError", "AuthenticationError"
|
||||
pub const TYPE: &str = "error.type";
|
||||
|
||||
/// Error message
|
||||
pub const MESSAGE: &str = "error.message";
|
||||
|
||||
/// Stack trace of the error
|
||||
pub const STACK_TRACE: &str = "error.stack_trace";
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Operation Names
|
||||
// =============================================================================
|
||||
|
||||
/// Canonical operation name components for Arch Gateway
|
||||
pub mod operation_component {
|
||||
/// Inbound request handling
|
||||
pub const INBOUND: &str = "plano(inbound)";
|
||||
|
||||
/// Routing decision phase
|
||||
pub const ROUTING: &str = "plano(routing)";
|
||||
|
||||
/// Handoff to upstream service
|
||||
pub const HANDOFF: &str = "plano(handoff)";
|
||||
|
||||
/// Agent filter execution
|
||||
pub const AGENT_FILTER: &str = "plano(filter)";
|
||||
|
||||
/// Agent execution
|
||||
pub const AGENT: &str = "plano(agent)";
|
||||
|
||||
/// LLM call
|
||||
pub const LLM: &str = "plano(llm)";
|
||||
}
|
||||
|
||||
/// Builder for constructing standardized operation names
|
||||
///
|
||||
/// Format: `{method} {path} {target}`
|
||||
///
|
||||
/// The operation component (e.g., "archgw(llm)") is now part of the service name,
|
||||
/// so the operation name focuses on the HTTP request details and target.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use brightstaff::tracing::OperationNameBuilder;
|
||||
///
|
||||
/// // LLM call operation: "POST /v1/chat/completions gpt-4"
|
||||
/// // (service name will be "archgw(llm)")
|
||||
/// let op = OperationNameBuilder::new()
|
||||
/// .with_method("POST")
|
||||
/// .with_path("/v1/chat/completions")
|
||||
/// .with_target("gpt-4")
|
||||
/// .build();
|
||||
///
|
||||
/// // Agent filter operation: "POST /agents/v1/chat/completions hallucination-detector"
|
||||
/// // (service name will be "archgw(agent filter)")
|
||||
/// let op = OperationNameBuilder::new()
|
||||
/// .with_method("POST")
|
||||
/// .with_path("/agents/v1/chat/completions")
|
||||
/// .with_target("hallucination-detector")
|
||||
/// .build();
|
||||
///
|
||||
/// // Routing operation: "POST /v1/chat/completions"
|
||||
/// // (service name will be "archgw(routing)")
|
||||
/// let op = OperationNameBuilder::new()
|
||||
/// .with_method("POST")
|
||||
/// .with_path("/v1/chat/completions")
|
||||
/// .build();
|
||||
/// ```
|
||||
pub struct OperationNameBuilder {
|
||||
method: Option<String>,
|
||||
path: Option<String>,
|
||||
operation: Option<String>,
|
||||
target: Option<String>,
|
||||
}
|
||||
|
||||
impl OperationNameBuilder {
|
||||
/// Create a new operation name builder
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
method: None,
|
||||
path: None,
|
||||
operation: None,
|
||||
target: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the HTTP method
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `method` - HTTP method (e.g., "GET", "POST", "PUT")
|
||||
pub fn with_method(mut self, method: impl Into<String>) -> Self {
|
||||
self.method = Some(method.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the request path
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `path` - Request path (e.g., "/v1/chat/completions", "/agents/v1/chat/completions")
|
||||
pub fn with_path(mut self, path: impl Into<String>) -> Self {
|
||||
self.path = Some(path.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the operation type (optional, for MCP operations)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `operation` - Operation type (e.g., "tool_call", "session_init", "notification")
|
||||
pub fn with_operation(mut self, operation: impl Into<String>) -> Self {
|
||||
self.operation = Some(operation.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the target (model name, agent name, or filter name)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `target` - Target identifier (e.g., "gpt-4", "my-agent", "hallucination-detector")
|
||||
pub fn with_target(mut self, target: impl Into<String>) -> Self {
|
||||
self.target = Some(target.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the operation name string
|
||||
///
|
||||
/// # Format
|
||||
/// - With all components: `{method} {path} ({operation}) {target}`
|
||||
/// - Without operation: `{method} {path} {target}`
|
||||
/// - Without target: `{method} {path}`
|
||||
/// - Without path: `{method}`
|
||||
/// - Empty: returns empty string
|
||||
pub fn build(self) -> String {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
if let Some(method) = self.method {
|
||||
parts.push(method);
|
||||
}
|
||||
|
||||
if let Some(path) = self.path {
|
||||
if let Some(operation) = self.operation {
|
||||
parts.push(format!("{} ({})", path, operation));
|
||||
} else {
|
||||
parts.push(path);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(target) = self.target {
|
||||
parts.push(target);
|
||||
}
|
||||
|
||||
parts.join(" ")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OperationNameBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_operation_name_full() {
|
||||
let op = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path("/v1/chat/completions")
|
||||
.with_target("gpt-4")
|
||||
.build();
|
||||
|
||||
assert_eq!(op, "POST /v1/chat/completions gpt-4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_name_no_target() {
|
||||
let op = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path("/v1/chat/completions")
|
||||
.build();
|
||||
|
||||
assert_eq!(op, "POST /v1/chat/completions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_name_agent_filter() {
|
||||
let op = OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path("/agents/v1/chat/completions")
|
||||
.with_target("content-filter")
|
||||
.build();
|
||||
|
||||
assert_eq!(op, "POST /agents/v1/chat/completions content-filter");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_name_minimal() {
|
||||
let op = OperationNameBuilder::new().build();
|
||||
assert_eq!(op, "");
|
||||
}
|
||||
}
|
||||
3
crates/brightstaff/src/tracing/mod.rs
Normal file
3
crates/brightstaff/src/tracing/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
mod constants;
|
||||
|
||||
pub use constants::{OperationNameBuilder, operation_component, http, llm, error, routing};
|
||||
1
crates/build.sh
Normal file
1
crates/build.sh
Normal file
|
|
@ -0,0 +1 @@
|
|||
cargo build --release --target wasm32-wasip1 -p prompt_gateway -p llm_gateway && cargo build --release -p brightstaff
|
||||
|
|
@ -14,13 +14,25 @@ derivative = "2.2.0"
|
|||
thiserror = "1.0.64"
|
||||
tiktoken-rs = "0.5.9"
|
||||
rand = "0.8.5"
|
||||
serde_json = "1.0"
|
||||
serde_json = { version = "1.0", features = ["preserve_order"] }
|
||||
hex = "0.4.3"
|
||||
urlencoding = "2.1.3"
|
||||
url = "2.5.4"
|
||||
hermesllm = { version = "0.1.0", path = "../hermesllm" }
|
||||
serde_with = "3.13.0"
|
||||
|
||||
# Optional dependencies for trace collection (not available in WASM)
|
||||
tokio = { version = "1.44", features = ["sync", "time"], optional = true }
|
||||
reqwest = { version = "0.12", features = ["json"], optional = true }
|
||||
tracing = { version = "0.1", optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
trace-collection = ["tokio", "reqwest", "tracing"]
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "1.4.1"
|
||||
serde_json = "1.0.64"
|
||||
serial_test = "3.2"
|
||||
axum = "0.7"
|
||||
tokio = { version = "1.44", features = ["sync", "time", "macros", "rt"] }
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ use crate::{
|
|||
};
|
||||
use core::{panic, str};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt::Display,
|
||||
|
|
@ -265,7 +264,7 @@ pub struct ToolCall {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDetail {
|
||||
pub name: String,
|
||||
pub arguments: Option<HashMap<String, Value>>,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
|
|
|
|||
|
|
@ -21,8 +21,11 @@ pub struct ModelAlias {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Agent {
|
||||
pub id: String,
|
||||
pub kind: Option<String>,
|
||||
pub transport: Option<String>,
|
||||
pub tool: Option<String>,
|
||||
pub url: String,
|
||||
#[serde(rename = "type")]
|
||||
pub agent_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -41,6 +44,20 @@ pub struct Listener {
|
|||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StateStorageConfig {
|
||||
#[serde(rename = "type")]
|
||||
pub storage_type: StateStorageType,
|
||||
pub connection_string: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum StateStorageType {
|
||||
Memory,
|
||||
Postgres,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
|
|
@ -57,7 +74,9 @@ pub struct Configuration {
|
|||
pub mode: Option<GatewayMode>,
|
||||
pub routing: Option<Routing>,
|
||||
pub agents: Option<Vec<Agent>>,
|
||||
pub filters: Option<Vec<Agent>>,
|
||||
pub listeners: Vec<Listener>,
|
||||
pub state_storage: Option<StateStorageConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
@ -252,6 +271,39 @@ pub struct RoutingPreference {
|
|||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct AgentUsagePreference {
|
||||
pub model: String,
|
||||
pub orchestration_preferences: Vec<OrchestrationPreference>,
|
||||
}
|
||||
|
||||
/// OrchestrationPreference with custom serialization to always include default parameters.
|
||||
/// The parameters field is always serialized as:
|
||||
/// {"type": "object", "properties": {}, "required": []}
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OrchestrationPreference {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
impl serde::Serialize for OrchestrationPreference {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
use serde::ser::SerializeStruct;
|
||||
let mut state = serializer.serialize_struct("OrchestrationPreference", 3)?;
|
||||
state.serialize_field("name", &self.name)?;
|
||||
state.serialize_field("description", &self.description)?;
|
||||
state.serialize_field("parameters", &serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}))?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct LlmProvider {
|
||||
|
|
|
|||
|
|
@ -7,12 +7,13 @@ pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
|
|||
pub const DEFAULT_TARGET_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
|
||||
pub const API_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
|
||||
pub const MODEL_SERVER_REQUEST_TIMEOUT_MS: u64 = 30000; // 30 seconds
|
||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
pub const MODEL_SERVER_NAME: &str = "bright_staff";
|
||||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
pub const MESSAGES_KEY: &str = "messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const ARCH_IS_STREAMING_HEADER: &str = "x-arch-streaming-request";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses";
|
||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||
pub const HEALTHZ_PATH: &str = "/healthz";
|
||||
pub const X_ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
|
|
@ -31,3 +32,4 @@ pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
|
|||
pub const OTEL_POST_PATH: &str = "/v1/traces";
|
||||
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";
|
||||
pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries";
|
||||
pub const BRIGHT_STAFF_SERVICE_NAME : &str = "brightstaff";
|
||||
|
|
|
|||
|
|
@ -11,4 +11,5 @@ pub mod routing;
|
|||
pub mod stats;
|
||||
pub mod tokenizer;
|
||||
pub mod tracing;
|
||||
pub mod traces;
|
||||
pub mod utils;
|
||||
|
|
|
|||
|
|
@ -40,8 +40,14 @@ pub fn get_llm_provider(
|
|||
let mut rng = thread_rng();
|
||||
llm_providers
|
||||
.iter()
|
||||
.filter(|(_, provider)| {
|
||||
provider.model
|
||||
.as_ref()
|
||||
.map(|m| !m.starts_with("Arch"))
|
||||
.unwrap_or(true)
|
||||
})
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one llm provider")
|
||||
.expect("There should always be at least one non-Arch llm provider")
|
||||
.1
|
||||
.clone()
|
||||
}
|
||||
|
|
|
|||
285
crates/common/src/traces/collector.rs
Normal file
285
crates/common/src/traces/collector.rs
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
use super::shapes::Span;
|
||||
use super::resource_span_builder::ResourceSpanBuilder;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::{interval, Duration};
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
/// Parse W3C traceparent header into trace_id and parent_span_id
|
||||
/// Format: "00-{trace_id}-{parent_span_id}-01"
|
||||
///
|
||||
/// Returns (trace_id, Option<parent_span_id>)
|
||||
/// - parent_span_id is None if it's all zeros (0000000000000000), indicating a root span
|
||||
pub fn parse_traceparent(traceparent: &str) -> (String, Option<String>) {
|
||||
let parts: Vec<&str> = traceparent.split('-').collect();
|
||||
if parts.len() == 4 {
|
||||
let trace_id = parts[1].to_string();
|
||||
let parent_span_id = parts[2].to_string();
|
||||
|
||||
// If parent_span_id is all zeros, this is a root span with no parent
|
||||
let parent = if parent_span_id == "0000000000000000" {
|
||||
None
|
||||
} else {
|
||||
Some(parent_span_id)
|
||||
};
|
||||
|
||||
(trace_id, parent)
|
||||
} else {
|
||||
warn!("Invalid traceparent format: {}", traceparent);
|
||||
// Return empty trace ID and None for parent if parsing fails
|
||||
(String::new(), None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Collects and batches spans, flushing them to an OTEL collector
|
||||
///
|
||||
/// Supports multiple services, with each service (e.g., "archgw(routing)", "archgw(llm)")
|
||||
/// maintaining its own span queue. Flushes all services together periodically.
|
||||
///
|
||||
/// Tracing can be enabled/disabled in two ways:
|
||||
/// 1. Via arch_config.yaml: presence of `tracing` configuration section
|
||||
/// 2. Via environment variable: `OTEL_TRACING_ENABLED=true/false`
|
||||
///
|
||||
/// When disabled, span recording and flushing are no-ops.
|
||||
pub struct TraceCollector {
|
||||
/// Spans grouped by service name
|
||||
/// Key: service name (e.g., "archgw(routing)", "archgw(llm)")
|
||||
/// Value: queue of spans for that service
|
||||
spans_by_service: Arc<Mutex<HashMap<String, VecDeque<Span>>>>,
|
||||
flush_interval: Duration,
|
||||
otel_url: String,
|
||||
/// Whether tracing is enabled
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl TraceCollector {
|
||||
/// Create a new trace collector
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `enabled` - Whether tracing is enabled
|
||||
/// - `Some(true)` - Force enable tracing
|
||||
/// - `Some(false)` - Force disable tracing
|
||||
/// - `None` - Check `OTEL_TRACING_ENABLED` env var (defaults to true if not set)
|
||||
///
|
||||
/// Other parameters are read from environment variables:
|
||||
/// - `TRACE_FLUSH_INTERVAL_MS` - Flush interval in milliseconds (default: 1000)
|
||||
/// - `OTEL_COLLECTOR_URL` - OTEL collector endpoint (default: http://localhost:9903/v1/traces)
|
||||
pub fn new(enabled: Option<bool>) -> Self {
|
||||
let flush_interval_ms = std::env::var("TRACE_FLUSH_INTERVAL_MS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1000);
|
||||
|
||||
let otel_url = std::env::var("OTEL_COLLECTOR_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:9903/v1/traces".to_string());
|
||||
|
||||
// Determine if tracing is enabled:
|
||||
// 1. Use explicit parameter if provided
|
||||
// 2. Otherwise check OTEL_TRACING_ENABLED env var
|
||||
// 3. Default to false if neither is set (tracing opt-in, not opt-out)
|
||||
let enabled = enabled.unwrap_or_else(|| {
|
||||
std::env::var("OTEL_TRACING_ENABLED")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
debug!(
|
||||
"TraceCollector initialized: flush_interval={}ms, url={}, enabled={}",
|
||||
flush_interval_ms, otel_url, enabled
|
||||
);
|
||||
|
||||
Self {
|
||||
spans_by_service: Arc::new(Mutex::new(HashMap::new())),
|
||||
flush_interval: Duration::from_millis(flush_interval_ms),
|
||||
otel_url,
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a span for a specific service
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `service_name` - Name of the service (e.g., "archgw(routing)", "archgw(llm)")
|
||||
/// * `span` - The span to record
|
||||
pub fn record_span(&self, service_name: impl Into<String>, span: Span) {
|
||||
// Skip recording if tracing is disabled
|
||||
if !self.enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
let service_name = service_name.into();
|
||||
|
||||
// Use try_lock to avoid blocking in async contexts
|
||||
// If the lock is held, we skip recording (telemetry shouldn't block the app)
|
||||
if let Ok(mut spans_by_service) = self.spans_by_service.try_lock() {
|
||||
// Get or create the queue for this service
|
||||
let spans = spans_by_service
|
||||
.entry(service_name)
|
||||
.or_insert_with(VecDeque::new);
|
||||
|
||||
spans.push_back(span);
|
||||
} else {
|
||||
// Lock contention - skip recording this span
|
||||
debug!("Skipped span recording due to lock contention");
|
||||
}
|
||||
// Flushing is handled by the periodic background flusher (see `start_background_flusher`).
|
||||
}
|
||||
|
||||
/// Flush all buffered spans to the OTEL collector
|
||||
/// Builds ResourceSpans for each service with spans
|
||||
pub async fn flush(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Skip flushing if tracing is disabled
|
||||
if !self.enabled {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut spans_by_service = self.spans_by_service.lock().await;
|
||||
|
||||
if spans_by_service.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Snapshot and drain all services' spans
|
||||
let service_batches: Vec<(String, Vec<Span>)> = spans_by_service
|
||||
.iter_mut()
|
||||
.filter_map(|(service_name, spans)| {
|
||||
if spans.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some((service_name.clone(), spans.drain(..).collect()))
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
drop(spans_by_service); // Release lock before HTTP call
|
||||
|
||||
if service_batches.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let total_spans: usize = service_batches.iter().map(|(_, spans)| spans.len()).sum();
|
||||
debug!("Flushing {} spans across {} services to OTEL collector", total_spans, service_batches.len());
|
||||
|
||||
// Build canonical OTEL payload structure - one ResourceSpan per service
|
||||
let resource_spans = self.build_resource_spans(service_batches);
|
||||
|
||||
match self.send_to_otel(resource_spans).await {
|
||||
Ok(_) => {
|
||||
debug!("Successfully flushed {} spans", total_spans);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to send spans to OTEL collector: {:?}", e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build OTEL-compliant resource spans from collected spans, one ResourceSpan per service
|
||||
fn build_resource_spans(&self, service_batches: Vec<(String, Vec<Span>)>) -> Vec<super::shapes::ResourceSpan> {
|
||||
service_batches
|
||||
.into_iter()
|
||||
.map(|(service_name, spans)| {
|
||||
ResourceSpanBuilder::new(&service_name)
|
||||
.add_spans(spans)
|
||||
.build()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Send resource spans to OTEL collector
|
||||
/// Serializes as {"resourceSpans": [...]} per OTEL spec
|
||||
async fn send_to_otel(
|
||||
&self,
|
||||
resource_spans: Vec<super::shapes::ResourceSpan>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Create OTEL payload with proper structure
|
||||
let payload = serde_json::json!({
|
||||
"resourceSpans": resource_spans
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(&self.otel_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&payload)
|
||||
.timeout(Duration::from_secs(5))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
warn!(
|
||||
"OTEL collector returned non-success status: {}",
|
||||
response.status()
|
||||
);
|
||||
return Err(format!("OTEL collector error: {}", response.status()).into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start a background task that periodically flushes traces
|
||||
/// Returns a join handle that can be used to stop the flusher
|
||||
pub fn start_background_flusher(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
|
||||
let flush_interval = self.flush_interval;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut ticker = interval(flush_interval);
|
||||
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
|
||||
if let Err(e) = self.flush().await {
|
||||
error!("Background trace flush failed: {:?}", e);
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Get current number of buffered spans across all services (for testing/monitoring)
|
||||
pub async fn buffered_count(&self) -> usize {
|
||||
self.spans_by_service
|
||||
.lock()
|
||||
.await
|
||||
.values()
|
||||
.map(|spans| spans.len())
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::traces::SpanBuilder;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_collector_basic() {
|
||||
let collector = TraceCollector::new(Some(true));
|
||||
|
||||
let span = SpanBuilder::new("test_operation")
|
||||
.with_trace_id("abc123")
|
||||
.build();
|
||||
|
||||
collector.record_span("test-service", span);
|
||||
|
||||
assert_eq!(collector.buffered_count().await, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_collector_auto_flush() {
|
||||
// Since batch-triggered flush behavior was removed, record two spans and verify both are buffered
|
||||
let collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
||||
let span1 = SpanBuilder::new("test1").build();
|
||||
let span2 = SpanBuilder::new("test2").build();
|
||||
|
||||
collector.record_span("test-service", span1);
|
||||
collector.record_span("test-service", span2);
|
||||
|
||||
// With no batch-triggered flush, both spans should remain buffered
|
||||
assert_eq!(collector.buffered_count().await, 2);
|
||||
}
|
||||
}
|
||||
27
crates/common/src/traces/constants.rs
Normal file
27
crates/common/src/traces/constants.rs
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
/// OpenTelemetry semantic convention constants for tracing
|
||||
///
|
||||
/// These constants ensure consistency across the codebase and prevent typos
|
||||
|
||||
/// Resource attribute keys following OTEL semantic conventions
|
||||
pub mod resource {
|
||||
/// Logical name of the service
|
||||
pub const SERVICE_NAME: &str = "service.name";
|
||||
|
||||
/// Version of the service
|
||||
pub const SERVICE_VERSION: &str = "service.version";
|
||||
|
||||
/// Service namespace/environment
|
||||
pub const SERVICE_NAMESPACE: &str = "service.namespace";
|
||||
|
||||
/// Service instance ID
|
||||
pub const SERVICE_INSTANCE_ID: &str = "service.instance.id";
|
||||
}
|
||||
|
||||
/// Instrumentation scope defaults
|
||||
pub mod scope {
|
||||
/// Default scope name for tracing instrumentation
|
||||
pub const DEFAULT_NAME: &str = "brightstaff.tracing";
|
||||
|
||||
/// Default scope version
|
||||
pub const DEFAULT_VERSION: &str = "1.0.0";
|
||||
}
|
||||
26
crates/common/src/traces/mod.rs
Normal file
26
crates/common/src/traces/mod.rs
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
// Original tracing types (OTEL structures)
|
||||
mod shapes;
|
||||
// New tracing utilities
|
||||
mod span_builder;
|
||||
mod resource_span_builder;
|
||||
mod constants;
|
||||
|
||||
#[cfg(feature = "trace-collection")]
|
||||
mod collector;
|
||||
|
||||
#[cfg(all(test, feature = "trace-collection"))]
|
||||
mod tests;
|
||||
|
||||
// Re-export original types
|
||||
pub use shapes::{
|
||||
Span, Event, Traceparent, TraceparentNewError,
|
||||
ResourceSpan, Resource, ScopeSpan, Scope, Attribute, AttributeValue,
|
||||
};
|
||||
|
||||
// Re-export new utilities
|
||||
pub use span_builder::{SpanBuilder, SpanKind, generate_random_span_id};
|
||||
pub use resource_span_builder::ResourceSpanBuilder;
|
||||
pub use constants::*;
|
||||
|
||||
#[cfg(feature = "trace-collection")]
|
||||
pub use collector::{TraceCollector, parse_traceparent};
|
||||
121
crates/common/src/traces/resource_span_builder.rs
Normal file
121
crates/common/src/traces/resource_span_builder.rs
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
use super::shapes::{ResourceSpan, Resource, ScopeSpan, Scope, Span, Attribute, AttributeValue};
|
||||
use super::constants::{resource, scope};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Builder for creating OTEL ResourceSpan structures
|
||||
///
|
||||
/// Provides a fluent API for building the resource/scope/span hierarchy
|
||||
pub struct ResourceSpanBuilder {
|
||||
service_name: String,
|
||||
resource_attributes: HashMap<String, String>,
|
||||
scope_name: String,
|
||||
scope_version: String,
|
||||
spans: Vec<Span>,
|
||||
}
|
||||
|
||||
impl ResourceSpanBuilder {
|
||||
/// Create a new ResourceSpan builder with service name
|
||||
pub fn new(service_name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
service_name: service_name.into(),
|
||||
resource_attributes: HashMap::new(),
|
||||
scope_name: scope::DEFAULT_NAME.to_string(),
|
||||
scope_version: scope::DEFAULT_VERSION.to_string(),
|
||||
spans: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a resource attribute (e.g., deployment.environment, host.name)
|
||||
pub fn with_resource_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
self.resource_attributes.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the instrumentation scope name
|
||||
pub fn with_scope_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.scope_name = name.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the instrumentation scope version
|
||||
pub fn with_scope_version(mut self, version: impl Into<String>) -> Self {
|
||||
self.scope_version = version.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a single span
|
||||
pub fn add_span(mut self, span: Span) -> Self {
|
||||
self.spans.push(span);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple spans
|
||||
pub fn add_spans(mut self, spans: Vec<Span>) -> Self {
|
||||
self.spans.extend(spans);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the ResourceSpan
|
||||
pub fn build(self) -> ResourceSpan {
|
||||
// Build resource attributes
|
||||
let mut attributes = vec![
|
||||
Attribute {
|
||||
key: resource::SERVICE_NAME.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(self.service_name),
|
||||
},
|
||||
}
|
||||
];
|
||||
|
||||
// Add custom resource attributes
|
||||
for (key, value) in self.resource_attributes {
|
||||
attributes.push(Attribute {
|
||||
key,
|
||||
value: AttributeValue {
|
||||
string_value: Some(value),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
let resource = Resource { attributes };
|
||||
|
||||
let scope = Scope {
|
||||
name: self.scope_name,
|
||||
version: self.scope_version,
|
||||
attributes: Vec::new(),
|
||||
};
|
||||
|
||||
let scope_span = ScopeSpan {
|
||||
scope,
|
||||
spans: self.spans,
|
||||
};
|
||||
|
||||
ResourceSpan {
|
||||
resource,
|
||||
scope_spans: vec![scope_span],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::traces::SpanBuilder;
|
||||
|
||||
#[test]
|
||||
fn test_resource_span_builder() {
|
||||
let span1 = SpanBuilder::new("operation1").build();
|
||||
let span2 = SpanBuilder::new("operation2").build();
|
||||
|
||||
let resource_span = ResourceSpanBuilder::new("test-service")
|
||||
.with_resource_attribute("deployment.environment", "production")
|
||||
.with_scope_name("test-scope")
|
||||
.add_span(span1)
|
||||
.add_span(span2)
|
||||
.build();
|
||||
|
||||
assert_eq!(resource_span.resource.attributes.len(), 2); // service.name + custom
|
||||
assert_eq!(resource_span.scope_spans.len(), 1);
|
||||
assert_eq!(resource_span.scope_spans[0].spans.len(), 2);
|
||||
}
|
||||
}
|
||||
123
crates/common/src/traces/shapes.rs
Normal file
123
crates/common/src/traces/shapes.rs
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ResourceSpan {
|
||||
pub resource: Resource,
|
||||
#[serde(rename = "scopeSpans")]
|
||||
pub scope_spans: Vec<ScopeSpan>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Resource {
|
||||
pub attributes: Vec<Attribute>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ScopeSpan {
|
||||
pub scope: Scope,
|
||||
pub spans: Vec<Span>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Scope {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub attributes: Vec<Attribute>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Span {
|
||||
#[serde(rename = "traceId")]
|
||||
pub trace_id: String,
|
||||
#[serde(rename = "spanId")]
|
||||
pub span_id: String,
|
||||
#[serde(rename = "parentSpanId")]
|
||||
pub parent_span_id: Option<String>, // Optional in case there's no parent span
|
||||
pub name: String,
|
||||
#[serde(rename = "startTimeUnixNano")]
|
||||
pub start_time_unix_nano: String,
|
||||
#[serde(rename = "endTimeUnixNano")]
|
||||
pub end_time_unix_nano: String,
|
||||
pub kind: u32,
|
||||
pub attributes: Vec<Attribute>,
|
||||
pub events: Option<Vec<Event>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Event {
|
||||
#[serde(rename = "timeUnixNano")]
|
||||
pub time_unix_nano: String,
|
||||
pub name: String,
|
||||
pub attributes: Vec<Attribute>,
|
||||
}
|
||||
|
||||
impl Event {
|
||||
pub fn new(name: String, time_unix_nano: u128) -> Self {
|
||||
Event {
|
||||
time_unix_nano: format!("{}", time_unix_nano),
|
||||
name,
|
||||
attributes: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_attribute(&mut self, key: String, value: String) {
|
||||
self.attributes.push(Attribute {
|
||||
key,
|
||||
value: AttributeValue {
|
||||
string_value: Some(value),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Attribute {
|
||||
pub key: String,
|
||||
pub value: AttributeValue,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct AttributeValue {
|
||||
#[serde(rename = "stringValue")]
|
||||
pub string_value: Option<String>, // Use Option to handle different value types
|
||||
}
|
||||
|
||||
pub struct Traceparent {
|
||||
pub version: String,
|
||||
pub trace_id: String,
|
||||
pub parent_id: String,
|
||||
pub flags: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Traceparent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}-{}-{}-{}",
|
||||
self.version, self.trace_id, self.parent_id, self.flags
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum TraceparentNewError {
|
||||
#[error("Invalid traceparent: \'{0}\'")]
|
||||
InvalidTraceparent(String),
|
||||
}
|
||||
|
||||
impl TryFrom<String> for Traceparent {
|
||||
type Error = TraceparentNewError;
|
||||
|
||||
fn try_from(traceparent: String) -> Result<Self, Self::Error> {
|
||||
let traceparent_tokens: Vec<&str> = traceparent.split("-").collect::<Vec<&str>>();
|
||||
if traceparent_tokens.len() != 4 {
|
||||
return Err(TraceparentNewError::InvalidTraceparent(traceparent));
|
||||
}
|
||||
Ok(Traceparent {
|
||||
version: traceparent_tokens[0].to_string(),
|
||||
trace_id: traceparent_tokens[1].to_string(),
|
||||
parent_id: traceparent_tokens[2].to_string(),
|
||||
flags: traceparent_tokens[3].to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
200
crates/common/src/traces/span_builder.rs
Normal file
200
crates/common/src/traces/span_builder.rs
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
use super::shapes::{Span, Attribute, AttributeValue};
|
||||
use std::collections::HashMap;
|
||||
use std::time::SystemTime;
|
||||
|
||||
/// OpenTelemetry span kinds
|
||||
/// https://opentelemetry.io/docs/specs/otel/trace/api/#spankind
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum SpanKind {
|
||||
/// Default value. Indicates that the span represents an internal operation within an application
|
||||
Internal = 0,
|
||||
/// Indicates that the span describes a request to some remote service
|
||||
Client = 3,
|
||||
}
|
||||
|
||||
/// Builder for creating OTEL-compliant spans with a fluent API
|
||||
///
|
||||
/// This is the recommended way to create spans with proper trace context.
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use common::traces::{SpanBuilder, SpanKind};
|
||||
/// use std::time::SystemTime;
|
||||
///
|
||||
/// let span = SpanBuilder::new("router_chat")
|
||||
/// .with_trace_id("abc123")
|
||||
/// .with_parent_span_id("parent456")
|
||||
/// .with_kind(SpanKind::Internal)
|
||||
/// .with_attribute("http.method", "POST")
|
||||
/// .with_attribute("http.path", "/v1/chat/completions")
|
||||
/// .build();
|
||||
/// ```
|
||||
pub struct SpanBuilder {
|
||||
name: String,
|
||||
trace_id: Option<String>,
|
||||
parent_span_id: Option<String>,
|
||||
start_time: SystemTime,
|
||||
end_time: Option<SystemTime>,
|
||||
kind: SpanKind,
|
||||
attributes: HashMap<String, String>,
|
||||
span_id: Option<String>,
|
||||
}
|
||||
|
||||
impl SpanBuilder {
|
||||
/// Create a new span builder
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The operation name for this span (e.g., "router_chat", "determine_route")
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
trace_id: None,
|
||||
parent_span_id: None,
|
||||
start_time: SystemTime::now(),
|
||||
end_time: None,
|
||||
kind: SpanKind::Internal,
|
||||
attributes: HashMap::new(),
|
||||
span_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the trace ID (extracted from traceparent or OpenTelemetry context)
|
||||
pub fn with_trace_id(mut self, trace_id: impl Into<String>) -> Self {
|
||||
self.trace_id = Some(trace_id.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_span_id(mut self, span_id: impl Into<String>) -> Self {
|
||||
self.span_id = Some(span_id.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the parent span ID to link this span to its parent
|
||||
pub fn with_parent_span_id(mut self, parent_span_id: impl Into<String>) -> Self {
|
||||
self.parent_span_id = Some(parent_span_id.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the span kind (defaults to Internal)
|
||||
pub fn with_kind(mut self, kind: SpanKind) -> Self {
|
||||
self.kind = kind;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set explicit start time (defaults to now)
|
||||
pub fn with_start_time(mut self, start_time: SystemTime) -> Self {
|
||||
self.start_time = start_time;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set explicit end time (defaults to build time)
|
||||
pub fn with_end_time(mut self, end_time: SystemTime) -> Self {
|
||||
self.end_time = Some(end_time);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a single attribute to the span
|
||||
pub fn with_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
self.attributes.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple attributes at once
|
||||
pub fn with_attributes(mut self, attrs: HashMap<String, String>) -> Self {
|
||||
self.attributes.extend(attrs);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the span, consuming the builder
|
||||
///
|
||||
/// Creates a complete OTEL-compliant span with all provided attributes,
|
||||
/// generating span_id and using provided or random trace_id.
|
||||
pub fn build(self) -> Span {
|
||||
let end_time = self.end_time.unwrap_or_else(SystemTime::now);
|
||||
|
||||
let start_nanos = system_time_to_nanos(self.start_time);
|
||||
let end_nanos = system_time_to_nanos(end_time);
|
||||
|
||||
// Generate trace_id if not provided
|
||||
let trace_id = self.trace_id.unwrap_or_else(|| generate_random_trace_id());
|
||||
|
||||
// Create attributes in OTEL format
|
||||
let attributes: Vec<Attribute> = self.attributes
|
||||
.into_iter()
|
||||
.map(|(key, value)| Attribute {
|
||||
key,
|
||||
value: AttributeValue {
|
||||
string_value: Some(value),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build span directly without going through Span::new()
|
||||
Span {
|
||||
trace_id,
|
||||
span_id: self.span_id.unwrap_or_else(|| generate_random_span_id()),
|
||||
parent_span_id: self.parent_span_id,
|
||||
name: self.name,
|
||||
start_time_unix_nano: format!("{}", start_nanos),
|
||||
end_time_unix_nano: format!("{}", end_nanos),
|
||||
kind: self.kind as u32,
|
||||
attributes,
|
||||
events: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert SystemTime to nanoseconds since UNIX epoch for OTEL
|
||||
fn system_time_to_nanos(time: SystemTime) -> u128 {
|
||||
time.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos()
|
||||
}
|
||||
|
||||
/// Generate a random span ID (16 hex characters = 8 bytes)
|
||||
pub fn generate_random_span_id() -> String {
|
||||
use rand::RngCore;
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut random_bytes = [0u8; 8];
|
||||
rng.fill_bytes(&mut random_bytes);
|
||||
hex::encode(random_bytes)
|
||||
}
|
||||
|
||||
/// Generate a random trace ID (32 hex characters = 16 bytes)
|
||||
fn generate_random_trace_id() -> String {
|
||||
use rand::RngCore;
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut random_bytes = [0u8; 16];
|
||||
rng.fill_bytes(&mut random_bytes);
|
||||
hex::encode(random_bytes)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_span_builder_basic() {
|
||||
let span = SpanBuilder::new("test_operation")
|
||||
.with_trace_id("abc123")
|
||||
.with_parent_span_id("parent123")
|
||||
.with_attribute("key", "value")
|
||||
.build();
|
||||
|
||||
assert_eq!(span.name, "test_operation");
|
||||
assert_eq!(span.trace_id, "abc123");
|
||||
assert_eq!(span.parent_span_id, Some("parent123".to_string()));
|
||||
assert_eq!(span.attributes.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_span_builder_no_parent() {
|
||||
let span = SpanBuilder::new("root_span")
|
||||
.with_trace_id("xyz789")
|
||||
.build();
|
||||
|
||||
assert_eq!(span.name, "root_span");
|
||||
assert_eq!(span.trace_id, "xyz789");
|
||||
assert_eq!(span.parent_span_id, None);
|
||||
}
|
||||
}
|
||||
101
crates/common/src/traces/tests/mock_otel_collector.rs
Normal file
101
crates/common/src/traces/tests/mock_otel_collector.rs
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
//! Mock OTEL Collector for testing trace output
|
||||
//!
|
||||
//! This module provides a simple HTTP server that mimics an OTEL collector.
|
||||
//! It exposes three endpoints:
|
||||
//! - POST /v1/traces: Capture incoming OTLP JSON payloads
|
||||
//! - GET /v1/traces: Return all captured payloads as JSON array
|
||||
//! - DELETE /v1/traces: Clear all captured payloads
|
||||
//!
|
||||
//! Each test creates its own MockOtelCollector instance.
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
routing::{delete, get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
type SharedTraces = Arc<RwLock<Vec<Value>>>;
|
||||
|
||||
/// POST /v1/traces - capture incoming OTLP payload
|
||||
async fn post_traces(
|
||||
State(traces): State<SharedTraces>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> StatusCode {
|
||||
traces.write().await.push(payload);
|
||||
StatusCode::OK
|
||||
}
|
||||
|
||||
/// GET /v1/traces - return all captured payloads
|
||||
async fn get_traces(State(traces): State<SharedTraces>) -> Json<Vec<Value>> {
|
||||
Json(traces.read().await.clone())
|
||||
}
|
||||
|
||||
/// DELETE /v1/traces - clear all captured payloads
|
||||
async fn delete_traces(State(traces): State<SharedTraces>) -> StatusCode {
|
||||
traces.write().await.clear();
|
||||
StatusCode::NO_CONTENT
|
||||
}
|
||||
|
||||
/// Mock OTEL collector server
|
||||
pub struct MockOtelCollector {
|
||||
address: String,
|
||||
client: reqwest::Client,
|
||||
#[allow(dead_code)]
|
||||
server_handle: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl MockOtelCollector {
|
||||
/// Create and start a new mock collector on a random port
|
||||
pub async fn start() -> Self {
|
||||
let traces = Arc::new(RwLock::new(Vec::new()));
|
||||
|
||||
let app = Router::new()
|
||||
.route("/v1/traces", post(post_traces))
|
||||
.route("/v1/traces", get(get_traces))
|
||||
.route("/v1/traces", delete(delete_traces))
|
||||
.with_state(traces.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("Failed to bind to random port");
|
||||
|
||||
let addr = listener.local_addr().expect("Failed to get local address");
|
||||
let address = format!("http://127.0.0.1:{}", addr.port());
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.expect("Server failed");
|
||||
});
|
||||
|
||||
// Give server a moment to start
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
|
||||
Self {
|
||||
address,
|
||||
client: reqwest::Client::new(),
|
||||
server_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the address of the collector
|
||||
pub fn address(&self) -> &str {
|
||||
&self.address
|
||||
}
|
||||
|
||||
/// GET /v1/traces - fetch all captured payloads
|
||||
pub async fn get_traces(&self) -> Vec<Value> {
|
||||
self.client
|
||||
.get(format!("{}/v1/traces", self.address))
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to GET traces")
|
||||
.json()
|
||||
.await
|
||||
.expect("Failed to parse traces JSON")
|
||||
}
|
||||
}
|
||||
4
crates/common/src/traces/tests/mod.rs
Normal file
4
crates/common/src/traces/tests/mod.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
mod mock_otel_collector;
|
||||
mod trace_integration_test;
|
||||
|
||||
pub use mock_otel_collector::MockOtelCollector;
|
||||
304
crates/common/src/traces/tests/trace_integration_test.rs
Normal file
304
crates/common/src/traces/tests/trace_integration_test.rs
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
//! Integration tests for OpenTelemetry tracing in router.rs
|
||||
//!
|
||||
//! These tests validate that the spans created for LLM requests contain
|
||||
//! all expected attributes and events by checking the raw JSON payloads
|
||||
//! sent to the mock OTEL collector.
|
||||
//!
|
||||
//! ## Test Design
|
||||
//! Each test creates its own MockOtelCollector and TraceCollector:
|
||||
//! 1. Start MockOtelCollector on random port
|
||||
//! 2. Create TraceCollector with 500ms flush interval
|
||||
//! 3. Record spans using TraceCollector
|
||||
//! 4. Flush and wait (500ms + 200ms buffer = 700ms total) for spans to arrive
|
||||
//! 5. Get raw JSON payloads (GET /v1/traces) and validate structure
|
||||
//! 6. Test cleanup happens automatically when collectors are dropped
|
||||
//!
|
||||
//! ## Serial Execution
|
||||
//! Tests use the `#[serial]` attribute to run sequentially because they
|
||||
//! use global environment variables (OTEL_COLLECTOR_URL, OTEL_TRACING_ENABLED,
|
||||
//! TRACE_FLUSH_INTERVAL_MS). This ensures test isolation without requiring
|
||||
//! the `--test-threads=1` command line flag.
|
||||
|
||||
const FLUSH_INTERVAL_MS: u64 = 50;
|
||||
const FLUSH_BUFFER_MS: u64 = 50;
|
||||
const TOTAL_WAIT_MS: u64 = FLUSH_INTERVAL_MS + FLUSH_BUFFER_MS;
|
||||
|
||||
use crate::traces::{SpanBuilder, SpanKind, TraceCollector};
|
||||
use serde_json::Value;
|
||||
use serial_test::serial;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::MockOtelCollector;
|
||||
|
||||
/// Helper to extract all spans from OTLP JSON payloads
|
||||
fn extract_spans(payloads: &[Value]) -> Vec<&Value> {
|
||||
let mut spans = Vec::new();
|
||||
for payload in payloads {
|
||||
if let Some(resource_spans) = payload.get("resourceSpans").and_then(|v| v.as_array()) {
|
||||
for resource_span in resource_spans {
|
||||
if let Some(scope_spans) = resource_span.get("scopeSpans").and_then(|v| v.as_array()) {
|
||||
for scope_span in scope_spans {
|
||||
if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array()) {
|
||||
spans.extend(span_list.iter());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
spans
|
||||
}
|
||||
|
||||
/// Helper to get string attribute value from a span
|
||||
fn get_string_attr<'a>(span: &'a Value, key: &str) -> Option<&'a str> {
|
||||
span.get("attributes")
|
||||
.and_then(|attrs| attrs.as_array())
|
||||
.and_then(|attrs| {
|
||||
attrs.iter().find(|attr| {
|
||||
attr.get("key").and_then(|k| k.as_str()) == Some(key)
|
||||
})
|
||||
})
|
||||
.and_then(|attr| attr.get("value"))
|
||||
.and_then(|v| v.get("stringValue"))
|
||||
.and_then(|v| v.as_str())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_llm_span_contains_basic_attributes() {
|
||||
// Start mock OTEL collector
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
|
||||
// Create TraceCollector pointing to mock with 500ms flush intervalc
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
||||
// Create a test span simulating router.rs behavior
|
||||
let span = SpanBuilder::new("POST /v1/chat/completions >> /v1/chat/completions")
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_trace_id("test-trace-123")
|
||||
.with_attribute("http.method", "POST")
|
||||
.with_attribute("http.target", "/v1/chat/completions")
|
||||
.with_attribute("http.upstream_target", "/v1/chat/completions")
|
||||
.with_attribute("llm.model", "gpt-4o")
|
||||
.with_attribute("llm.provider", "openai")
|
||||
.with_attribute("llm.is_streaming", "true")
|
||||
.with_attribute("llm.temperature", "0.7")
|
||||
.build();
|
||||
|
||||
trace_collector.record_span("archgw(llm)", span);
|
||||
|
||||
// Flush and wait for spans to arrive (500ms flush interval + 200ms buffer)
|
||||
trace_collector.flush().await.expect("Failed to flush");
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||
|
||||
let payloads = mock_collector.get_traces().await;
|
||||
let spans = extract_spans(&payloads);
|
||||
|
||||
assert_eq!(spans.len(), 1, "Expected exactly one span");
|
||||
|
||||
let span = spans[0];
|
||||
// Validate HTTP attributes
|
||||
assert_eq!(get_string_attr(span, "http.method"), Some("POST"));
|
||||
assert_eq!(get_string_attr(span, "http.target"), Some("/v1/chat/completions"));
|
||||
|
||||
// Validate LLM attributes
|
||||
assert_eq!(get_string_attr(span, "llm.model"), Some("gpt-4o"));
|
||||
assert_eq!(get_string_attr(span, "llm.provider"), Some("openai"));
|
||||
assert_eq!(get_string_attr(span, "llm.is_streaming"), Some("true"));
|
||||
assert_eq!(get_string_attr(span, "llm.temperature"), Some("0.7"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_llm_span_contains_tool_information() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
||||
let tools_formatted = "get_weather(...)\nsearch_web(...)\ncalculate(...)";
|
||||
|
||||
let span = SpanBuilder::new("POST /v1/chat/completions")
|
||||
.with_trace_id("test-trace-tools")
|
||||
.with_attribute("llm.request.tools", tools_formatted)
|
||||
.with_attribute("llm.model", "gpt-4o")
|
||||
.build();
|
||||
|
||||
trace_collector.record_span("archgw(llm)", span);
|
||||
trace_collector.flush().await.expect("Failed to flush");
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||
|
||||
let payloads = mock_collector.get_traces().await;
|
||||
let spans = extract_spans(&payloads);
|
||||
|
||||
assert!(!spans.is_empty(), "No spans captured");
|
||||
|
||||
let span = spans[0];
|
||||
let tools = get_string_attr(span, "llm.request.tools");
|
||||
|
||||
assert!(tools.is_some(), "Tools attribute missing");
|
||||
assert!(tools.unwrap().contains("get_weather(...)"));
|
||||
assert!(tools.unwrap().contains("search_web(...)"));
|
||||
assert!(tools.unwrap().contains("calculate(...)"));
|
||||
assert!(tools.unwrap().contains('\n'), "Tools should be newline-separated");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_llm_span_contains_user_message_preview() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
||||
let long_message = "This is a very long user message that should be truncated to 50 characters in the span";
|
||||
let preview = if long_message.len() > 50 {
|
||||
format!("{}...", &long_message[..50])
|
||||
} else {
|
||||
long_message.to_string()
|
||||
};
|
||||
|
||||
let span = SpanBuilder::new("POST /v1/messages")
|
||||
.with_trace_id("test-trace-preview")
|
||||
.with_attribute("llm.request.user_message_preview", &preview)
|
||||
.build();
|
||||
|
||||
trace_collector.record_span("archgw(llm)", span);
|
||||
trace_collector.flush().await.expect("Failed to flush");
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||
|
||||
let payloads = mock_collector.get_traces().await;
|
||||
let spans = extract_spans(&payloads);
|
||||
let span = spans[0];
|
||||
|
||||
let message_preview = get_string_attr(span, "llm.request.user_message_preview");
|
||||
|
||||
assert!(message_preview.is_some());
|
||||
assert!(message_preview.unwrap().len() <= 53); // 50 chars + "..."
|
||||
assert!(message_preview.unwrap().contains("..."));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_llm_span_contains_time_to_first_token() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
||||
let ttft_ms = "245"; // milliseconds as string
|
||||
|
||||
let span = SpanBuilder::new("POST /v1/chat/completions")
|
||||
.with_trace_id("test-trace-ttft")
|
||||
.with_attribute("llm.is_streaming", "true")
|
||||
.with_attribute("llm.time_to_first_token_ms", ttft_ms)
|
||||
.build();
|
||||
|
||||
trace_collector.record_span("archgw(llm)", span);
|
||||
trace_collector.flush().await.expect("Failed to flush");
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||
|
||||
let payloads = mock_collector.get_traces().await;
|
||||
let spans = extract_spans(&payloads);
|
||||
let span = spans[0];
|
||||
|
||||
// Check TTFT attribute
|
||||
let ttft_attr = get_string_attr(span, "llm.time_to_first_token_ms");
|
||||
assert_eq!(ttft_attr, Some("245"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_llm_span_contains_upstream_path() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
||||
// Test Zhipu provider with path transformation
|
||||
let span = SpanBuilder::new("POST /v1/chat/completions >> /api/paas/v4/chat/completions")
|
||||
.with_trace_id("test-trace-upstream")
|
||||
.with_attribute("http.upstream_target", "/api/paas/v4/chat/completions")
|
||||
.with_attribute("llm.provider", "zhipu")
|
||||
.with_attribute("llm.model", "glm-4")
|
||||
.build();
|
||||
|
||||
trace_collector.record_span("archgw(llm)", span);
|
||||
trace_collector.flush().await.expect("Failed to flush");
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||
|
||||
let payloads = mock_collector.get_traces().await;
|
||||
let spans = extract_spans(&payloads);
|
||||
let span = spans[0];
|
||||
|
||||
// Operation name should show the transformation
|
||||
let name = span.get("name").and_then(|v| v.as_str());
|
||||
assert!(name.is_some());
|
||||
assert!(name.unwrap().contains(">>"), "Operation name should show path transformation");
|
||||
|
||||
// Check upstream target attribute
|
||||
let upstream = get_string_attr(span, "http.upstream_target");
|
||||
assert_eq!(upstream, Some("/api/paas/v4/chat/completions"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_llm_span_multiple_services() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
||||
// Create spans for different services
|
||||
let llm_span = SpanBuilder::new("LLM Request")
|
||||
.with_trace_id("test-multi")
|
||||
.with_attribute("service", "llm")
|
||||
.build();
|
||||
|
||||
let routing_span = SpanBuilder::new("Routing Decision")
|
||||
.with_trace_id("test-multi")
|
||||
.with_attribute("service", "routing")
|
||||
.build();
|
||||
|
||||
trace_collector.record_span("archgw(llm)", llm_span);
|
||||
trace_collector.record_span("archgw(routing)", routing_span);
|
||||
trace_collector.flush().await.expect("Failed to flush");
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||
|
||||
let payloads = mock_collector.get_traces().await;
|
||||
let all_spans = extract_spans(&payloads);
|
||||
|
||||
assert_eq!(all_spans.len(), 2, "Should have captured both spans");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_tracing_disabled_produces_no_spans() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
|
||||
// Create TraceCollector with tracing DISABLED
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "false");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(false)));
|
||||
|
||||
let span = SpanBuilder::new("Test Span")
|
||||
.with_trace_id("test-disabled")
|
||||
.build();
|
||||
|
||||
trace_collector.record_span("archgw(llm)", span);
|
||||
trace_collector.flush().await.ok(); // Should be no-op when disabled
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(TOTAL_WAIT_MS)).await;
|
||||
|
||||
let payloads = mock_collector.get_traces().await;
|
||||
let all_spans = extract_spans(&payloads);
|
||||
assert_eq!(all_spans.len(), 0, "No spans should be captured when tracing is disabled");
|
||||
}
|
||||
|
|
@ -10,3 +10,5 @@ serde_with = {version = "3.12.0", features = ["base64"]}
|
|||
thiserror = "2.0.12"
|
||||
aws-smithy-eventstream = "0.60"
|
||||
bytes = "1.10"
|
||||
uuid = { version = "1.11", features = ["v4"] }
|
||||
log = "0.4"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use thiserror::Error;
|
|||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::ProviderStreamResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
|
||||
// ============================================================================
|
||||
// AMAZON BEDROCK CONVERSE API ENUMERATION
|
||||
|
|
@ -200,6 +200,17 @@ impl ProviderRequest for ConverseRequest {
|
|||
})
|
||||
}
|
||||
|
||||
fn get_tool_names(&self) -> Option<Vec<String>> {
|
||||
self.tool_config.as_ref()?.tools.as_ref().map(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.filter_map(|tool| match tool {
|
||||
Tool::ToolSpec { tool_spec } => Some(tool_spec.name.clone()),
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||
serde_json::to_vec(self).map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to serialize Bedrock request: {}", e),
|
||||
|
|
@ -218,6 +229,108 @@ impl ProviderRequest for ConverseRequest {
|
|||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn get_temperature(&self) -> Option<f32> {
|
||||
self.inference_config.as_ref()?.temperature
|
||||
}
|
||||
|
||||
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
|
||||
use crate::apis::openai::{Message, MessageContent, Role};
|
||||
|
||||
let mut openai_messages = Vec::new();
|
||||
|
||||
// Add system messages if present
|
||||
if let Some(system) = &self.system {
|
||||
for sys_block in system {
|
||||
match sys_block {
|
||||
SystemContentBlock::Text { text } => {
|
||||
openai_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(text.clone()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
_ => {} // Skip other system content types
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert conversation messages
|
||||
if let Some(messages) = &self.messages {
|
||||
for msg in messages {
|
||||
let role = match msg.role {
|
||||
ConversationRole::User => Role::User,
|
||||
ConversationRole::Assistant => Role::Assistant,
|
||||
};
|
||||
|
||||
// Extract text from content blocks
|
||||
let content = msg.content.iter()
|
||||
.filter_map(|block| {
|
||||
if let ContentBlock::Text { text } = block {
|
||||
Some(text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
openai_messages.push(Message {
|
||||
role,
|
||||
content: MessageContent::Text(content),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
openai_messages
|
||||
}
|
||||
|
||||
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
|
||||
// Convert OpenAI messages to Bedrock format
|
||||
use crate::apis::amazon_bedrock::{ContentBlock, ConversationRole, SystemContentBlock};
|
||||
|
||||
let mut system_blocks = Vec::new();
|
||||
let mut bedrock_messages = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
match msg.role {
|
||||
crate::apis::openai::Role::System => {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
system_blocks.push(SystemContentBlock::Text { text: text.clone() });
|
||||
}
|
||||
}
|
||||
crate::apis::openai::Role::User | crate::apis::openai::Role::Assistant => {
|
||||
let role = match msg.role {
|
||||
crate::apis::openai::Role::User => ConversationRole::User,
|
||||
crate::apis::openai::Role::Assistant => ConversationRole::Assistant,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let content = if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
vec![ContentBlock::Text { text: text.clone() }]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
bedrock_messages.push(crate::apis::amazon_bedrock::Message {
|
||||
role,
|
||||
content,
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if !system_blocks.is_empty() {
|
||||
self.system = Some(system_blocks);
|
||||
}
|
||||
self.messages = Some(bedrock_messages);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ use std::collections::HashMap;
|
|||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse};
|
||||
use crate::providers::response::ProviderResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::MESSAGES_PATH;
|
||||
|
||||
|
|
@ -397,6 +398,8 @@ pub enum MessagesContentDelta {
|
|||
InputJsonDelta { partial_json: String },
|
||||
#[serde(rename = "thinking_delta")]
|
||||
ThinkingDelta { thinking: String },
|
||||
#[serde(rename = "signature_delta")]
|
||||
SignatureDelta { signature: String },
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
|
|
@ -512,6 +515,12 @@ impl ProviderRequest for MessagesRequest {
|
|||
None
|
||||
}
|
||||
|
||||
fn get_tool_names(&self) -> Option<Vec<String>> {
|
||||
self.tools.as_ref().map(|tools| {
|
||||
tools.iter().map(|tool| tool.name.clone()).collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||
serde_json::to_vec(self).map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to serialize MessagesRequest: {}", e),
|
||||
|
|
@ -530,6 +539,69 @@ impl ProviderRequest for MessagesRequest {
|
|||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn get_temperature(&self) -> Option<f32> {
|
||||
self.temperature
|
||||
}
|
||||
|
||||
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
|
||||
use crate::apis::openai::Message;
|
||||
|
||||
let mut openai_messages = Vec::new();
|
||||
|
||||
// Add system prompt as system message if present
|
||||
if let Some(system) = &self.system {
|
||||
openai_messages.push(system.clone().into());
|
||||
}
|
||||
|
||||
// Convert each Anthropic message to OpenAI format
|
||||
for msg in &self.messages {
|
||||
if let Ok(converted_msgs) = TryInto::<Vec<Message>>::try_into(msg.clone()) {
|
||||
openai_messages.extend(converted_msgs);
|
||||
}
|
||||
}
|
||||
|
||||
openai_messages
|
||||
}
|
||||
|
||||
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
|
||||
// Convert OpenAI messages to Anthropic format
|
||||
// Separate system messages from regular messages
|
||||
let mut system_messages = Vec::new();
|
||||
let mut regular_messages = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
if msg.role == crate::apis::openai::Role::System {
|
||||
system_messages.push(msg.clone());
|
||||
} else {
|
||||
regular_messages.push(msg.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Set system prompt if there are system messages
|
||||
if !system_messages.is_empty() {
|
||||
// Combine all system messages into one
|
||||
let system_text = system_messages.iter()
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
self.system = Some(crate::apis::anthropic::MessagesSystemPrompt::Single(system_text));
|
||||
}
|
||||
|
||||
// Convert regular messages
|
||||
self.messages = regular_messages.iter()
|
||||
.filter_map(|msg| {
|
||||
msg.clone().try_into().ok()
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
|
||||
impl MessagesResponse {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
pub mod amazon_bedrock;
|
||||
pub mod amazon_bedrock_binary_frame;
|
||||
pub mod anthropic;
|
||||
pub mod openai;
|
||||
pub mod sse;
|
||||
pub mod openai_responses;
|
||||
pub mod streaming_shapes;
|
||||
|
||||
// Explicit exports to avoid naming conflicts
|
||||
pub use amazon_bedrock::{AmazonBedrockApi, ConverseRequest, ConverseStreamRequest};
|
||||
|
|
@ -88,8 +88,9 @@ mod tests {
|
|||
fn test_all_variants_method() {
|
||||
// Test that all_variants returns the expected variants
|
||||
let openai_variants = OpenAIApi::all_variants();
|
||||
assert_eq!(openai_variants.len(), 1);
|
||||
assert_eq!(openai_variants.len(), 2);
|
||||
assert!(openai_variants.contains(&OpenAIApi::ChatCompletions));
|
||||
assert!(openai_variants.contains(&OpenAIApi::Responses));
|
||||
|
||||
let anthropic_variants = AnthropicApi::all_variants();
|
||||
assert_eq!(anthropic_variants.len(), 1);
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ use thiserror::Error;
|
|||
|
||||
use super::ApiDefinition;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage};
|
||||
use crate::providers::response::{ProviderResponse, TokenUsage};
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::CHAT_COMPLETIONS_PATH;
|
||||
use crate::{CHAT_COMPLETIONS_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
|
||||
// ============================================================================
|
||||
// OPENAI API ENUMERATION
|
||||
|
|
@ -19,6 +20,7 @@ use crate::CHAT_COMPLETIONS_PATH;
|
|||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum OpenAIApi {
|
||||
ChatCompletions,
|
||||
Responses,
|
||||
// Future APIs can be added here:
|
||||
// Embeddings,
|
||||
// FineTuning,
|
||||
|
|
@ -29,12 +31,14 @@ impl ApiDefinition for OpenAIApi {
|
|||
fn endpoint(&self) -> &'static str {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => CHAT_COMPLETIONS_PATH,
|
||||
OpenAIApi::Responses => OPENAI_RESPONSES_API_PATH,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
match endpoint {
|
||||
CHAT_COMPLETIONS_PATH => Some(OpenAIApi::ChatCompletions),
|
||||
OPENAI_RESPONSES_API_PATH => Some(OpenAIApi::Responses),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -42,23 +46,26 @@ impl ApiDefinition for OpenAIApi {
|
|||
fn supports_streaming(&self) -> bool {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
OpenAIApi::Responses => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
OpenAIApi::Responses => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_vision(&self) -> bool {
|
||||
match self {
|
||||
OpenAIApi::ChatCompletions => true,
|
||||
OpenAIApi::Responses => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn all_variants() -> Vec<Self> {
|
||||
vec![OpenAIApi::ChatCompletions]
|
||||
vec![OpenAIApi::ChatCompletions, OpenAIApi::Responses]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -101,6 +108,12 @@ pub struct ChatCompletionsRequest {
|
|||
pub top_logprobs: Option<u32>,
|
||||
pub user: Option<String>,
|
||||
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
|
||||
|
||||
// VLLM-specific parameters (used by Arch-Function)
|
||||
pub top_k: Option<u32>,
|
||||
pub stop_token_ids: Option<Vec<u32>>,
|
||||
pub continue_final_message: Option<bool>,
|
||||
pub add_generation_prompt: Option<bool>,
|
||||
}
|
||||
|
||||
impl ChatCompletionsRequest {
|
||||
|
|
@ -385,6 +398,8 @@ pub struct ChatCompletionsResponse {
|
|||
pub usage: Usage,
|
||||
pub system_fingerprint: Option<String>,
|
||||
pub service_tier: Option<String>,
|
||||
// This isn't a standard OpenAI field, but we include it for extensibility
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
impl Default for ChatCompletionsResponse {
|
||||
|
|
@ -398,6 +413,7 @@ impl Default for ChatCompletionsResponse {
|
|||
usage: Usage::default(),
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -671,6 +687,32 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
})
|
||||
}
|
||||
|
||||
fn get_tool_names(&self) -> Option<Vec<String>> {
|
||||
// First check the 'tools' field (current API)
|
||||
if let Some(tools) = &self.tools {
|
||||
let names: Vec<String> = tools
|
||||
.iter()
|
||||
.map(|tool| tool.function.name.clone())
|
||||
.collect();
|
||||
if !names.is_empty() {
|
||||
return Some(names);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to 'functions' field (deprecated but still supported)
|
||||
if let Some(functions) = &self.functions {
|
||||
let names: Vec<String> = functions
|
||||
.iter()
|
||||
.map(|func| func.function.name.clone())
|
||||
.collect();
|
||||
if !names.is_empty() {
|
||||
return Some(names);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError> {
|
||||
serde_json::to_vec(&self).map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to serialize OpenAI request: {}", e),
|
||||
|
|
@ -689,6 +731,18 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn get_temperature(&self) -> Option<f32> {
|
||||
self.temperature
|
||||
}
|
||||
|
||||
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
|
||||
self.messages.clone()
|
||||
}
|
||||
|
||||
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
|
||||
self.messages = messages.to_vec();
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of ProviderResponse for ChatCompletionsResponse
|
||||
|
|
@ -1068,8 +1122,9 @@ mod tests {
|
|||
|
||||
// Test all_variants
|
||||
let all_variants = OpenAIApi::all_variants();
|
||||
assert_eq!(all_variants.len(), 1);
|
||||
assert_eq!(all_variants[0], OpenAIApi::ChatCompletions);
|
||||
assert_eq!(all_variants.len(), 2);
|
||||
assert!(all_variants.contains(&OpenAIApi::ChatCompletions));
|
||||
assert!(all_variants.contains(&OpenAIApi::Responses));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
1573
crates/hermesllm/src/apis/openai_responses.rs
Normal file
1573
crates/hermesllm/src/apis/openai_responses.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,7 +1,6 @@
|
|||
use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
use aws_smithy_eventstream::frame::MessageFrameDecoder;
|
||||
use bytes::Buf;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// AWS Event Stream frame decoder wrapper
|
||||
pub struct BedrockBinaryFrameDecoder<B>
|
||||
|
|
@ -10,7 +9,6 @@ where
|
|||
{
|
||||
decoder: MessageFrameDecoder,
|
||||
buffer: B,
|
||||
content_block_start_indices: HashSet<i32>,
|
||||
}
|
||||
|
||||
impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
|
||||
|
|
@ -20,7 +18,6 @@ impl BedrockBinaryFrameDecoder<bytes::BytesMut> {
|
|||
Self {
|
||||
decoder: MessageFrameDecoder::new(),
|
||||
buffer,
|
||||
content_block_start_indices: std::collections::HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -33,7 +30,6 @@ where
|
|||
Self {
|
||||
decoder: MessageFrameDecoder::new(),
|
||||
buffer,
|
||||
content_block_start_indices: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -52,14 +48,4 @@ where
|
|||
pub fn has_remaining(&self) -> bool {
|
||||
self.buffer.has_remaining()
|
||||
}
|
||||
|
||||
/// Check if a content_block_start event has been sent for the given index
|
||||
pub fn has_content_block_start_been_sent(&self, index: i32) -> bool {
|
||||
self.content_block_start_indices.contains(&index)
|
||||
}
|
||||
|
||||
/// Mark that a content_block_start event has been sent for the given index
|
||||
pub fn set_content_block_start_sent(&mut self, index: i32) {
|
||||
self.content_block_start_indices.insert(index);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,507 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// SSE Stream Buffer for Anthropic Messages API streaming.
|
||||
///
|
||||
/// This buffer manages the wire format for Anthropic Messages API streaming,
|
||||
/// handling the specific event sequencing requirements:
|
||||
/// - MessageStart → ContentBlockStart → ContentBlockDelta(s) → ContentBlockStop → MessageDelta → MessageStop
|
||||
///
|
||||
/// When converting from OpenAI to Anthropic format, this buffer injects the required
|
||||
/// ContentBlockStart and ContentBlockStop events to maintain proper Anthropic protocol.
|
||||
pub struct AnthropicMessagesStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
|
||||
/// Track if we've seen a message_start event
|
||||
message_started: bool,
|
||||
|
||||
/// Track content block indices that have received ContentBlockStart events
|
||||
content_block_start_indices: HashSet<i32>,
|
||||
|
||||
/// Track if we need to inject ContentBlockStop before message_delta
|
||||
needs_content_block_stop: bool,
|
||||
|
||||
/// Track if we've seen a MessageDelta (so we need to send MessageStop at the end)
|
||||
seen_message_delta: bool,
|
||||
|
||||
/// Model name to use when generating message_start events
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl AnthropicMessagesStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
message_started: false,
|
||||
content_block_start_indices: HashSet::new(),
|
||||
needs_content_block_stop: false,
|
||||
seen_message_delta: false,
|
||||
model: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a content_block_start event has been sent for the given index
|
||||
fn has_content_block_start_been_sent(&self, index: i32) -> bool {
|
||||
self.content_block_start_indices.contains(&index)
|
||||
}
|
||||
|
||||
/// Mark that a content_block_start event has been sent for the given index
|
||||
fn set_content_block_start_sent(&mut self, index: i32) {
|
||||
self.content_block_start_indices.insert(index);
|
||||
}
|
||||
|
||||
/// Helper to create and format a ContentBlockStart SSE event
|
||||
fn create_content_block_start_event() -> SseEvent {
|
||||
let content_block_start = MessagesStreamEvent::ContentBlockStart {
|
||||
index: 0,
|
||||
content_block: crate::apis::anthropic::MessagesContentBlock::Text {
|
||||
text: String::new(),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
let sse_string: String = content_block_start.into();
|
||||
|
||||
SseEvent {
|
||||
data: None,
|
||||
event: Some("content_block_start".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create and format a MessageStart SSE event
|
||||
fn create_message_start_event(model: &str) -> SseEvent {
|
||||
let message_start = MessagesStreamEvent::MessageStart {
|
||||
message: crate::apis::anthropic::MessagesStreamMessage {
|
||||
id: format!("msg_{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
|
||||
obj_type: "message".to_string(),
|
||||
role: crate::apis::anthropic::MessagesRole::Assistant,
|
||||
content: vec![],
|
||||
model: model.to_string(),
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: crate::apis::anthropic::MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
};
|
||||
let sse_string: String = message_start.into();
|
||||
|
||||
SseEvent {
|
||||
data: None,
|
||||
event: Some("message_start".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create and format a ContentBlockStop SSE event
|
||||
fn create_content_block_stop_event() -> SseEvent {
|
||||
let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 };
|
||||
let sse_string: String = content_block_stop.into();
|
||||
|
||||
SseEvent {
|
||||
data: None,
|
||||
event: Some("content_block_stop".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// FIRST: Try to extract model name from the raw event data before transformation
|
||||
// The provider_stream_response has already been transformed to Anthropic format,
|
||||
// so we need to extract the model from the original raw data if available
|
||||
if self.model.is_none() {
|
||||
if let Some(data) = &event.data {
|
||||
// Try to parse as JSON and extract model field
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
|
||||
if let Some(model) = json.get("model").and_then(|m| m.as_str()) {
|
||||
self.model = Some(model.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Match directly on the provider response type to handle event processing
|
||||
// We match on a reference first to determine the type, then move the event
|
||||
match &event.provider_stream_response {
|
||||
Some(ProviderStreamResponseType::MessagesStreamEvent(evt)) => {
|
||||
match evt {
|
||||
MessagesStreamEvent::MessageStart { .. } => {
|
||||
// Add the message_start event
|
||||
self.buffered_events.push(event);
|
||||
self.message_started = true;
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockStart { index, .. } => {
|
||||
let index = *index as i32;
|
||||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
||||
// Add the content_block_start event (from tool calls or other sources)
|
||||
self.buffered_events.push(event);
|
||||
self.set_content_block_start_sent(index);
|
||||
self.needs_content_block_stop = true;
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { index, .. } => {
|
||||
let index = *index as i32;
|
||||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
||||
// Check if ContentBlockStart was sent for this index
|
||||
if !self.has_content_block_start_been_sent(index) {
|
||||
// Inject ContentBlockStart before delta
|
||||
let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event();
|
||||
self.buffered_events.push(content_block_start);
|
||||
self.set_content_block_start_sent(index);
|
||||
self.needs_content_block_stop = true;
|
||||
}
|
||||
|
||||
// Content deltas are between ContentBlockStart and ContentBlockStop
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
MessagesStreamEvent::MessageDelta { usage, .. } => {
|
||||
// Inject ContentBlockStop before message_delta
|
||||
if self.needs_content_block_stop {
|
||||
let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event();
|
||||
self.buffered_events.push(content_block_stop);
|
||||
self.needs_content_block_stop = false;
|
||||
}
|
||||
|
||||
// Check if the last event was also a MessageDelta - if so, merge them
|
||||
// This handles Bedrock's split of stop_reason (MessageStop) and usage (Metadata)
|
||||
if let Some(last_event) = self.buffered_events.last_mut() {
|
||||
if let Some(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
MessagesStreamEvent::MessageDelta {
|
||||
usage: last_usage,
|
||||
..
|
||||
}
|
||||
)) = &mut last_event.provider_stream_response {
|
||||
// Merge: take stop_reason from first, usage from second (if non-zero)
|
||||
if usage.input_tokens > 0 || usage.output_tokens > 0 {
|
||||
*last_usage = usage.clone();
|
||||
}
|
||||
// Mark that we've seen MessageDelta (need to send MessageStop later)
|
||||
self.seen_message_delta = true;
|
||||
// Don't push the new event, we've merged it
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// No previous MessageDelta to merge with, add this one
|
||||
self.buffered_events.push(event);
|
||||
self.seen_message_delta = true;
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => {
|
||||
// ContentBlockStop received from upstream (e.g., Bedrock)
|
||||
// Clear the flag so we don't inject another one
|
||||
self.needs_content_block_stop = false;
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
MessagesStreamEvent::MessageStop => {
|
||||
// MessageStop received from upstream (e.g., OpenAI via [DONE])
|
||||
// Clear the flag so we don't inject another one
|
||||
self.seen_message_delta = false;
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
_ => {
|
||||
// Other Anthropic event types (Ping, etc.), just accumulate
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Non-Anthropic events or events without provider_stream_response, just accumulate
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// Convert all accumulated events to bytes and clear buffer
|
||||
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
|
||||
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
|
||||
|
||||
// Inject MessageStop after MessageDelta if we've seen one
|
||||
// This completes the Anthropic Messages API event sequence
|
||||
if self.seen_message_delta {
|
||||
let message_stop = MessagesStreamEvent::MessageStop;
|
||||
let sse_string: String = message_stop.into();
|
||||
let message_stop_event = SseEvent {
|
||||
data: None,
|
||||
event: Some("message_stop".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
};
|
||||
self.buffered_events.push(message_stop_event);
|
||||
self.seen_message_delta = false;
|
||||
}
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_complete_transformation() {
|
||||
// OpenAI ChatCompletions input that will be transformed to Anthropic Messages API
|
||||
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 1: OpenAI → Anthropic Messages API Complete Transformation");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (OpenAI ChatCompletions):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation (client wants Anthropic, upstream is OpenAI)
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse events and apply transformation
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("event: message_start"), "Should have message_start");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
|
||||
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events");
|
||||
|
||||
// Verify both pieces of content are present
|
||||
assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'");
|
||||
assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'");
|
||||
|
||||
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
|
||||
assert!(output.contains("event: message_delta"), "Should have message_delta");
|
||||
assert!(output.contains("event: message_stop"), "Should have message_stop");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API");
|
||||
println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop");
|
||||
println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count);
|
||||
println!("✓ Complete stream with message_stop");
|
||||
println!("✓ Proper Anthropic protocol sequencing\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_partial_transformation() {
|
||||
// Partial OpenAI ChatCompletions stream - no [DONE]
|
||||
let raw_input = r#"data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"The weather"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" in San Francisco"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" is"},"finish_reason":null}]}"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 2: OpenAI → Anthropic Partial Transformation (NO [DONE])");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (OpenAI ChatCompletions - NO [DONE]):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse and transform events
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("event: message_start"), "Should have message_start");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
|
||||
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events");
|
||||
|
||||
// Verify all three pieces of content are present
|
||||
assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta");
|
||||
assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta");
|
||||
assert!(output.contains("\"text\":\" is\""), "Should have third content delta");
|
||||
|
||||
// For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop
|
||||
// because the stream may continue. This is correct behavior - only inject lifecycle events
|
||||
// when we have explicit signals from upstream (finish_reason, [DONE], etc.)
|
||||
assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream");
|
||||
|
||||
// Should NOT have completion events
|
||||
assert!(!output.contains("event: message_delta"), "Should NOT have message_delta");
|
||||
assert!(!output.contains("event: message_stop"), "Should NOT have message_stop");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)");
|
||||
println!("✓ Injected: message_start, content_block_start at beginning");
|
||||
println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count);
|
||||
println!("✓ NO completion events (partial stream, no [DONE])");
|
||||
println!("✓ Buffer maintains Anthropic protocol for active streams\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_tool_calling_to_anthropic_transformation() {
|
||||
// OpenAI ChatCompletions tool calling stream
|
||||
let raw_input = r#"data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_2Uzw0AEZQeOex2CP2TKjcLKc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"uSpCcO"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"24WSqt08jtf"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"6CleV8twTxkKYg"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"San"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Francisco"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"1XLz89l3v"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":","}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"sh"}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" CA"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-Cgx6pZPBgfLcMqfT0ILIH2mID2zWQ","object":"chat.completion.chunk","created":1764353027,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"obfuscation":"I"}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 3: OpenAI Tool Calling → Anthropic Messages API Transformation");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (OpenAI ChatCompletions with Tool Calls):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse and transform events
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions for tool calling transformation
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
|
||||
// Should have lifecycle events (injected by buffer)
|
||||
assert!(output.contains("event: message_start"), "Should have message_start (injected)");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start");
|
||||
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
|
||||
assert!(output.contains("event: message_delta"), "Should have message_delta");
|
||||
assert!(output.contains("event: message_stop"), "Should have message_stop");
|
||||
|
||||
// Should have tool_use content block
|
||||
assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type");
|
||||
assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name");
|
||||
assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID");
|
||||
|
||||
// Count input_json_delta events - should match the number of argument chunks
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert!(delta_count >= 8, "Should have at least 8 input_json_delta events");
|
||||
|
||||
// Verify argument deltas are present
|
||||
assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type");
|
||||
assert!(output.contains("\"partial_json\":"), "Should have partial_json field");
|
||||
|
||||
// Verify the accumulated arguments contain the location
|
||||
assert!(output.contains("San"), "Arguments should contain 'San'");
|
||||
assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'");
|
||||
assert!(output.contains("CA"), "Arguments should contain 'CA'");
|
||||
|
||||
// Verify stop reason is tool_use
|
||||
assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Complete tool calling transformation: OpenAI → Anthropic Messages API");
|
||||
println!("✓ Injected lifecycle: message_start, content_block_stop");
|
||||
println!("✓ Tool metadata: name='get_weather', id='call_2Uzw0AEZQeOex2CP2TKjcLKc'");
|
||||
println!("✓ Argument deltas: {} events", delta_count);
|
||||
println!("✓ Complete JSON arguments: '{{\"location\":\"San Francisco, CA\"}}'");
|
||||
println!("✓ Stop reason: tool_use");
|
||||
println!("✓ Proper Anthropic tool_use protocol\n");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
|
||||
/// OpenAI Chat Completions SSE Stream Buffer for when client and upstream APIs match.
|
||||
pub struct OpenAIChatCompletionsStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl OpenAIChatCompletionsStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// For OpenAI Chat Completions, events are already properly transformed
|
||||
// Just accumulate them for later wire transmission
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// No finalization needed for OpenAI Chat Completions
|
||||
// The [DONE] marker is already handled by the transformation layer
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
7
crates/hermesllm/src/apis/streaming_shapes/mod.rs
Normal file
7
crates/hermesllm/src/apis/streaming_shapes/mod.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
pub mod sse;
|
||||
pub mod sse_chunk_processor;
|
||||
pub mod amazon_bedrock_binary_frame;
|
||||
pub mod anthropic_streaming_buffer;
|
||||
pub mod chat_completions_streaming_buffer;
|
||||
pub mod passthrough_streaming_buffer;
|
||||
pub mod responses_api_streaming_buffer;
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
|
||||
/// Passthrough SSE Stream Buffer for when client and upstream APIs match.
|
||||
pub struct PassthroughStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl PassthroughStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for PassthroughStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Skip events with empty transformed lines (e.g., suppressed event-only lines)
|
||||
if event.sse_transformed_lines.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Just accumulate events as-is
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// No finalization needed for passthrough - just convert accumulated events to bytes
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
|
||||
use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait};
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_passthrough_buffer() {
|
||||
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 1: ChatCompletions Passthrough Buffer");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (ChatCompletions):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Parse and process through buffer
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = PassthroughStreamBuffer::new();
|
||||
|
||||
for event in stream_iter {
|
||||
buffer.add_transformed_event(event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty());
|
||||
assert!(output.contains("chatcmpl-123"));
|
||||
assert!(output.contains("[DONE]"));
|
||||
assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Passthrough buffer: input = output (no transformation)");
|
||||
println!("✓ All events preserved including [DONE]");
|
||||
println!("✓ Function calling events preserved\n");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,649 @@
|
|||
use std::collections::HashMap;
|
||||
use log::debug;
|
||||
use crate::apis::openai_responses::{
|
||||
ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus,
|
||||
ResponseStatus, TextConfig, TextFormat, Reasoning,
|
||||
};
|
||||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
|
||||
/// Helper to convert ResponseAPIStreamEvent to SseEvent
|
||||
fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
|
||||
let event_type = match &event {
|
||||
ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created",
|
||||
ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress",
|
||||
ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed",
|
||||
ResponsesAPIStreamEvent::ResponseOutputItemAdded { .. } => "response.output_item.added",
|
||||
ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done",
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta",
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done",
|
||||
unknown => {
|
||||
debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown);
|
||||
"unknown"
|
||||
}
|
||||
};
|
||||
|
||||
let json_data = match serde_json::to_string(&event) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
debug!("Error serializing ResponsesAPIStreamEvent to JSON: {}", e);
|
||||
String::new()
|
||||
}
|
||||
};
|
||||
let wire_format: String = event.into();
|
||||
|
||||
SseEvent {
|
||||
data: Some(json_data),
|
||||
event: Some(event_type.to_string()),
|
||||
raw_line: wire_format.clone(),
|
||||
sse_transformed_lines: wire_format,
|
||||
provider_stream_response: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// SSE Stream Buffer for ResponsesAPIStreamEvent with full lifecycle management.
|
||||
///
|
||||
/// This buffer manages the wire format for v1/responses streaming, handling
|
||||
/// delta events and emitting complete lifecycle events.
|
||||
///
|
||||
pub struct ResponsesAPIStreamBuffer {
|
||||
/// Sequence number for events
|
||||
sequence_number: i32,
|
||||
|
||||
/// Track item IDs by output index
|
||||
item_ids: HashMap<i32, String>,
|
||||
|
||||
/// Response metadata
|
||||
response_id: Option<String>,
|
||||
model: Option<String>,
|
||||
created_at: Option<i64>,
|
||||
|
||||
/// Full response metadata from upstream (tools, temperature, etc.)
|
||||
/// This is extracted from the first upstream event and used to build
|
||||
/// complete response.created and response.in_progress events
|
||||
upstream_response_metadata: Option<ResponsesAPIResponse>,
|
||||
|
||||
/// Lifecycle state flags
|
||||
created_emitted: bool,
|
||||
in_progress_emitted: bool,
|
||||
|
||||
/// Track which output items we've added
|
||||
output_items_added: HashMap<i32, String>, // output_index -> item_id
|
||||
|
||||
/// Accumulated content by item_id
|
||||
text_content: HashMap<String, String>,
|
||||
function_arguments: HashMap<String, String>,
|
||||
|
||||
/// Tool call metadata by output_index
|
||||
tool_call_metadata: HashMap<i32, (String, String)>, // output_index -> (call_id, name)
|
||||
|
||||
/// Final completed response (for logging/tracing/persistence)
|
||||
completed_response: Option<ResponsesAPIResponse>,
|
||||
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl ResponsesAPIStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sequence_number: 0,
|
||||
item_ids: HashMap::new(),
|
||||
response_id: None,
|
||||
model: None,
|
||||
created_at: None,
|
||||
upstream_response_metadata: None,
|
||||
created_emitted: false,
|
||||
in_progress_emitted: false,
|
||||
output_items_added: HashMap::new(),
|
||||
text_content: HashMap::new(),
|
||||
function_arguments: HashMap::new(),
|
||||
tool_call_metadata: HashMap::new(),
|
||||
completed_response: None,
|
||||
buffered_events: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_sequence_number(&mut self) -> i32 {
|
||||
let seq = self.sequence_number;
|
||||
self.sequence_number += 1;
|
||||
seq
|
||||
}
|
||||
|
||||
fn generate_item_id(prefix: &str) -> String {
|
||||
format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", ""))
|
||||
}
|
||||
|
||||
fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String {
|
||||
if let Some(id) = self.item_ids.get(&output_index) {
|
||||
return id.clone();
|
||||
}
|
||||
let id = ResponsesAPIStreamBuffer::generate_item_id(prefix);
|
||||
self.item_ids.insert(output_index, id.clone());
|
||||
id
|
||||
}
|
||||
|
||||
/// Create response.created event
|
||||
fn create_response_created_event(&mut self) -> SseEvent {
|
||||
let response = self.build_response(ResponseStatus::InProgress);
|
||||
let event = ResponsesAPIStreamEvent::ResponseCreated {
|
||||
response,
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Create response.in_progress event
|
||||
fn create_response_in_progress_event(&mut self) -> SseEvent {
|
||||
let response = self.build_response(ResponseStatus::InProgress);
|
||||
let event = ResponsesAPIStreamEvent::ResponseInProgress {
|
||||
response,
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Create output_item.added event for text
|
||||
fn create_output_item_added_event(&mut self, output_index: i32, item_id: &str) -> SseEvent {
|
||||
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
|
||||
output_index,
|
||||
item: OutputItem::Message {
|
||||
id: item_id.to_string(),
|
||||
status: OutputItemStatus::InProgress,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
},
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Create output_item.added event for tool call
|
||||
fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent {
|
||||
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
|
||||
output_index,
|
||||
item: OutputItem::FunctionCall {
|
||||
id: item_id.to_string(),
|
||||
status: OutputItemStatus::InProgress,
|
||||
call_id: call_id.to_string(),
|
||||
name: Some(name.to_string()),
|
||||
arguments: Some(String::new()),
|
||||
},
|
||||
sequence_number: self.next_sequence_number(),
|
||||
};
|
||||
event_to_sse(event)
|
||||
}
|
||||
|
||||
/// Build the base response object with current state
|
||||
fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse {
|
||||
// If we have upstream metadata, use it as a base and update status/output
|
||||
if let Some(upstream) = &self.upstream_response_metadata {
|
||||
let mut response = upstream.clone();
|
||||
response.status = status;
|
||||
// Don't update output here - will be set in finalize()
|
||||
return response;
|
||||
}
|
||||
|
||||
// Fallback: build a minimal response from local state
|
||||
ResponsesAPIResponse {
|
||||
id: self.response_id.clone().unwrap_or_default(),
|
||||
object: "response".to_string(),
|
||||
created_at: self.created_at.unwrap_or(0),
|
||||
status,
|
||||
error: None,
|
||||
incomplete_details: None,
|
||||
instructions: None,
|
||||
model: self.model.clone().unwrap_or_else(|| "unknown".to_string()),
|
||||
output: vec![],
|
||||
usage: None,
|
||||
parallel_tool_calls: true,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
tools: vec![],
|
||||
tool_choice: "auto".to_string(),
|
||||
temperature: 1.0,
|
||||
top_p: 1.0,
|
||||
metadata: HashMap::new(),
|
||||
truncation: Some("disabled".to_string()),
|
||||
max_output_tokens: None,
|
||||
reasoning: Some(Reasoning {
|
||||
effort: None,
|
||||
summary: None,
|
||||
}),
|
||||
store: Some(true),
|
||||
text: Some(TextConfig {
|
||||
format: TextFormat::Text,
|
||||
}),
|
||||
audio: None,
|
||||
modalities: None,
|
||||
service_tier: Some("auto".to_string()),
|
||||
background: Some(false),
|
||||
top_logprobs: Some(0),
|
||||
max_tool_calls: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the completed response after finalization (for logging/tracing/persistence)
|
||||
pub fn get_completed_response(&self) -> Option<&ResponsesAPIResponse> {
|
||||
self.completed_response.as_ref()
|
||||
}
|
||||
|
||||
/// Finalize the response by emitting all *.done events and response.completed.
|
||||
/// Call this when the stream is complete (after seeing [DONE] or end_of_stream).
|
||||
pub fn finalize(&mut self) {
|
||||
let mut events = Vec::new();
|
||||
|
||||
// Emit done events for all accumulated content
|
||||
|
||||
// Text content done events
|
||||
let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect();
|
||||
for (item_id, content) in text_items {
|
||||
let output_index = self.output_items_added.iter()
|
||||
.find(|(_, id)| **id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
||||
let seq1 = self.next_sequence_number();
|
||||
let text_done_event = ResponsesAPIStreamEvent::ResponseOutputTextDone {
|
||||
item_id: item_id.clone(),
|
||||
output_index,
|
||||
content_index: 0,
|
||||
text: content.clone(),
|
||||
logprobs: vec![],
|
||||
sequence_number: seq1,
|
||||
};
|
||||
events.push(event_to_sse(text_done_event));
|
||||
|
||||
let seq2 = self.next_sequence_number();
|
||||
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
|
||||
output_index,
|
||||
item: OutputItem::Message {
|
||||
id: item_id.clone(),
|
||||
status: OutputItemStatus::Completed,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
},
|
||||
sequence_number: seq2,
|
||||
};
|
||||
events.push(event_to_sse(item_done_event));
|
||||
}
|
||||
|
||||
// Function call done events
|
||||
let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect();
|
||||
for (item_id, arguments) in func_items {
|
||||
let output_index = self.output_items_added.iter()
|
||||
.find(|(_, id)| **id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
||||
let seq1 = self.next_sequence_number();
|
||||
let args_done_event = ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone {
|
||||
output_index,
|
||||
item_id: item_id.clone(),
|
||||
arguments: arguments.clone(),
|
||||
sequence_number: seq1,
|
||||
};
|
||||
events.push(event_to_sse(args_done_event));
|
||||
|
||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
|
||||
let seq2 = self.next_sequence_number();
|
||||
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
|
||||
output_index,
|
||||
item: OutputItem::FunctionCall {
|
||||
id: item_id.clone(),
|
||||
status: OutputItemStatus::Completed,
|
||||
call_id,
|
||||
name: Some(name),
|
||||
arguments: Some(arguments.clone()),
|
||||
},
|
||||
sequence_number: seq2,
|
||||
};
|
||||
events.push(event_to_sse(item_done_event));
|
||||
}
|
||||
|
||||
// Build final response
|
||||
let mut output_items = Vec::new();
|
||||
|
||||
// Build complete output array by iterating through all output indices in order
|
||||
let max_output_index = self.output_items_added.keys().max().copied().unwrap_or(-1);
|
||||
|
||||
for output_index in 0..=max_output_index {
|
||||
if let Some(item_id) = self.output_items_added.get(&output_index) {
|
||||
// Check if this is a function call
|
||||
if let Some(arguments) = self.function_arguments.get(item_id) {
|
||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
|
||||
output_items.push(OutputItem::FunctionCall {
|
||||
id: item_id.clone(),
|
||||
status: OutputItemStatus::Completed,
|
||||
call_id,
|
||||
name: Some(name),
|
||||
arguments: Some(arguments.clone()),
|
||||
});
|
||||
}
|
||||
// Check if this is a text message
|
||||
else if let Some(text) = self.text_content.get(item_id) {
|
||||
use crate::apis::openai_responses::OutputContent;
|
||||
output_items.push(OutputItem::Message {
|
||||
id: item_id.clone(),
|
||||
status: OutputItemStatus::Completed,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![OutputContent::OutputText {
|
||||
text: text.clone(),
|
||||
annotations: vec![],
|
||||
logprobs: None,
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut final_response = self.build_response(ResponseStatus::Completed);
|
||||
final_response.output = output_items;
|
||||
|
||||
// Store completed response
|
||||
self.completed_response = Some(final_response.clone());
|
||||
|
||||
// Emit response.completed
|
||||
let seq_final = self.next_sequence_number();
|
||||
let completed_event = ResponsesAPIStreamEvent::ResponseCompleted {
|
||||
response: final_response,
|
||||
sequence_number: seq_final,
|
||||
};
|
||||
events.push(event_to_sse(completed_event));
|
||||
|
||||
// Add all finalization events to the buffer
|
||||
self.buffered_events.extend(events);
|
||||
}
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
// Skip ping messages
|
||||
if event.should_skip() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle [DONE] marker - trigger finalization
|
||||
if event.is_done() {
|
||||
self.finalize();
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract the ResponseAPIStreamEvent from the SseEvent's provider_stream_response
|
||||
let provider_response = match event.provider_stream_response.as_ref() {
|
||||
Some(response) => response,
|
||||
None => {
|
||||
eprintln!("Warning: Event missing provider_stream_response");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Extract ResponseAPIStreamEvent from the enum
|
||||
let stream_event = match provider_response {
|
||||
crate::providers::streaming_response::ProviderStreamResponseType::ResponseAPIStreamEvent(evt) => evt,
|
||||
_ => {
|
||||
eprintln!("Warning: Expected ResponseAPIStreamEvent in provider_stream_response");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
|
||||
// Capture upstream metadata from ResponseCreated or ResponseInProgress if present
|
||||
match stream_event {
|
||||
ResponsesAPIStreamEvent::ResponseCreated { response, .. } |
|
||||
ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
|
||||
if self.upstream_response_metadata.is_none() {
|
||||
// Store the full upstream response as our metadata template
|
||||
self.upstream_response_metadata = Some(response.clone());
|
||||
// Also extract basic fields
|
||||
self.response_id = Some(response.id.clone());
|
||||
self.model = Some(response.model.clone());
|
||||
self.created_at = Some(response.created_at);
|
||||
}
|
||||
// Don't emit these - we'll generate our own lifecycle events
|
||||
return;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Emit lifecycle events if not yet emitted
|
||||
if !self.created_emitted {
|
||||
// Initialize metadata from first event if needed
|
||||
if self.response_id.is_none() {
|
||||
self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", "")));
|
||||
self.created_at = Some(std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64);
|
||||
self.model = Some("unknown".to_string()); // Will be set by caller if available
|
||||
}
|
||||
|
||||
events.push(self.create_response_created_event());
|
||||
self.created_emitted = true;
|
||||
}
|
||||
|
||||
if !self.in_progress_emitted {
|
||||
events.push(self.create_response_in_progress_event());
|
||||
self.in_progress_emitted = true;
|
||||
}
|
||||
|
||||
// Process the delta event
|
||||
match stream_event {
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => {
|
||||
let item_id = self.get_or_create_item_id(*output_index, "msg");
|
||||
|
||||
// Emit output_item.added if this is the first time we see this output index
|
||||
if !self.output_items_added.contains_key(output_index) {
|
||||
self.output_items_added.insert(*output_index, item_id.clone());
|
||||
events.push(self.create_output_item_added_event(*output_index, &item_id));
|
||||
}
|
||||
|
||||
// Accumulate text content
|
||||
self.text_content.entry(item_id.clone())
|
||||
.and_modify(|content| content.push_str(delta))
|
||||
.or_insert_with(|| delta.clone());
|
||||
|
||||
// Emit text delta with filled-in item_id and sequence_number
|
||||
let mut delta_event = stream_event.clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
|
||||
*id = item_id;
|
||||
*seq = self.next_sequence_number();
|
||||
}
|
||||
events.push(event_to_sse(delta_event));
|
||||
}
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => {
|
||||
let item_id = self.get_or_create_item_id(*output_index, "fc");
|
||||
|
||||
// Store metadata if provided (from initial tool call event)
|
||||
if let (Some(cid), Some(n)) = (call_id, name) {
|
||||
self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone()));
|
||||
}
|
||||
|
||||
// Emit output_item.added if this is the first time we see this tool call
|
||||
if !self.output_items_added.contains_key(output_index) {
|
||||
self.output_items_added.insert(*output_index, item_id.clone());
|
||||
|
||||
// For tool calls, we need call_id and name from metadata
|
||||
// These should now be populated from the event itself
|
||||
let (call_id, name) = self.tool_call_metadata.get(output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
|
||||
events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name));
|
||||
}
|
||||
|
||||
// Accumulate function arguments
|
||||
self.function_arguments.entry(item_id.clone())
|
||||
.and_modify(|args| args.push_str(delta))
|
||||
.or_insert_with(|| delta.clone());
|
||||
|
||||
// Emit function call arguments delta with filled-in item_id and sequence_number
|
||||
let mut delta_event = stream_event.clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
|
||||
*id = item_id;
|
||||
*seq = self.next_sequence_number();
|
||||
}
|
||||
events.push(event_to_sse(delta_event));
|
||||
}
|
||||
_ => {
|
||||
// For other event types, just pass through with sequence number
|
||||
let other_event = stream_event.clone();
|
||||
// TODO: Add sequence number to other event types if needed
|
||||
events.push(event_to_sse(other_event));
|
||||
}
|
||||
}
|
||||
|
||||
// Store all generated events in the buffer
|
||||
self.buffered_events.extend(events);
|
||||
}
|
||||
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
// For Responses API, we need special handling:
|
||||
// - Most events are already in buffered_events from add_transformed_event
|
||||
// - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream
|
||||
// - Just flush the accumulated events and clear the buffer
|
||||
|
||||
// Convert all accumulated events to bytes and clear buffer
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
let event_bytes: Vec<u8> = event.into();
|
||||
buffer.extend_from_slice(&event_bytes);
|
||||
}
|
||||
buffer
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_transformation() {
|
||||
// ChatCompletions input that will be transformed to ResponsesAPI
|
||||
let raw_input = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 2: ChatCompletions → ResponsesAPI Transformation (with [DONE])");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (ChatCompletions):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Parse events and apply transformation
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = ResponsesAPIStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
// Transform the event using the client/upstream APIs
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("response.created"), "Should have response.created");
|
||||
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
|
||||
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
|
||||
assert!(output.contains("response.output_text.delta"), "Should have text deltas");
|
||||
assert!(output.contains("response.output_text.done"), "Should have text.done");
|
||||
assert!(output.contains("response.output_item.done"), "Should have output_item.done");
|
||||
assert!(output.contains("response.completed"), "Should have response.completed");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Lifecycle events: response.created, response.in_progress, response.completed");
|
||||
println!("✓ Output item lifecycle: output_item.added, output_item.done");
|
||||
println!("✓ Text streaming: output_text.delta (2 deltas), output_text.done");
|
||||
println!("✓ Complete transformation with finalization ([DONE] processed)\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_streaming_incremental_output() {
|
||||
let raw_input = r#"data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_mD5ggLKk3SMKGPFqFdcpKg6q","type":"function","function":{"name":"get_weather","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"obfuscation":"PCFrpy"}
|
||||
|
||||
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":""}
|
||||
|
||||
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"location"}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"TC58A3QEIx8"}
|
||||
|
||||
data: {"id":"chatcmpl-CfpqklihniLRuuQfP7inMb2ghtGmT","object":"chat.completion.chunk","created":1764086794,"model":"gpt-4o-2024-08-06","service_tier":"default","system_fingerprint":"fp_7eeb46f068","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"obfuscation":"PK4oFzlVlGTUP5"}"#;
|
||||
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("TEST 3: Partial Streaming - Function Calling (NO [DONE])");
|
||||
println!("{}", "=".repeat(80));
|
||||
println!("\nRAW INPUT (ChatCompletions - NO [DONE]):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", raw_input);
|
||||
|
||||
// Setup API configuration for transformation
|
||||
let client_api = SupportedAPIsFromClient::OpenAIResponsesAPI(OpenAIApi::Responses);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Transform all events
|
||||
let stream_iter = SseStreamIter::try_from(raw_input.as_bytes()).unwrap();
|
||||
let mut buffer = ResponsesAPIStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(output.contains("response.created"), "Should have response.created");
|
||||
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
|
||||
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
|
||||
assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type");
|
||||
assert!(output.contains("\"name\":\"get_weather\""), "Should have function name");
|
||||
assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id");
|
||||
|
||||
let delta_count = output.matches("event: response.function_call_arguments.delta").count();
|
||||
assert_eq!(delta_count, 4, "Should have 4 delta events");
|
||||
|
||||
assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done");
|
||||
assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done");
|
||||
assert!(!output.contains("response.completed"), "Should NOT have response.completed");
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Lifecycle events: response.created, response.in_progress");
|
||||
println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'");
|
||||
println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)");
|
||||
println!("✓ NO completion events (partial stream, no [DONE])");
|
||||
println!("✓ Arguments accumulated: '{{\"location\":\"'\n");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,10 +1,73 @@
|
|||
use crate::providers::response::ProviderStreamResponse;
|
||||
use crate::providers::response::ProviderStreamResponseType;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
|
||||
use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer;
|
||||
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
|
||||
use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// Trait defining the interface for SSE stream buffers.
|
||||
///
|
||||
/// This trait is implemented by both the enum `SseStreamBuffer` (for zero-cost dispatch)
|
||||
/// and individual buffer implementations (for direct use).
|
||||
///
|
||||
pub trait SseStreamBufferTrait: Send + Sync {
|
||||
/// Add a transformed SSE event to the buffer.
|
||||
///
|
||||
/// The buffer may inject additional events as needed based on internal state.
|
||||
/// For example, Anthropic buffers inject ContentBlockStart before the first ContentBlockDelta.
|
||||
///
|
||||
/// All events (original + injected) are accumulated internally for the next `into_bytes()` call.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `event` - A transformed SSE event to accumulate
|
||||
fn add_transformed_event(&mut self, event: SseEvent);
|
||||
|
||||
/// Get bytes for all accumulated events since the last call.
|
||||
///
|
||||
/// This method:
|
||||
/// - Converts all buffered events to wire format bytes
|
||||
/// - Clears the internal event buffer
|
||||
/// - Preserves state for subsequent `add_transformed_event()` calls
|
||||
///
|
||||
/// Call this after processing each chunk of upstream events to get bytes for immediate transmission.
|
||||
///
|
||||
/// # Returns
|
||||
/// Bytes ready for wire transmission (may be empty if no events were accumulated)
|
||||
fn into_bytes(&mut self) -> Vec<u8>;
|
||||
}
|
||||
|
||||
/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction
|
||||
pub enum SseStreamBuffer {
|
||||
Passthrough(PassthroughStreamBuffer),
|
||||
OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer),
|
||||
AnthropicMessages(AnthropicMessagesStreamBuffer),
|
||||
OpenAIResponses(ResponsesAPIStreamBuffer),
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for SseStreamBuffer {
|
||||
fn add_transformed_event(&mut self, event: SseEvent) {
|
||||
match self {
|
||||
Self::Passthrough(buffer) => buffer.add_transformed_event(event),
|
||||
Self::OpenAIChatCompletions(buffer) => buffer.add_transformed_event(event),
|
||||
Self::AnthropicMessages(buffer) => buffer.add_transformed_event(event),
|
||||
Self::OpenAIResponses(buffer) => buffer.add_transformed_event(event),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
match self {
|
||||
Self::Passthrough(buffer) => buffer.into_bytes(),
|
||||
Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(),
|
||||
Self::AnthropicMessages(buffer) => buffer.into_bytes(),
|
||||
Self::OpenAIResponses(buffer) => buffer.into_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE EVENT CONTAINER
|
||||
// ============================================================================
|
||||
|
|
@ -22,16 +85,31 @@ pub struct SseEvent {
|
|||
pub raw_line: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub sse_transform_buffer: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
pub sse_transformed_lines: String, // The complete line as received including "data: " prefix and "\n\n"
|
||||
|
||||
#[serde(skip_serializing, skip_deserializing)]
|
||||
pub provider_stream_response: Option<ProviderStreamResponseType>, // Parsed provider stream response object
|
||||
}
|
||||
|
||||
impl SseEvent {
|
||||
/// Create an SseEvent from a ProviderStreamResponseType
|
||||
/// This is useful for binary frame formats (like Bedrock) that need to be converted to SSE
|
||||
pub fn from_provider_response(response: ProviderStreamResponseType) -> Self {
|
||||
// Convert the provider response to SSE format string
|
||||
let sse_string: String = response.clone().into();
|
||||
|
||||
SseEvent {
|
||||
data: None, // Data is embedded in sse_transformed_lines
|
||||
event: None, // Event type is embedded in sse_transformed_lines
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: Some(response),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this event represents the end of the stream
|
||||
pub fn is_done(&self) -> bool {
|
||||
self.data == Some("[DONE]".into())
|
||||
self.data == Some("[DONE]".into()) || self.event == Some("message_stop".into())
|
||||
}
|
||||
|
||||
/// Check if this event should be skipped during processing
|
||||
|
|
@ -61,23 +139,35 @@ impl FromStr for SseEvent {
|
|||
type Err = SseParseError;
|
||||
|
||||
fn from_str(line: &str) -> Result<Self, Self::Err> {
|
||||
if line.starts_with("data: ") {
|
||||
let data: String = line[6..].to_string(); // Remove "data: " prefix
|
||||
if data.is_empty() {
|
||||
// Trim leading/trailing whitespace for parsing
|
||||
let trimmed_line = line.trim();
|
||||
|
||||
// Skip empty or whitespace-only lines (SSE event separators)
|
||||
if trimmed_line.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty line (SSE event separator)".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if trimmed_line.starts_with("data: ") {
|
||||
let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix
|
||||
// Allow empty data content after "data: " prefix
|
||||
// This handles cases like "data: " followed by newline
|
||||
if data.trim().is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty data field is not a valid SSE event".to_string(),
|
||||
message: "Empty data field after 'data: ' prefix".to_string(),
|
||||
});
|
||||
}
|
||||
Ok(SseEvent {
|
||||
data: Some(data),
|
||||
event: None,
|
||||
raw_line: line.to_string(),
|
||||
sse_transform_buffer: line.to_string(),
|
||||
// Preserve original line format for passthrough, use trimmed for transformations
|
||||
sse_transformed_lines: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else if line.starts_with("event: ") {
|
||||
//used by Anthropic
|
||||
let event_type = line[7..].to_string();
|
||||
} else if trimmed_line.starts_with("event: ") {
|
||||
let event_type = trimmed_line[7..].to_string();
|
||||
if event_type.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty event field is not a valid SSE event".to_string(),
|
||||
|
|
@ -87,12 +177,13 @@ impl FromStr for SseEvent {
|
|||
data: None,
|
||||
event: Some(event_type),
|
||||
raw_line: line.to_string(),
|
||||
sse_transform_buffer: line.to_string(),
|
||||
// Preserve original line format for passthrough, use trimmed for transformations
|
||||
sse_transformed_lines: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else {
|
||||
Err(SseParseError {
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", line),
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -100,14 +191,22 @@ impl FromStr for SseEvent {
|
|||
|
||||
impl fmt::Display for SseEvent {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.sse_transform_buffer)
|
||||
write!(f, "{}", self.sse_transformed_lines)
|
||||
}
|
||||
}
|
||||
|
||||
// Into implementation to convert SseEvent to bytes for response buffer
|
||||
impl Into<Vec<u8>> for SseEvent {
|
||||
fn into(self) -> Vec<u8> {
|
||||
format!("{}\n\n", self.sse_transform_buffer).into_bytes()
|
||||
// For generated events (like ResponsesAPI), sse_transformed_lines already includes trailing \n\n
|
||||
// For parsed events (like passthrough), we need to add the \n\n separator
|
||||
if self.sse_transformed_lines.ends_with("\n\n") {
|
||||
// Already properly formatted with trailing newlines
|
||||
self.sse_transformed_lines.into_bytes()
|
||||
} else {
|
||||
// Add SSE event separator
|
||||
format!("{}\n\n", self.sse_transformed_lines).into_bytes()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamIter};
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
||||
/// Stateful processor for handling SSE chunks that may contain incomplete events.
|
||||
///
|
||||
/// This processor buffers incomplete SSE event bytes when transformation fails
|
||||
/// (e.g., due to incomplete JSON) and prepends them to the next chunk for retry.
|
||||
pub struct SseChunkProcessor {
|
||||
/// Buffered bytes from incomplete SSE events across chunks
|
||||
incomplete_event_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl SseChunkProcessor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
incomplete_event_buffer: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a chunk of SSE data, handling incomplete events across chunk boundaries.
|
||||
///
|
||||
/// Returns successfully transformed events. Incomplete events are buffered internally
|
||||
/// and will be retried when more data arrives in the next chunk.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `chunk` - Raw bytes from upstream SSE stream
|
||||
/// * `client_api` - The API format the client expects
|
||||
/// * `upstream_api` - The API format from the upstream provider
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(Vec<SseEvent>)` - Successfully transformed events ready for client
|
||||
/// * `Err(String)` - Fatal error that cannot be recovered by buffering
|
||||
pub fn process_chunk(
|
||||
&mut self,
|
||||
chunk: &[u8],
|
||||
client_api: &SupportedAPIsFromClient,
|
||||
upstream_api: &SupportedUpstreamAPIs,
|
||||
) -> Result<Vec<SseEvent>, String> {
|
||||
// Combine buffered incomplete event with new chunk
|
||||
let mut combined_data = std::mem::take(&mut self.incomplete_event_buffer);
|
||||
combined_data.extend_from_slice(chunk);
|
||||
|
||||
// Parse using SseStreamIter
|
||||
let sse_iter = match SseStreamIter::try_from(combined_data.as_slice()) {
|
||||
Ok(iter) => iter,
|
||||
Err(e) => return Err(format!("Failed to create SSE iterator: {}", e)),
|
||||
};
|
||||
|
||||
let mut transformed_events = Vec::new();
|
||||
|
||||
// Process each parsed SSE event
|
||||
for sse_event in sse_iter {
|
||||
// Try to transform the event (this is where incomplete JSON fails)
|
||||
match SseEvent::try_from((sse_event.clone(), client_api, upstream_api)) {
|
||||
Ok(transformed) => {
|
||||
// Successfully transformed - add to results
|
||||
transformed_events.push(transformed);
|
||||
}
|
||||
Err(e) => {
|
||||
// Check if this is incomplete JSON (EOF while parsing) vs other errors
|
||||
let error_str = e.to_string().to_lowercase();
|
||||
let is_incomplete_json = error_str.contains("eof while parsing")
|
||||
|| error_str.contains("unexpected end of json")
|
||||
|| error_str.contains("unexpected eof");
|
||||
|
||||
if is_incomplete_json {
|
||||
// Incomplete JSON - buffer for retry with next chunk
|
||||
self.incomplete_event_buffer = sse_event.raw_line.as_bytes().to_vec();
|
||||
break;
|
||||
} else {
|
||||
// Other error (unsupported event type, validation error, etc.)
|
||||
// Skip this event and continue processing others
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(transformed_events)
|
||||
}
|
||||
|
||||
/// Check if there are buffered incomplete bytes
|
||||
pub fn has_buffered_data(&self) -> bool {
|
||||
!self.incomplete_event_buffer.is_empty()
|
||||
}
|
||||
|
||||
/// Get the size of buffered incomplete data (for debugging/logging)
|
||||
pub fn buffered_size(&self) -> usize {
|
||||
self.incomplete_event_buffer.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
#[test]
|
||||
fn test_complete_events_process_immediately() {
|
||||
let mut processor = SseChunkProcessor::new();
|
||||
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
|
||||
|
||||
let events = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
assert!(!processor.has_buffered_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_incomplete_json_buffered_and_completed() {
|
||||
let mut processor = SseChunkProcessor::new();
|
||||
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// First chunk with incomplete JSON
|
||||
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chu";
|
||||
|
||||
let events1 = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
|
||||
|
||||
assert_eq!(events1.len(), 0, "Incomplete event should not be processed");
|
||||
assert!(processor.has_buffered_data(), "Incomplete data should be buffered");
|
||||
|
||||
// Second chunk completes the JSON
|
||||
let chunk2 = b"nk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
|
||||
|
||||
let events2 = processor.process_chunk(chunk2, &client_api, &upstream_api).unwrap();
|
||||
|
||||
assert_eq!(events2.len(), 1, "Complete event should be processed");
|
||||
assert!(!processor.has_buffered_data(), "Buffer should be cleared after completion");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_events_with_one_incomplete() {
|
||||
let mut processor = SseChunkProcessor::new();
|
||||
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Chunk with 2 complete events and 1 incomplete
|
||||
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"A\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"B\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-125\",\"object\":\"chat.completion.chu";
|
||||
|
||||
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
|
||||
|
||||
assert_eq!(events.len(), 2, "Two complete events should be processed");
|
||||
assert!(processor.has_buffered_data(), "Incomplete third event should be buffered");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_signature_delta_from_production_logs() {
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
|
||||
let mut processor = SseChunkProcessor::new();
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
|
||||
// Exact chunk from production logs - signature_delta event followed by content_block_stop
|
||||
let chunk = br#"event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":"ErECCkYIChgCKkC7lAf/BOatd0I4NnANYNEDKl5/WSsjNK44AETnLoy3i5FfdYMAb0m4qMLJD6A04QnM4Hf3VpGqq/snA/9vvNxCEgw3CYcHcj0aTdqOisQaDOhlVBtAUKkoh3WopSIwAbJp4jG/41vVWBj63eaR7KFJ37OdY1byjlPkaGDUJRcWc/YfUWIDSAToomq2fB4VKpgBk+swVYxLZ709gQvyTCT+3vO/I+yexZpkx6eBl/+YCgQXTeviZ+hTxSoPVayf5vEQoc19ZA4MEkZ7yBInRgk8vUxAJITSf+vOvDIBsElpgkLfSjARCasjh78wONg39AkAoIbKzU+Q2l1htUwXcqQ2b+b5DrY9+Oxae4pBVGQlWU36XAHsa/KG+ejfdwhWJM7FNL3uphwAf0oYAQ=="}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type":"content_block_stop","index":0}
|
||||
|
||||
"#;
|
||||
|
||||
let result = processor.process_chunk(chunk, &client_api, &upstream_api);
|
||||
|
||||
match result {
|
||||
Ok(events) => {
|
||||
println!("Successfully processed {} events", events.len());
|
||||
for (i, event) in events.iter().enumerate() {
|
||||
println!("Event {}: event={:?}, has_data={}", i, event.event, event.data.is_some());
|
||||
}
|
||||
// Should successfully process both events (signature_delta + content_block_stop)
|
||||
assert!(events.len() >= 2, "Should process at least 2 complete events (signature_delta + stop), got {}", events.len());
|
||||
assert!(!processor.has_buffered_data(), "Complete events should not be buffered");
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Failed to process signature_delta chunk - this means SignatureDelta is not properly handled: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported_event_does_not_block_subsequent_events() {
|
||||
let mut processor = SseChunkProcessor::new();
|
||||
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Chunk with an unsupported/invalid event followed by a valid event
|
||||
// First event has invalid JSON structure that will fail validation (not incomplete)
|
||||
// Second event is valid and should be processed
|
||||
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"unsupported_field_causing_validation_error\":true},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
|
||||
|
||||
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
|
||||
|
||||
// Should skip the invalid event and process the valid one
|
||||
// (If we were buffering all errors, we'd get 0 events and have buffered data)
|
||||
assert!(events.len() >= 1, "Should process at least the valid event, got {} events", events.len());
|
||||
assert!(!processor.has_buffered_data(), "Invalid (non-incomplete) events should not be buffered");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_delta_type_skipped_others_processed() {
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
|
||||
let mut processor = SseChunkProcessor::new();
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
|
||||
// Chunk with valid event, unsupported delta type, then another valid event
|
||||
// This simulates a future API change where Anthropic adds a new delta type we don't support yet
|
||||
let chunk = br#"event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"future_unsupported_delta","future_field":"some_value"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" World"}}
|
||||
|
||||
"#;
|
||||
|
||||
let result = processor.process_chunk(chunk, &client_api, &upstream_api);
|
||||
|
||||
match result {
|
||||
Ok(events) => {
|
||||
println!("Processed {} events (unsupported event should be skipped)", events.len());
|
||||
// Should process the 2 valid text_delta events and skip the unsupported one
|
||||
// We expect at least 2 events (the valid ones), unsupported should be skipped
|
||||
assert!(events.len() >= 2, "Should process at least 2 valid events, got {}", events.len());
|
||||
assert!(!processor.has_buffered_data(), "Unsupported events should be skipped, not buffered");
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Should not fail on unsupported delta type, should skip it: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -4,9 +4,10 @@ use std::fmt;
|
|||
|
||||
/// Unified enum representing all supported API endpoints across providers
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SupportedAPIs {
|
||||
pub enum SupportedAPIsFromClient {
|
||||
OpenAIChatCompletions(OpenAIApi),
|
||||
AnthropicMessagesAPI(AnthropicApi),
|
||||
OpenAIResponsesAPI(OpenAIApi),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
|
@ -15,17 +16,21 @@ pub enum SupportedUpstreamAPIs {
|
|||
AnthropicMessagesAPI(AnthropicApi),
|
||||
AmazonBedrockConverse(AmazonBedrockApi),
|
||||
AmazonBedrockConverseStream(AmazonBedrockApi),
|
||||
OpenAIResponsesAPI(OpenAIApi),
|
||||
}
|
||||
|
||||
impl fmt::Display for SupportedAPIs {
|
||||
impl fmt::Display for SupportedAPIsFromClient {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SupportedAPIs::OpenAIChatCompletions(api) => {
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(api) => {
|
||||
write!(f, "OpenAI ({})", api.endpoint())
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(api) => {
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(api) => {
|
||||
write!(f, "Anthropic AI ({})", api.endpoint())
|
||||
}
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => {
|
||||
write!(f, "OpenAI Responses ({})", api.endpoint())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -45,19 +50,27 @@ impl fmt::Display for SupportedUpstreamAPIs {
|
|||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(api) => {
|
||||
write!(f, "Amazon Bedrock ({})", api.endpoint())
|
||||
}
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(api) => {
|
||||
write!(f, "OpenAI Responses ({})", api.endpoint())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SupportedAPIs {
|
||||
impl SupportedAPIsFromClient {
|
||||
/// Create a SupportedApi from an endpoint path
|
||||
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
||||
return Some(SupportedAPIs::OpenAIChatCompletions(openai_api));
|
||||
// Check if this is the Responses API endpoint
|
||||
if openai_api == OpenAIApi::Responses {
|
||||
return Some(SupportedAPIsFromClient::OpenAIResponsesAPI(openai_api));
|
||||
}
|
||||
// Otherwise it's ChatCompletions
|
||||
return Some(SupportedAPIsFromClient::OpenAIChatCompletions(openai_api));
|
||||
}
|
||||
|
||||
if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) {
|
||||
return Some(SupportedAPIs::AnthropicMessagesAPI(anthropic_api));
|
||||
return Some(SupportedAPIsFromClient::AnthropicMessagesAPI(anthropic_api));
|
||||
}
|
||||
|
||||
None
|
||||
|
|
@ -66,8 +79,9 @@ impl SupportedAPIs {
|
|||
/// Get the endpoint path for this API
|
||||
pub fn endpoint(&self) -> &'static str {
|
||||
match self {
|
||||
SupportedAPIs::OpenAIChatCompletions(api) => api.endpoint(),
|
||||
SupportedAPIs::AnthropicMessagesAPI(api) => api.endpoint(),
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(api) => api.endpoint(),
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(api) => api.endpoint(),
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(api) => api.endpoint(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -94,8 +108,62 @@ impl SupportedAPIs {
|
|||
}
|
||||
};
|
||||
|
||||
// Helper function to route based on provider with a specific endpoint suffix
|
||||
let route_by_provider = |endpoint_suffix: &str| -> String {
|
||||
match provider_id {
|
||||
ProviderId::Groq => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/openai", request_path)
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
ProviderId::Zhipu => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/api/paas/v4", endpoint_suffix)
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
ProviderId::Qwen => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/compatible-mode/v1", endpoint_suffix)
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
let suffix = endpoint_suffix.trim_start_matches('/');
|
||||
build_endpoint("/openai/deployments", &format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix))
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/v1beta/openai", endpoint_suffix)
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
if !is_streaming {
|
||||
build_endpoint("", &format!("/model/{}/converse", model_id))
|
||||
} else {
|
||||
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
|
||||
}
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
}
|
||||
_ => build_endpoint("/v1", endpoint_suffix),
|
||||
}
|
||||
};
|
||||
|
||||
match self {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
|
||||
ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") && !is_streaming {
|
||||
|
|
@ -108,59 +176,57 @@ impl SupportedAPIs {
|
|||
}
|
||||
_ => build_endpoint("/v1", "/chat/completions"),
|
||||
},
|
||||
_ => match provider_id {
|
||||
ProviderId::Groq => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/openai", request_path)
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
|
||||
// For Responses API, check if provider supports it, otherwise translate to chat/completions
|
||||
match provider_id {
|
||||
// OpenAI and compatible providers that support /v1/responses
|
||||
ProviderId::OpenAI => route_by_provider("/responses"),
|
||||
// All other providers: translate to /chat/completions
|
||||
_ => route_by_provider("/chat/completions"),
|
||||
}
|
||||
ProviderId::Zhipu => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/api/paas/v4", "/chat/completions")
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
ProviderId::Qwen => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/compatible-mode/v1", "/chat/completions")
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/openai/deployments", &format!("/{}/chat/completions?api-version=2025-01-01-preview", model_id))
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
build_endpoint("/v1beta/openai", "/chat/completions")
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
if !is_streaming {
|
||||
build_endpoint("", &format!("/model/{}/converse", model_id))
|
||||
} else {
|
||||
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
|
||||
}
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
_ => build_endpoint("/v1", "/chat/completions"),
|
||||
},
|
||||
}
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
|
||||
// For Chat Completions API, use the standard chat/completions path
|
||||
route_by_provider("/chat/completions")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl SupportedUpstreamAPIs {
|
||||
/// Create a SupportedUpstreamApi from an endpoint path
|
||||
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
||||
// Check if this is the Responses API endpoint
|
||||
if openai_api == OpenAIApi::Responses {
|
||||
return Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(openai_api));
|
||||
}
|
||||
// Otherwise it's ChatCompletions
|
||||
return Some(SupportedUpstreamAPIs::OpenAIChatCompletions(openai_api));
|
||||
}
|
||||
|
||||
if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) {
|
||||
return Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(anthropic_api));
|
||||
}
|
||||
|
||||
if let Some(bedrock_api) = AmazonBedrockApi::from_endpoint(endpoint) {
|
||||
match bedrock_api {
|
||||
AmazonBedrockApi::Converse => {
|
||||
return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api))
|
||||
}
|
||||
AmazonBedrockApi::ConverseStream => {
|
||||
return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(bedrock_api))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// Get all supported endpoint paths
|
||||
pub fn supported_endpoints() -> Vec<&'static str> {
|
||||
let mut endpoints = Vec::new();
|
||||
|
|
@ -198,22 +264,23 @@ mod tests {
|
|||
#[test]
|
||||
fn test_is_supported_endpoint() {
|
||||
// OpenAI endpoints
|
||||
assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some());
|
||||
assert!(SupportedAPIsFromClient::from_endpoint("/v1/chat/completions").is_some());
|
||||
// Anthropic endpoints
|
||||
assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
|
||||
assert!(SupportedAPIsFromClient::from_endpoint("/v1/messages").is_some());
|
||||
|
||||
// Unsupported endpoints
|
||||
assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
|
||||
assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some());
|
||||
assert!(!SupportedAPIs::from_endpoint("").is_some());
|
||||
assert!(!SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_some());
|
||||
assert!(!SupportedAPIsFromClient::from_endpoint("/v2/chat").is_some());
|
||||
assert!(!SupportedAPIsFromClient::from_endpoint("").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supported_endpoints() {
|
||||
let endpoints = supported_endpoints();
|
||||
assert_eq!(endpoints.len(), 2); // We have 2 APIs defined
|
||||
assert_eq!(endpoints.len(), 3); // We have 3 APIs defined
|
||||
assert!(endpoints.contains(&"/v1/chat/completions"));
|
||||
assert!(endpoints.contains(&"/v1/messages"));
|
||||
assert!(endpoints.contains(&"/v1/responses"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -263,7 +330,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_target_endpoint_without_base_url_prefix() {
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test default OpenAI provider
|
||||
assert_eq!(
|
||||
|
|
@ -340,7 +407,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_target_endpoint_with_base_url_prefix() {
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test Zhipu with custom base_url_path_prefix
|
||||
assert_eq!(
|
||||
|
|
@ -405,7 +472,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_target_endpoint_with_empty_base_url_prefix() {
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test with just slashes - trims to empty, uses provider default
|
||||
assert_eq!(
|
||||
|
|
@ -434,7 +501,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_amazon_bedrock_endpoints() {
|
||||
let api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
|
||||
// Test Bedrock non-streaming without prefix
|
||||
assert_eq!(
|
||||
|
|
@ -487,7 +554,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_anthropic_messages_endpoint() {
|
||||
let api = SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
|
||||
// Test Anthropic without prefix
|
||||
assert_eq!(
|
||||
|
|
@ -516,7 +583,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_non_v1_request_paths() {
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test Groq with non-v1 path (should use default)
|
||||
assert_eq!(
|
||||
|
|
@ -557,7 +624,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_azure_openai_with_query_params() {
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
// Test Azure without prefix - should include query params
|
||||
assert_eq!(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
pub mod endpoints;
|
||||
pub mod lib;
|
||||
pub mod transformer;
|
||||
|
||||
// Re-export the main items for easier access
|
||||
pub use endpoints::{identify_provider, SupportedAPIs};
|
||||
pub use endpoints::*;
|
||||
pub use lib::*;
|
||||
|
||||
// Note: transformer module contains TryFrom trait implementations that are automatically available
|
||||
|
|
|
|||
|
|
@ -1,694 +0,0 @@
|
|||
// Re-export new transformation modules for backward compatibility
|
||||
|
||||
//KEEPING THE TESTS TO MAKE SURE ALL THE REFACTORING DIDN'T BREAK ANYTHING
|
||||
|
||||
// ============================================================================
|
||||
// TESTS
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::apis::anthropic::*;
|
||||
use crate::apis::openai::*;
|
||||
use crate::transforms::*;
|
||||
use serde_json::json;
|
||||
type AnthropicMessagesRequest = MessagesRequest;
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_to_openai_basic_request() {
|
||||
let anthropic_req = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet-20240229".to_string(),
|
||||
system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())),
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello, world!".to_string()),
|
||||
}],
|
||||
max_tokens: 1024,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
top_k: Some(50),
|
||||
stream: Some(false),
|
||||
stop_sequences: Some(vec!["STOP".to_string()]),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let openai_req: ChatCompletionsRequest = anthropic_req.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_req.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(openai_req.messages.len(), 2); // system + user message
|
||||
assert_eq!(openai_req.max_completion_tokens, Some(1024));
|
||||
assert_eq!(openai_req.temperature, Some(0.7));
|
||||
assert_eq!(openai_req.top_p, Some(0.9));
|
||||
assert_eq!(openai_req.stream, Some(false));
|
||||
assert_eq!(openai_req.stop, Some(vec!["STOP".to_string()]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_consistency() {
|
||||
// Test that converting back and forth maintains consistency
|
||||
let original_anthropic = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
system: Some(MessagesSystemPrompt::Single("System prompt".to_string())),
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("User message".to_string()),
|
||||
}],
|
||||
max_tokens: 1000,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: Some(0.5),
|
||||
top_p: Some(1.0),
|
||||
top_k: None,
|
||||
stream: Some(false),
|
||||
stop_sequences: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
// Convert to OpenAI and back
|
||||
let openai_req: ChatCompletionsRequest = original_anthropic.clone().try_into().unwrap();
|
||||
let roundtrip_anthropic: AnthropicMessagesRequest = openai_req.try_into().unwrap();
|
||||
|
||||
// Check key fields are preserved
|
||||
assert_eq!(original_anthropic.model, roundtrip_anthropic.model);
|
||||
assert_eq!(
|
||||
original_anthropic.max_tokens,
|
||||
roundtrip_anthropic.max_tokens
|
||||
);
|
||||
assert_eq!(
|
||||
original_anthropic.temperature,
|
||||
roundtrip_anthropic.temperature
|
||||
);
|
||||
assert_eq!(original_anthropic.top_p, roundtrip_anthropic.top_p);
|
||||
assert_eq!(original_anthropic.stream, roundtrip_anthropic.stream);
|
||||
assert_eq!(
|
||||
original_anthropic.messages.len(),
|
||||
roundtrip_anthropic.messages.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_choice_auto() {
|
||||
let anthropic_req = AnthropicMessagesRequest {
|
||||
model: "claude-3".to_string(),
|
||||
system: None,
|
||||
messages: vec![],
|
||||
max_tokens: 100,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
tools: Some(vec![MessagesTool {
|
||||
name: "test_tool".to_string(),
|
||||
description: Some("A test tool".to_string()),
|
||||
input_schema: json!({"type": "object"}),
|
||||
}]),
|
||||
tool_choice: Some(MessagesToolChoice {
|
||||
kind: MessagesToolChoiceType::Auto,
|
||||
name: None,
|
||||
disable_parallel_tool_use: Some(true),
|
||||
}),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let openai_req: ChatCompletionsRequest = anthropic_req.try_into().unwrap();
|
||||
|
||||
assert!(openai_req.tools.is_some());
|
||||
assert_eq!(openai_req.tools.as_ref().unwrap().len(), 1);
|
||||
|
||||
if let Some(ToolChoice::Type(choice)) = openai_req.tool_choice {
|
||||
assert_eq!(choice, ToolChoiceType::Auto);
|
||||
} else {
|
||||
panic!("Expected auto tool choice");
|
||||
}
|
||||
|
||||
assert_eq!(openai_req.parallel_tool_calls, Some(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_max_tokens_used_when_openai_has_none() {
|
||||
// Test that DEFAULT_MAX_TOKENS is used when OpenAI request has no max_tokens
|
||||
let openai_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
max_tokens: None, // No max_tokens specified
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let anthropic_req: AnthropicMessagesRequest = openai_req.try_into().unwrap();
|
||||
|
||||
assert_eq!(anthropic_req.max_tokens, DEFAULT_MAX_TOKENS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_message_start_streaming() {
|
||||
let event = MessagesStreamEvent::MessageStart {
|
||||
message: MessagesStreamMessage {
|
||||
id: "msg_stream_123".to_string(),
|
||||
obj_type: "message".to_string(),
|
||||
role: MessagesRole::Assistant,
|
||||
content: vec![],
|
||||
model: "claude-3".to_string(),
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 5,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.id, "msg_stream_123");
|
||||
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
|
||||
assert_eq!(openai_resp.model, "claude-3");
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert_eq!(choice.index, 0);
|
||||
assert_eq!(choice.delta.role, Some(Role::Assistant));
|
||||
assert_eq!(choice.delta.content, None);
|
||||
assert_eq!(choice.finish_reason, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_content_block_delta_streaming() {
|
||||
let event = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta {
|
||||
text: "Hello, world!".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert_eq!(choice.index, 0);
|
||||
assert_eq!(choice.delta.content, Some("Hello, world!".to_string()));
|
||||
assert_eq!(choice.delta.role, None);
|
||||
assert_eq!(choice.finish_reason, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_tool_use_streaming() {
|
||||
// Test tool use start
|
||||
let tool_start = MessagesStreamEvent::ContentBlockStart {
|
||||
index: 0,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: "call_123".to_string(),
|
||||
name: "get_weather".to_string(),
|
||||
input: json!({}),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = tool_start.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert!(choice.delta.tool_calls.is_some());
|
||||
|
||||
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id, Some("call_123".to_string()));
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().name,
|
||||
Some("get_weather".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_tool_input_delta_streaming() {
|
||||
let event = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: r#"{"location": "San Francisco"#.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert!(choice.delta.tool_calls.is_some());
|
||||
|
||||
let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().arguments,
|
||||
Some(r#"{"location": "San Francisco"#.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_message_delta_with_usage() {
|
||||
let event = MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: MessagesStopReason::EndTurn,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 10,
|
||||
output_tokens: 25,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert_eq!(choice.finish_reason, Some(FinishReason::Stop));
|
||||
|
||||
assert!(openai_resp.usage.is_some());
|
||||
let usage = openai_resp.usage.unwrap();
|
||||
assert_eq!(usage.prompt_tokens, 10);
|
||||
assert_eq!(usage.completion_tokens, 25);
|
||||
assert_eq!(usage.total_tokens, 35);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_message_stop_streaming() {
|
||||
let event = MessagesStreamEvent::MessageStop;
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert_eq!(choice.finish_reason, Some(FinishReason::Stop));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_ping_streaming() {
|
||||
let event = MessagesStreamEvent::Ping;
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
|
||||
assert_eq!(openai_resp.choices.len(), 0); // Ping has no choices
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_streaming_role_start() {
|
||||
let openai_resp = ChatCompletionsStreamResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role: Some(Role::Assistant),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::MessageStart { message } => {
|
||||
assert_eq!(message.id, "chatcmpl-123");
|
||||
assert_eq!(message.role, MessagesRole::Assistant);
|
||||
assert_eq!(message.model, "gpt-4");
|
||||
}
|
||||
_ => panic!("Expected MessageStart event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_streaming_content_delta() {
|
||||
let openai_resp = ChatCompletionsStreamResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role: None,
|
||||
content: Some("Hello there!".to_string()),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::ContentBlockDelta { index, delta } => {
|
||||
assert_eq!(index, 0);
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => {
|
||||
assert_eq!(text, "Hello there!");
|
||||
}
|
||||
_ => panic!("Expected TextDelta"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected ContentBlockDelta event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_streaming_tool_calls() {
|
||||
let openai_resp = ChatCompletionsStreamResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: Some("call_abc123".to_string()),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some("get_current_weather".to_string()),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => {
|
||||
assert_eq!(index, 0);
|
||||
match content_block {
|
||||
MessagesContentBlock::ToolUse { id, name, .. } => {
|
||||
assert_eq!(id, "call_abc123");
|
||||
assert_eq!(name, "get_current_weather");
|
||||
}
|
||||
_ => panic!("Expected ToolUse content block"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected ContentBlockStart event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_streaming_final_usage() {
|
||||
let openai_resp = ChatCompletionsStreamResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: Some(FinishReason::Stop),
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: Some(Usage {
|
||||
prompt_tokens: 15,
|
||||
completion_tokens: 30,
|
||||
total_tokens: 45,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
}),
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::MessageDelta { delta, usage } => {
|
||||
assert_eq!(delta.stop_reason, MessagesStopReason::EndTurn);
|
||||
assert_eq!(usage.input_tokens, 15);
|
||||
assert_eq!(usage.output_tokens, 30);
|
||||
}
|
||||
_ => panic!("Expected MessageDelta event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_empty_choices_to_anthropic_ping() {
|
||||
let openai_resp = ChatCompletionsStreamResponse {
|
||||
id: "chatcmpl-123".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![], // Empty choices
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
let anthropic_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
match anthropic_event {
|
||||
MessagesStreamEvent::Ping => {
|
||||
// Expected behavior
|
||||
}
|
||||
_ => panic!("Expected Ping event for empty choices"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_roundtrip_consistency() {
|
||||
// Test that streaming events can roundtrip through conversions
|
||||
let original_event = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta {
|
||||
text: "Test message".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
// Convert to OpenAI and back
|
||||
let openai_resp: ChatCompletionsStreamResponse = original_event.try_into().unwrap();
|
||||
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
|
||||
// Verify the roundtrip maintains the essential information
|
||||
match roundtrip_event {
|
||||
MessagesStreamEvent::ContentBlockDelta { index, delta } => {
|
||||
assert_eq!(index, 0);
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => {
|
||||
assert_eq!(text, "Test message");
|
||||
}
|
||||
_ => panic!("Expected TextDelta after roundtrip"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected ContentBlockDelta after roundtrip"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_tool_argument_accumulation() {
|
||||
// Test multiple tool argument deltas that should accumulate
|
||||
let tool_start = MessagesStreamEvent::ContentBlockStart {
|
||||
index: 0,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: "call_weather".to_string(),
|
||||
name: "get_weather".to_string(),
|
||||
input: json!({}),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
|
||||
let arg_delta1 = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: r#"{"location": "#.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let arg_delta2 = MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: r#"San Francisco", "unit": "fahrenheit"}"#.to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
// Test that each delta converts properly to OpenAI format
|
||||
let openai_start: ChatCompletionsStreamResponse = tool_start.try_into().unwrap();
|
||||
let openai_delta1: ChatCompletionsStreamResponse = arg_delta1.try_into().unwrap();
|
||||
let openai_delta2: ChatCompletionsStreamResponse = arg_delta2.try_into().unwrap();
|
||||
|
||||
// Verify tool start
|
||||
let tool_calls = &openai_start.choices[0].delta.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls[0].id, Some("call_weather".to_string()));
|
||||
assert_eq!(
|
||||
tool_calls[0].function.as_ref().unwrap().name,
|
||||
Some("get_weather".to_string())
|
||||
);
|
||||
|
||||
// Verify argument deltas
|
||||
let args1 = &openai_delta1.choices[0].delta.tool_calls.as_ref().unwrap()[0]
|
||||
.function
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.arguments;
|
||||
assert_eq!(args1, &Some(r#"{"location": "#.to_string()));
|
||||
|
||||
let args2 = &openai_delta2.choices[0].delta.tool_calls.as_ref().unwrap()[0]
|
||||
.function
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.arguments;
|
||||
assert_eq!(
|
||||
args2,
|
||||
&Some(r#"San Francisco", "unit": "fahrenheit"}"#.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_multiple_finish_reasons() {
|
||||
// Test different finish reasons in streaming
|
||||
let test_cases = vec![
|
||||
(MessagesStopReason::EndTurn, FinishReason::Stop),
|
||||
(MessagesStopReason::MaxTokens, FinishReason::Length),
|
||||
(MessagesStopReason::ToolUse, FinishReason::ToolCalls),
|
||||
(MessagesStopReason::StopSequence, FinishReason::Stop),
|
||||
];
|
||||
|
||||
for (anthropic_reason, expected_openai_reason) in test_cases {
|
||||
let event = MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_reason.clone(),
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 10,
|
||||
output_tokens: 20,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
};
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
assert_eq!(
|
||||
openai_resp.choices[0].finish_reason,
|
||||
Some(expected_openai_reason)
|
||||
);
|
||||
|
||||
// Test reverse conversion
|
||||
let roundtrip_event: MessagesStreamEvent = openai_resp.try_into().unwrap();
|
||||
match roundtrip_event {
|
||||
MessagesStreamEvent::MessageDelta { delta, .. } => {
|
||||
// Note: Some precision may be lost in roundtrip due to mapping differences
|
||||
assert!(matches!(
|
||||
delta.stop_reason,
|
||||
MessagesStopReason::EndTurn
|
||||
| MessagesStopReason::MaxTokens
|
||||
| MessagesStopReason::ToolUse
|
||||
| MessagesStopReason::StopSequence
|
||||
));
|
||||
}
|
||||
_ => panic!("Expected MessageDelta after roundtrip"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_error_handling() {
|
||||
// Test that malformed streaming events are handled gracefully
|
||||
let openai_resp_with_missing_data = ChatCompletionsStreamResponse {
|
||||
id: "test".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: 1234567890,
|
||||
model: "test".to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta: MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
};
|
||||
|
||||
// Should convert to Ping when no meaningful content
|
||||
let anthropic_event: MessagesStreamEvent =
|
||||
openai_resp_with_missing_data.try_into().unwrap();
|
||||
assert!(matches!(anthropic_event, MessagesStreamEvent::Ping));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_content_block_stop() {
|
||||
let event = MessagesStreamEvent::ContentBlockStop { index: 0 };
|
||||
|
||||
let openai_resp: ChatCompletionsStreamResponse = event.try_into().unwrap();
|
||||
|
||||
// ContentBlockStop should produce an empty chunk
|
||||
assert_eq!(openai_resp.object.as_deref(), Some("chat.completion.chunk"));
|
||||
assert_eq!(openai_resp.choices.len(), 1);
|
||||
|
||||
let choice = &openai_resp.choices[0];
|
||||
assert_eq!(choice.delta.role, None);
|
||||
assert_eq!(choice.delta.content, None);
|
||||
assert_eq!(choice.delta.tool_calls, None);
|
||||
assert_eq!(choice.finish_reason, None);
|
||||
}
|
||||
}
|
||||
|
|
@ -6,18 +6,21 @@ pub mod clients;
|
|||
pub mod providers;
|
||||
pub mod transforms;
|
||||
// Re-export important types and traits
|
||||
pub use apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
pub use apis::sse::{SseEvent, SseStreamIter};
|
||||
pub use apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
pub use apis::streaming_shapes::sse::{SseEvent, SseStreamIter};
|
||||
pub use aws_smithy_eventstream::frame::DecodedFrame;
|
||||
pub use providers::id::ProviderId;
|
||||
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use providers::response::{
|
||||
ProviderResponse, ProviderResponseError, ProviderResponseType, ProviderStreamResponse,
|
||||
ProviderStreamResponseType, TokenUsage,
|
||||
ProviderResponse, ProviderResponseType, TokenUsage, ProviderResponseError
|
||||
};
|
||||
pub use providers::streaming_response::{
|
||||
ProviderStreamResponse, ProviderStreamResponseType
|
||||
};
|
||||
|
||||
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses";
|
||||
pub const MESSAGES_PATH: &str = "/v1/messages";
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -42,9 +45,9 @@ mod tests {
|
|||
data: [DONE]
|
||||
"#;
|
||||
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
let client_api =
|
||||
SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
let upstream_api =
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
|
||||
|
||||
|
|
@ -79,9 +82,16 @@ mod tests {
|
|||
assert_eq!(stream_response.content_delta(), Some("Hello"));
|
||||
assert!(!stream_response.is_final());
|
||||
|
||||
// Test that stream ends properly with [DONE] (SseStreamIter should stop before [DONE])
|
||||
// Test that stream ends properly with [DONE]
|
||||
// The iterator should return the [DONE] event, then None
|
||||
let done_event = streaming_iter.next();
|
||||
assert!(done_event.is_some(), "Should get [DONE] event");
|
||||
let done_event = done_event.unwrap();
|
||||
assert!(done_event.is_done(), "[DONE] event should be marked as done");
|
||||
|
||||
// After [DONE], iterator should return None
|
||||
let final_event = streaming_iter.next();
|
||||
assert!(final_event.is_none()); // Should be None because iterator stops at [DONE]
|
||||
assert!(final_event.is_none(), "Iterator should return None after [DONE]");
|
||||
}
|
||||
|
||||
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
|
||||
use crate::clients::endpoints::{SupportedAPIs, SupportedUpstreamAPIs};
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
|
|
@ -51,19 +51,24 @@ impl ProviderId {
|
|||
/// Given a client API, return the compatible upstream API for this provider
|
||||
pub fn compatible_api_for_client(
|
||||
&self,
|
||||
client_api: &SupportedAPIs,
|
||||
client_api: &SupportedAPIsFromClient,
|
||||
is_streaming: bool,
|
||||
) -> SupportedUpstreamAPIs {
|
||||
match (self, client_api) {
|
||||
// Claude/Anthropic providers natively support Anthropic APIs
|
||||
(ProviderId::Anthropic, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
(ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
}
|
||||
(
|
||||
ProviderId::Anthropic,
|
||||
SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// Anthropic doesn't support Responses API, fall back to chat completions
|
||||
(ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
|
||||
}
|
||||
|
||||
// OpenAI-compatible providers only support OpenAI chat completions
|
||||
(
|
||||
ProviderId::OpenAI
|
||||
|
|
@ -80,7 +85,7 @@ impl ProviderId {
|
|||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
(
|
||||
|
|
@ -98,11 +103,16 @@ impl ProviderId {
|
|||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen,
|
||||
SupportedAPIs::OpenAIChatCompletions(_),
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
// OpenAI Responses API - only OpenAI supports this
|
||||
(ProviderId::OpenAI, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses)
|
||||
}
|
||||
|
||||
// Amazon Bedrock natively supports Bedrock APIs
|
||||
(ProviderId::AmazonBedrock, SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
|
|
@ -111,7 +121,7 @@ impl ProviderId {
|
|||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
(ProviderId::AmazonBedrock, SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
|
|
@ -120,6 +130,20 @@ impl ProviderId {
|
|||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
(ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||
if is_streaming {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
AmazonBedrockApi::ConverseStream,
|
||||
)
|
||||
} else {
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse)
|
||||
}
|
||||
}
|
||||
|
||||
// Non-OpenAI providers: if client requested the Responses API, fall back to Chat Completions
|
||||
(_, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@
|
|||
pub mod id;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod streaming_response;
|
||||
|
||||
pub use id::ProviderId;
|
||||
pub use request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use response::{ProviderResponse, ProviderResponseType, ProviderStreamResponse, TokenUsage};
|
||||
pub use response::{ProviderResponse, ProviderResponseType, TokenUsage};
|
||||
pub use streaming_response::{ProviderStreamResponse, ProviderStreamResponseType};
|
||||
|
|
|
|||
|
|
@ -2,19 +2,21 @@ use crate::apis::anthropic::MessagesRequest;
|
|||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
|
||||
use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest};
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::apis::openai_responses::ResponsesAPIRequest;
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ProviderRequestType {
|
||||
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||
MessagesRequest(MessagesRequest),
|
||||
BedrockConverse(ConverseRequest),
|
||||
BedrockConverseStream(ConverseStreamRequest),
|
||||
ResponsesAPIRequest(ResponsesAPIRequest),
|
||||
//add more request types here
|
||||
}
|
||||
pub trait ProviderRequest: Send + Sync {
|
||||
|
|
@ -33,6 +35,9 @@ pub trait ProviderRequest: Send + Sync {
|
|||
/// Extract the user message for tracing/logging purposes
|
||||
fn get_recent_user_message(&self) -> Option<String>;
|
||||
|
||||
/// Get tool names if tools are defined in the request
|
||||
fn get_tool_names(&self) -> Option<Vec<String>>;
|
||||
|
||||
/// Convert the request to bytes for transmission
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
|
||||
|
||||
|
|
@ -40,6 +45,30 @@ pub trait ProviderRequest: Send + Sync {
|
|||
|
||||
/// Remove a metadata key from the request and return true if the key was present
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool;
|
||||
|
||||
fn get_temperature(&self) -> Option<f32>;
|
||||
|
||||
/// Get message history as OpenAI Message format
|
||||
/// This is useful for processing chat history across different provider formats
|
||||
fn get_messages(&self) -> Vec<crate::apis::openai::Message>;
|
||||
|
||||
/// Set message history from OpenAI Message format
|
||||
/// This converts OpenAI messages to the appropriate format for each provider type
|
||||
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]);
|
||||
}
|
||||
|
||||
impl ProviderRequestType {
|
||||
/// Set message history from OpenAI Message format
|
||||
/// This converts OpenAI messages to the appropriate format for each provider type
|
||||
pub fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.set_messages(messages),
|
||||
Self::MessagesRequest(r) => r.set_messages(messages),
|
||||
Self::BedrockConverse(r) => r.set_messages(messages),
|
||||
Self::BedrockConverseStream(r) => r.set_messages(messages),
|
||||
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderRequest for ProviderRequestType {
|
||||
|
|
@ -49,6 +78,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.model(),
|
||||
Self::BedrockConverse(r) => r.model(),
|
||||
Self::BedrockConverseStream(r) => r.model(),
|
||||
Self::ResponsesAPIRequest(r) => r.model(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -58,6 +88,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.set_model(model),
|
||||
Self::BedrockConverse(r) => r.set_model(model),
|
||||
Self::BedrockConverseStream(r) => r.set_model(model),
|
||||
Self::ResponsesAPIRequest(r) => r.set_model(model),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -67,6 +98,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.is_streaming(),
|
||||
Self::BedrockConverse(_) => false,
|
||||
Self::BedrockConverseStream(_) => true,
|
||||
Self::ResponsesAPIRequest(r) => r.is_streaming(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -76,6 +108,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.extract_messages_text(),
|
||||
Self::BedrockConverse(r) => r.extract_messages_text(),
|
||||
Self::BedrockConverseStream(r) => r.extract_messages_text(),
|
||||
Self::ResponsesAPIRequest(r) => r.extract_messages_text(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -85,6 +118,17 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.get_recent_user_message(),
|
||||
Self::BedrockConverse(r) => r.get_recent_user_message(),
|
||||
Self::BedrockConverseStream(r) => r.get_recent_user_message(),
|
||||
Self::ResponsesAPIRequest(r) => r.get_recent_user_message(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tool_names(&self) -> Option<Vec<String>> {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.get_tool_names(),
|
||||
Self::MessagesRequest(r) => r.get_tool_names(),
|
||||
Self::BedrockConverse(r) => r.get_tool_names(),
|
||||
Self::BedrockConverseStream(r) => r.get_tool_names(),
|
||||
Self::ResponsesAPIRequest(r) => r.get_tool_names(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -94,6 +138,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.to_bytes(),
|
||||
Self::BedrockConverse(r) => r.to_bytes(),
|
||||
Self::BedrockConverseStream(r) => r.to_bytes(),
|
||||
Self::ResponsesAPIRequest(r) => r.to_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -103,6 +148,7 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.metadata(),
|
||||
Self::BedrockConverse(r) => r.metadata(),
|
||||
Self::BedrockConverseStream(r) => r.metadata(),
|
||||
Self::ResponsesAPIRequest(r) => r.metadata(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -112,18 +158,49 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.remove_metadata_key(key),
|
||||
Self::BedrockConverse(r) => r.remove_metadata_key(key),
|
||||
Self::BedrockConverseStream(r) => r.remove_metadata_key(key),
|
||||
Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_temperature(&self) -> Option<f32> {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.get_temperature(),
|
||||
Self::MessagesRequest(r) => r.get_temperature(),
|
||||
Self::BedrockConverse(r) => r.get_temperature(),
|
||||
Self::BedrockConverseStream(r) => r.get_temperature(),
|
||||
Self::ResponsesAPIRequest(r) => r.get_temperature(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_messages(&self) -> Vec<crate::apis::openai::Message> {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.get_messages(),
|
||||
Self::MessagesRequest(r) => r.get_messages(),
|
||||
Self::BedrockConverse(r) => r.get_messages(),
|
||||
Self::BedrockConverseStream(r) => r.get_messages(),
|
||||
Self::ResponsesAPIRequest(r) => r.get_messages(),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.set_messages(messages),
|
||||
Self::MessagesRequest(r) => r.set_messages(messages),
|
||||
Self::BedrockConverse(r) => r.set_messages(messages),
|
||||
Self::BedrockConverseStream(r) => r.set_messages(messages),
|
||||
Self::ResponsesAPIRequest(r) => r.set_messages(messages),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the client API from a byte slice.
|
||||
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
||||
impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, client_api): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
|
||||
fn try_from((bytes, client_api): (&[u8], &SupportedAPIsFromClient)) -> Result<Self, Self::Error> {
|
||||
// Use SupportedApi to determine the appropriate request type
|
||||
match client_api {
|
||||
SupportedAPIs::OpenAIChatCompletions(_) => {
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
|
||||
let chat_completion_request: ChatCompletionsRequest =
|
||||
ChatCompletionsRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
|
@ -131,11 +208,20 @@ impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
|
|||
chat_completion_request,
|
||||
))
|
||||
}
|
||||
SupportedAPIs::AnthropicMessagesAPI(_) => {
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_) => {
|
||||
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_request))
|
||||
}
|
||||
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
|
||||
let responses_apirequest: ResponsesAPIRequest =
|
||||
ResponsesAPIRequest::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderRequestType::ResponsesAPIRequest(
|
||||
responses_apirequest,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -148,17 +234,13 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
(client_request, upstream_api): (ProviderRequestType, &SupportedUpstreamAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
match (client_request, upstream_api) {
|
||||
// Same API - no conversion needed, just clone the reference
|
||||
// ============================================================================
|
||||
// ChatCompletionsRequest conversions
|
||||
// ============================================================================
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)),
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
|
||||
|
||||
// Cross-API conversion - cloning is necessary for transformation
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
|
|
@ -173,7 +255,45 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock Stream request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(_),
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
Err(ProviderRequestError {
|
||||
message: "Conversion from ChatCompletionsRequest to ResponsesAPIRequest is not supported. ResponsesAPI can only be used as a client API, not as an upstream API.".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MessagesRequest conversions
|
||||
// ============================================================================
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => Ok(ProviderRequestType::MessagesRequest(messages_req)),
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
|
|
@ -189,31 +309,6 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
|
||||
// Cross-API conversions: OpenAI/Anthropic to Amazon Bedrock
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
|
||||
(
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
let bedrock_req = ConverseStreamRequest::try_from(chat_req)
|
||||
.map_err(|e| ProviderRequestError {
|
||||
message: format!("Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}", e),
|
||||
source: Some(Box::new(e))
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
|
|
@ -235,7 +330,97 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
let bedrock_req = ConverseStreamRequest::try_from(messages_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert MessagesRequest to Amazon Bedrock request: {}",
|
||||
"Failed to convert MessagesRequest to Amazon Bedrock Stream request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
||||
}
|
||||
(
|
||||
ProviderRequestType::MessagesRequest(_),
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
Err(ProviderRequestError {
|
||||
message: "Conversion from MessagesRequest to ResponsesAPIRequest is not supported. ResponsesAPI can only be used as a client API, not as an upstream API.".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ResponsesAPIRequest conversions (only converts TO other formats)
|
||||
// ============================================================================
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
) => Ok(ProviderRequestType::ResponsesAPIRequest(responses_req)),
|
||||
|
||||
// ResponsesAPI -> ChatCompletions (direct conversion)
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => {
|
||||
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(chat_req))
|
||||
}
|
||||
|
||||
// ResponsesAPI -> Anthropic Messages (via ChatCompletions)
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => {
|
||||
// Chain: ResponsesAPI -> ChatCompletions -> MessagesRequest
|
||||
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
|
||||
let messages_req = MessagesRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to MessagesRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::MessagesRequest(messages_req))
|
||||
}
|
||||
|
||||
// ResponsesAPI -> Bedrock Converse (via ChatCompletions)
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
) => {
|
||||
// Chain: ResponsesAPI -> ChatCompletions -> ConverseRequest
|
||||
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
|
||||
let bedrock_req = ConverseRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to Amazon Bedrock request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
|
|
@ -244,13 +429,50 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT
|
|||
Ok(ProviderRequestType::BedrockConverse(bedrock_req))
|
||||
}
|
||||
|
||||
// Amazon Bedrock to other APIs conversions
|
||||
// ResponsesAPI -> Bedrock Converse Stream (via ChatCompletions)
|
||||
(
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
) => {
|
||||
// Chain: ResponsesAPI -> ChatCompletions -> ConverseStreamRequest
|
||||
let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
|
||||
let bedrock_req = ConverseStreamRequest::try_from(chat_req).map_err(|e| {
|
||||
ProviderRequestError {
|
||||
message: format!(
|
||||
"Failed to convert ChatCompletionsRequest to Amazon Bedrock Stream request: {}",
|
||||
e
|
||||
),
|
||||
source: Some(Box::new(e)),
|
||||
}
|
||||
})?;
|
||||
Ok(ProviderRequestType::BedrockConverseStream(bedrock_req))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Amazon Bedrock conversions (not supported as client API)
|
||||
// ============================================================================
|
||||
|
||||
(ProviderRequestType::BedrockConverse(_), _) => {
|
||||
todo!("Amazon Bedrock to ChatCompletionsRequest conversion not implemented yet")
|
||||
Err(ProviderRequestError {
|
||||
message: "Amazon Bedrock Converse is not supported as a client API. Only OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses APIs are supported as client APIs.".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
(ProviderRequestType::BedrockConverseStream(_), _) => {
|
||||
todo!("Amazon Bedrock Stream to ChatCompletionsRequest conversion not implemented yet")
|
||||
Err(ProviderRequestError {
|
||||
message: "Amazon Bedrock Converse Stream is not supported as a client API. Only OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses APIs are supported as client APIs.".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -284,7 +506,7 @@ mod tests {
|
|||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use serde_json::json;
|
||||
|
||||
|
|
@ -298,7 +520,7 @@ mod tests {
|
|||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let api = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let api = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &api));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
|
|
@ -321,7 +543,7 @@ mod tests {
|
|||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let endpoint = SupportedAPIs::AnthropicMessagesAPI(Messages);
|
||||
let endpoint = SupportedAPIsFromClient::AnthropicMessagesAPI(Messages);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
|
|
@ -343,7 +565,7 @@ mod tests {
|
|||
]
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let endpoint = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
|
|
@ -366,7 +588,7 @@ mod tests {
|
|||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
// Intentionally use OpenAI endpoint for Anthropic payload
|
||||
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let endpoint = SupportedAPIsFromClient::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
|
||||
// Should parse as ChatCompletionsRequest, not error
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -486,4 +708,399 @@ mod tests {
|
|||
let roundtrip_max_tokens = openai_req2.max_completion_tokens.or(openai_req2.max_tokens);
|
||||
assert_eq!(original_max_tokens, roundtrip_max_tokens);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_api_request_from_bytes() {
|
||||
use crate::apis::openai::OpenAIApi::Responses;
|
||||
|
||||
let req = json!({
|
||||
"model": "gpt-4o",
|
||||
"input": "Hello, how are you?"
|
||||
});
|
||||
let bytes = serde_json::to_vec(&req).unwrap();
|
||||
let api = SupportedAPIsFromClient::OpenAIResponsesAPI(Responses);
|
||||
let result = ProviderRequestType::try_from((bytes.as_slice(), &api));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ResponsesAPIRequest(r) => {
|
||||
assert_eq!(r.model, "gpt-4o");
|
||||
}
|
||||
_ => panic!("Expected ResponsesAPIRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_api_to_chat_completions_conversion() {
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
|
||||
|
||||
let responses_req = ResponsesAPIRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
input: InputParam::Text("Hello, world!".to_string()),
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(0.9),
|
||||
max_output_tokens: Some(100),
|
||||
stream: Some(false),
|
||||
metadata: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
instructions: None,
|
||||
modalities: None,
|
||||
user: None,
|
||||
store: None,
|
||||
reasoning_effort: None,
|
||||
include: None,
|
||||
audio: None,
|
||||
text: None,
|
||||
service_tier: None,
|
||||
top_logprobs: None,
|
||||
stream_options: None,
|
||||
truncation: None,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
max_tool_calls: None,
|
||||
background: None,
|
||||
};
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req) => {
|
||||
assert_eq!(chat_req.model, "gpt-4o");
|
||||
assert_eq!(chat_req.temperature, Some(0.7));
|
||||
assert_eq!(chat_req.top_p, Some(0.9));
|
||||
assert_eq!(chat_req.max_completion_tokens, Some(100));
|
||||
assert_eq!(chat_req.messages.len(), 1);
|
||||
}
|
||||
_ => panic!("Expected ChatCompletionsRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_api_to_anthropic_messages_conversion() {
|
||||
use crate::apis::anthropic::AnthropicApi::Messages;
|
||||
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
|
||||
|
||||
let responses_req = ResponsesAPIRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
input: InputParam::Text("Hello, Claude!".to_string()),
|
||||
temperature: Some(0.8),
|
||||
max_output_tokens: Some(150),
|
||||
stream: Some(false),
|
||||
metadata: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
instructions: Some("You are a helpful assistant".to_string()),
|
||||
modalities: None,
|
||||
user: None,
|
||||
store: None,
|
||||
reasoning_effort: None,
|
||||
include: None,
|
||||
audio: None,
|
||||
text: None,
|
||||
service_tier: None,
|
||||
top_p: None,
|
||||
top_logprobs: None,
|
||||
stream_options: None,
|
||||
truncation: None,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
max_tool_calls: None,
|
||||
background: None,
|
||||
};
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(Messages);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::MessagesRequest(messages_req) => {
|
||||
assert_eq!(messages_req.model, "gpt-4o");
|
||||
assert_eq!(messages_req.temperature, Some(0.8));
|
||||
assert_eq!(messages_req.max_tokens, 150);
|
||||
// Instructions should be converted to system prompt via ChatCompletions conversion
|
||||
// The conversion chain: ResponsesAPI -> ChatCompletions (system message) -> Anthropic (system prompt)
|
||||
// But we need to check if the system prompt was actually set
|
||||
assert_eq!(messages_req.messages.len(), 1);
|
||||
}
|
||||
_ => panic!("Expected MessagesRequest variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_api_to_bedrock_conversion() {
|
||||
use crate::apis::amazon_bedrock::AmazonBedrockApi::Converse;
|
||||
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
|
||||
|
||||
let responses_req = ResponsesAPIRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
input: InputParam::Text("Hello, Bedrock!".to_string()),
|
||||
temperature: Some(0.5),
|
||||
max_output_tokens: Some(200),
|
||||
stream: Some(false),
|
||||
metadata: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
instructions: None,
|
||||
modalities: None,
|
||||
user: None,
|
||||
store: None,
|
||||
reasoning_effort: None,
|
||||
include: None,
|
||||
audio: None,
|
||||
text: None,
|
||||
service_tier: None,
|
||||
top_p: None,
|
||||
top_logprobs: None,
|
||||
stream_options: None,
|
||||
truncation: None,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
max_tool_calls: None,
|
||||
background: None,
|
||||
};
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverse(Converse);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::ResponsesAPIRequest(responses_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderRequestType::BedrockConverse(bedrock_req) => {
|
||||
assert_eq!(bedrock_req.model_id, "gpt-4o");
|
||||
// Bedrock receives the converted request through ChatCompletions
|
||||
assert!(!bedrock_req.messages.is_none());
|
||||
}
|
||||
_ => panic!("Expected BedrockConverse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_not_supported() {
|
||||
use crate::apis::openai::OpenAIApi::Responses;
|
||||
use crate::apis::openai::{Message, MessageContent, Role};
|
||||
|
||||
let chat_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIResponsesAPI(Responses);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::ChatCompletionsRequest(chat_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.message.contains("ResponsesAPI can only be used as a client API"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_messages_to_responses_api_not_supported() {
|
||||
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
|
||||
use crate::apis::openai::OpenAIApi::Responses;
|
||||
|
||||
let messages_req = AnthropicMessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
messages: vec![crate::apis::anthropic::MessagesMessage {
|
||||
role: crate::apis::anthropic::MessagesRole::User,
|
||||
content: crate::apis::anthropic::MessagesMessageContent::Single(
|
||||
"Hello!".to_string(),
|
||||
),
|
||||
}],
|
||||
max_tokens: 100,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
system: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIResponsesAPI(Responses);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::MessagesRequest(messages_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.message.contains("ResponsesAPI can only be used as a client API"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bedrock_as_client_api_not_supported() {
|
||||
use crate::apis::openai::OpenAIApi::ChatCompletions;
|
||||
|
||||
// Create a simple Bedrock request (we'll use Default if available, or minimal construction)
|
||||
let bedrock_req = ConverseRequest::default();
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(ChatCompletions);
|
||||
let result = ProviderRequestType::try_from((
|
||||
ProviderRequestType::BedrockConverse(bedrock_req),
|
||||
&upstream_api,
|
||||
));
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.message.contains("not supported as a client API"));
|
||||
assert!(err
|
||||
.message
|
||||
.contains("OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_message_history_chat_completions() {
|
||||
use crate::apis::openai::{Message, MessageContent, Role};
|
||||
|
||||
let chat_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are helpful".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let provider_req = ProviderRequestType::ChatCompletionsRequest(chat_req);
|
||||
let messages = provider_req.get_messages();
|
||||
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].role, Role::System);
|
||||
assert_eq!(messages[1].role, Role::User);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_message_history_anthropic_messages() {
|
||||
use crate::apis::anthropic::{
|
||||
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
|
||||
MessagesSystemPrompt,
|
||||
};
|
||||
|
||||
let anthropic_req = MessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello!".to_string()),
|
||||
}],
|
||||
system: Some(MessagesSystemPrompt::Single(
|
||||
"You are helpful".to_string(),
|
||||
)),
|
||||
max_tokens: 100,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
metadata: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
let provider_req = ProviderRequestType::MessagesRequest(anthropic_req);
|
||||
let messages = provider_req.get_messages();
|
||||
|
||||
// Should have system message + user message
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(
|
||||
messages[0].role,
|
||||
crate::apis::openai::Role::System
|
||||
);
|
||||
assert_eq!(
|
||||
messages[1].role,
|
||||
crate::apis::openai::Role::User
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_message_history_responses_api() {
|
||||
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
|
||||
|
||||
let responses_req = ResponsesAPIRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
input: InputParam::Text("Hello, world!".to_string()),
|
||||
instructions: Some("Be helpful".to_string()),
|
||||
temperature: None,
|
||||
max_output_tokens: None,
|
||||
stream: None,
|
||||
metadata: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
modalities: None,
|
||||
user: None,
|
||||
store: None,
|
||||
reasoning_effort: None,
|
||||
include: None,
|
||||
audio: None,
|
||||
text: None,
|
||||
service_tier: None,
|
||||
top_p: None,
|
||||
top_logprobs: None,
|
||||
stream_options: None,
|
||||
truncation: None,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
max_tool_calls: None,
|
||||
background: None,
|
||||
};
|
||||
|
||||
let provider_req = ProviderRequestType::ResponsesAPIRequest(responses_req);
|
||||
let messages = provider_req.get_messages();
|
||||
|
||||
// Should have system message (instructions) + user message (input)
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(
|
||||
messages[0].role,
|
||||
crate::apis::openai::Role::System
|
||||
);
|
||||
assert_eq!(
|
||||
messages[1].role,
|
||||
crate::apis::openai::Role::User
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
1348
crates/hermesllm/src/providers/streaming_response.rs
Normal file
1348
crates/hermesllm/src/providers/streaming_response.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -11,11 +11,13 @@
|
|||
pub mod lib;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod response_streaming;
|
||||
|
||||
// Re-export commonly used items for convenience
|
||||
pub use lib::*;
|
||||
pub use request::*;
|
||||
pub use response::*;
|
||||
pub use response_streaming::*;
|
||||
|
||||
// ============================================================================
|
||||
// CONSTANTS
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@ use crate::apis::anthropic::{
|
|||
use crate::apis::openai::{
|
||||
ChatCompletionsRequest, Message, MessageContent, Role, Tool, ToolChoice, ToolChoiceType,
|
||||
};
|
||||
|
||||
use crate::apis::openai_responses::{
|
||||
ResponsesAPIRequest, InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
use crate::transforms::lib::*;
|
||||
|
|
@ -244,6 +248,212 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: ResponsesAPIRequest) -> Result<Self, Self::Error> {
|
||||
|
||||
// Convert input to messages
|
||||
let messages = match req.input {
|
||||
InputParam::Text(text) => {
|
||||
// Simple text input becomes a user message
|
||||
vec![Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text(text),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}]
|
||||
}
|
||||
InputParam::Items(items) => {
|
||||
// Convert input items to messages
|
||||
let mut converted_messages = Vec::new();
|
||||
|
||||
// Add instructions as system message if present
|
||||
if let Some(instructions) = &req.instructions {
|
||||
converted_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(instructions.clone()),
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Convert each input item
|
||||
for item in items {
|
||||
match item {
|
||||
InputItem::Message(input_msg) => {
|
||||
let role = match input_msg.role {
|
||||
MessageRole::User => Role::User,
|
||||
MessageRole::Assistant => Role::Assistant,
|
||||
MessageRole::System => Role::System,
|
||||
MessageRole::Developer => Role::System, // Map developer to system
|
||||
};
|
||||
|
||||
// Convert content based on MessageContent type
|
||||
let content = match &input_msg.content {
|
||||
crate::apis::openai_responses::MessageContent::Text(text) => {
|
||||
// Simple text content
|
||||
MessageContent::Text(text.clone())
|
||||
}
|
||||
crate::apis::openai_responses::MessageContent::Items(content_items) => {
|
||||
// Check if it's a single text item (can use simple text format)
|
||||
if content_items.len() == 1 {
|
||||
if let InputContent::InputText { text } = &content_items[0] {
|
||||
MessageContent::Text(text.clone())
|
||||
} else {
|
||||
// Single non-text item - use parts format
|
||||
MessageContent::Parts(
|
||||
content_items.iter()
|
||||
.filter_map(|c| match c {
|
||||
InputContent::InputText { text } => {
|
||||
Some(crate::apis::openai::ContentPart::Text { text: text.clone() })
|
||||
}
|
||||
InputContent::InputImage { image_url, .. } => {
|
||||
Some(crate::apis::openai::ContentPart::ImageUrl {
|
||||
image_url: crate::apis::openai::ImageUrl {
|
||||
url: image_url.clone(),
|
||||
detail: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
InputContent::InputFile { .. } => None, // Skip files for now
|
||||
InputContent::InputAudio { .. } => None, // Skip audio for now
|
||||
})
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Multiple content items - convert to parts
|
||||
MessageContent::Parts(
|
||||
content_items.iter()
|
||||
.filter_map(|c| match c {
|
||||
InputContent::InputText { text } => {
|
||||
Some(crate::apis::openai::ContentPart::Text { text: text.clone() })
|
||||
}
|
||||
InputContent::InputImage { image_url, .. } => {
|
||||
Some(crate::apis::openai::ContentPart::ImageUrl {
|
||||
image_url: crate::apis::openai::ImageUrl {
|
||||
url: image_url.clone(),
|
||||
detail: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
InputContent::InputFile { .. } => None, // Skip files for now
|
||||
InputContent::InputAudio { .. } => None, // Skip audio for now
|
||||
})
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
converted_messages.push(Message {
|
||||
role,
|
||||
content,
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
// Skip non-message items (references, outputs) for now
|
||||
// These would need special handling in chat completions format
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
converted_messages
|
||||
}
|
||||
};
|
||||
|
||||
// Build the ChatCompletionsRequest
|
||||
Ok(ChatCompletionsRequest {
|
||||
model: req.model,
|
||||
messages,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
max_completion_tokens: req.max_output_tokens.map(|t| t as u32),
|
||||
stream: req.stream,
|
||||
metadata: req.metadata,
|
||||
user: req.user,
|
||||
store: req.store,
|
||||
service_tier: req.service_tier,
|
||||
top_logprobs: req.top_logprobs.map(|t| t as u32),
|
||||
modalities: req.modalities.map(|mods| {
|
||||
mods.into_iter().map(|m| {
|
||||
match m {
|
||||
Modality::Text => "text".to_string(),
|
||||
Modality::Audio => "audio".to_string(),
|
||||
}
|
||||
}).collect()
|
||||
}),
|
||||
stream_options: req.stream_options.map(|opts| {
|
||||
crate::apis::openai::StreamOptions {
|
||||
include_usage: opts.include_usage,
|
||||
}
|
||||
}),
|
||||
reasoning_effort: req.reasoning_effort.map(|effort| {
|
||||
match effort {
|
||||
ReasoningEffort::Low => "low".to_string(),
|
||||
ReasoningEffort::Medium => "medium".to_string(),
|
||||
ReasoningEffort::High => "high".to_string(),
|
||||
}
|
||||
}),
|
||||
tools: req.tools.map(|tools| {
|
||||
tools.into_iter().map(|tool| {
|
||||
|
||||
// Only convert Function tools - other types are not supported in ChatCompletions
|
||||
match tool {
|
||||
ResponsesTool::Function { name, description, parameters, strict } => Ok(Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: crate::apis::openai::Function {
|
||||
name,
|
||||
description,
|
||||
parameters: parameters.unwrap_or_else(|| serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})),
|
||||
strict,
|
||||
}
|
||||
}),
|
||||
ResponsesTool::FileSearch { .. } => Err(TransformError::UnsupportedConversion(
|
||||
"FileSearch tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
|
||||
)),
|
||||
ResponsesTool::WebSearchPreview { .. } => Err(TransformError::UnsupportedConversion(
|
||||
"WebSearchPreview tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
|
||||
)),
|
||||
ResponsesTool::CodeInterpreter => Err(TransformError::UnsupportedConversion(
|
||||
"CodeInterpreter tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
|
||||
)),
|
||||
ResponsesTool::Computer { .. } => Err(TransformError::UnsupportedConversion(
|
||||
"Computer tool is not supported in ChatCompletions API. Only function tools are supported.".to_string()
|
||||
)),
|
||||
}
|
||||
}).collect::<Result<Vec<_>, _>>()
|
||||
}).transpose()?,
|
||||
tool_choice: req.tool_choice.map(|choice| {
|
||||
match choice {
|
||||
ResponsesToolChoice::String(s) => {
|
||||
match s.as_str() {
|
||||
"auto" => ToolChoice::Type(ToolChoiceType::Auto),
|
||||
"required" => ToolChoice::Type(ToolChoiceType::Required),
|
||||
"none" => ToolChoice::Type(ToolChoiceType::None),
|
||||
_ => ToolChoice::Type(ToolChoiceType::Auto), // Default to auto for unknown strings
|
||||
}
|
||||
}
|
||||
ResponsesToolChoice::Named { function, .. } => ToolChoice::Function {
|
||||
choice_type: "function".to_string(),
|
||||
function: crate::apis::openai::FunctionChoice { name: function.name }
|
||||
}
|
||||
}
|
||||
}),
|
||||
parallel_tool_calls: req.parallel_tool_calls,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
|
||||
type Error = TransformError;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
//! Response transformation modules
|
||||
pub mod output_to_input;
|
||||
pub mod to_anthropic;
|
||||
pub mod to_openai;
|
||||
|
|
|
|||
178
crates/hermesllm/src/transforms/response/output_to_input.rs
Normal file
178
crates/hermesllm/src/transforms/response/output_to_input.rs
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
//! Conversions from response outputs to request inputs for conversation continuation
|
||||
//!
|
||||
//! This module provides utilities for converting OutputItem types from API responses
|
||||
//! into InputItem types that can be used in subsequent requests. This is primarily used
|
||||
//! for maintaining conversation history in the v1/responses API.
|
||||
|
||||
use crate::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, MessageContent, MessageRole, OutputContent, OutputItem,
|
||||
};
|
||||
|
||||
/// Converts an OutputItem from a response into an InputItem for the next request
|
||||
/// This is used to build conversation history from previous responses
|
||||
pub fn convert_responses_output_to_input_items(output: &OutputItem) -> Option<InputItem> {
|
||||
match output {
|
||||
// Convert output messages to input messages
|
||||
OutputItem::Message {
|
||||
role, content, ..
|
||||
} => {
|
||||
let input_content: Vec<InputContent> = content
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
OutputContent::OutputText { text, .. } => Some(InputContent::InputText {
|
||||
text: text.clone(),
|
||||
}),
|
||||
OutputContent::OutputAudio {
|
||||
data, ..
|
||||
} => Some(InputContent::InputAudio {
|
||||
data: data.clone(),
|
||||
format: None, // Format not preserved in output
|
||||
}),
|
||||
OutputContent::Refusal { .. } => None, // Skip refusals
|
||||
})
|
||||
.collect();
|
||||
|
||||
if input_content.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Map role string to MessageRole enum
|
||||
let message_role = match role.as_str() {
|
||||
"user" => MessageRole::User,
|
||||
"assistant" => MessageRole::Assistant,
|
||||
"system" => MessageRole::System,
|
||||
"developer" => MessageRole::Developer,
|
||||
_ => MessageRole::Assistant, // Default to assistant
|
||||
};
|
||||
|
||||
Some(InputItem::Message(InputMessage {
|
||||
role: message_role,
|
||||
content: MessageContent::Items(input_content),
|
||||
}))
|
||||
}
|
||||
// For function calls, we'll create an assistant message with the tool call info
|
||||
// This matches how conversation history is typically built
|
||||
OutputItem::FunctionCall {
|
||||
name, arguments, ..
|
||||
} => {
|
||||
let tool_call_text = if let (Some(n), Some(args)) = (name, arguments) {
|
||||
format!("Called function: {} with arguments: {}", n, args)
|
||||
} else {
|
||||
"Called a function".to_string()
|
||||
};
|
||||
|
||||
Some(InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: tool_call_text,
|
||||
}]),
|
||||
}))
|
||||
}
|
||||
// Skip other output types (tool outputs, etc.) as they don't convert to input
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a Vec of OutputItems into InputItems for conversation continuation
|
||||
pub fn outputs_to_inputs(outputs: &[OutputItem]) -> Vec<InputItem> {
|
||||
outputs
|
||||
.iter()
|
||||
.filter_map(convert_responses_output_to_input_items)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::openai_responses::{OutputItemStatus};
|
||||
|
||||
#[test]
|
||||
fn test_output_message_to_input() {
|
||||
let output = OutputItem::Message {
|
||||
id: "msg_123".to_string(),
|
||||
status: OutputItemStatus::Completed,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![OutputContent::OutputText {
|
||||
text: "Hello!".to_string(),
|
||||
annotations: vec![],
|
||||
logprobs: None,
|
||||
}],
|
||||
};
|
||||
|
||||
let input = convert_responses_output_to_input_items(&output).unwrap();
|
||||
|
||||
match input {
|
||||
InputItem::Message(msg) => {
|
||||
assert!(matches!(msg.role, MessageRole::Assistant));
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => {
|
||||
assert_eq!(items.len(), 1);
|
||||
match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Hello!"),
|
||||
_ => panic!("Expected InputText"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Message variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_to_input() {
|
||||
let output = OutputItem::FunctionCall {
|
||||
id: "fc_123".to_string(),
|
||||
status: OutputItemStatus::Completed,
|
||||
call_id: "call_123".to_string(),
|
||||
name: Some("get_weather".to_string()),
|
||||
arguments: Some(r#"{"location":"SF"}"#.to_string()),
|
||||
};
|
||||
|
||||
let input = convert_responses_output_to_input_items(&output).unwrap();
|
||||
|
||||
match input {
|
||||
InputItem::Message(msg) => {
|
||||
assert!(matches!(msg.role, MessageRole::Assistant));
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => {
|
||||
match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("get_weather"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Message variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_outputs_to_inputs() {
|
||||
let outputs = vec![
|
||||
OutputItem::Message {
|
||||
id: "msg_1".to_string(),
|
||||
status: OutputItemStatus::Completed,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![OutputContent::OutputText {
|
||||
text: "Hello".to_string(),
|
||||
annotations: vec![],
|
||||
logprobs: None,
|
||||
}],
|
||||
},
|
||||
OutputItem::FunctionCall {
|
||||
id: "fc_1".to_string(),
|
||||
status: OutputItemStatus::Completed,
|
||||
call_id: "call_1".to_string(),
|
||||
name: Some("test".to_string()),
|
||||
arguments: Some("{}".to_string()),
|
||||
},
|
||||
];
|
||||
|
||||
let inputs = outputs_to_inputs(&outputs);
|
||||
assert_eq!(inputs.len(), 2);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,16 +1,11 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
ContentBlockDelta, ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
|
||||
};
|
||||
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesResponse,
|
||||
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsResponse, ChatCompletionsStreamResponse, Role, ToolCallDelta,
|
||||
MessagesContentBlock, MessagesResponse,
|
||||
MessagesRole, MessagesStopReason, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::*;
|
||||
use serde_json::Value;
|
||||
|
||||
// ============================================================================
|
||||
// STANDARD RUST TRAIT IMPLEMENTATIONS - Using Into/TryFrom for convenience
|
||||
|
|
@ -120,289 +115,6 @@ impl TryFrom<ConverseResponse> for MessagesResponse {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsStreamResponse> for MessagesStreamEvent {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: ChatCompletionsStreamResponse) -> Result<Self, Self::Error> {
|
||||
if resp.choices.is_empty() {
|
||||
return Ok(MessagesStreamEvent::Ping);
|
||||
}
|
||||
|
||||
let choice = &resp.choices[0];
|
||||
|
||||
// Handle final chunk with usage
|
||||
let has_usage = resp.usage.is_some();
|
||||
if let Some(usage) = resp.usage {
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
|
||||
return Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: usage.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle role start
|
||||
if let Some(Role::Assistant) = choice.delta.role {
|
||||
return Ok(MessagesStreamEvent::MessageStart {
|
||||
message: MessagesStreamMessage {
|
||||
id: resp.id,
|
||||
obj_type: "message".to_string(),
|
||||
role: MessagesRole::Assistant,
|
||||
content: vec![],
|
||||
model: resp.model,
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Handle content delta
|
||||
if let Some(content) = &choice.delta.content {
|
||||
if !content.is_empty() {
|
||||
return Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta {
|
||||
text: content.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if let Some(tool_calls) = &choice.delta.tool_calls {
|
||||
return convert_tool_call_deltas(tool_calls.clone());
|
||||
}
|
||||
|
||||
// Handle finish reason - generate MessageDelta only (MessageStop comes later)
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
// If we have usage data, it was already handled above
|
||||
// If not, we need to generate MessageDelta with default usage
|
||||
if !has_usage {
|
||||
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
|
||||
return Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
});
|
||||
}
|
||||
// If usage was already handled above, we don't need to do anything more here
|
||||
// MessageStop will be handled when [DONE] is encountered
|
||||
}
|
||||
|
||||
// Default to ping for unhandled cases
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<String> for MessagesStreamEvent {
|
||||
fn into(self) -> String {
|
||||
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
|
||||
let event_type = match &self {
|
||||
MessagesStreamEvent::MessageStart { .. } => "message_start",
|
||||
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
|
||||
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
|
||||
MessagesStreamEvent::MessageStop => "message_stop",
|
||||
MessagesStreamEvent::Ping => "ping",
|
||||
};
|
||||
|
||||
let event = format!("event: {}\n", event_type);
|
||||
let data = format!("data: {}\n\n", transformed_json);
|
||||
event + &data
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
// MessageStart - convert to Anthropic MessageStart
|
||||
ConverseStreamEvent::MessageStart(start_event) => {
|
||||
let role = match start_event.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => {
|
||||
MessagesRole::Assistant
|
||||
}
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::MessageStart {
|
||||
message: MessagesStreamMessage {
|
||||
id: format!(
|
||||
"bedrock-stream-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos()
|
||||
),
|
||||
obj_type: "message".to_string(),
|
||||
role,
|
||||
content: vec![],
|
||||
model: "bedrock-model".to_string(),
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ContentBlockStart - convert to Anthropic ContentBlockStart
|
||||
ConverseStreamEvent::ContentBlockStart(start_event) => {
|
||||
// Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas
|
||||
// Anthropic expects the same pattern, so we initialize with an empty input object
|
||||
match start_event.start {
|
||||
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => {
|
||||
Ok(MessagesStreamEvent::ContentBlockStart {
|
||||
index: start_event.content_block_index as u32,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: tool_use.tool_use_id,
|
||||
name: tool_use.name,
|
||||
input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas
|
||||
cache_control: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ContentBlockDelta - convert to Anthropic ContentBlockDelta
|
||||
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
|
||||
let delta = match delta_event.delta {
|
||||
ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text },
|
||||
ContentBlockDelta::ToolUse { tool_use } => {
|
||||
MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: tool_use.input,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: delta_event.content_block_index as u32,
|
||||
delta,
|
||||
})
|
||||
}
|
||||
|
||||
// ContentBlockStop - convert to Anthropic ContentBlockStop
|
||||
ConverseStreamEvent::ContentBlockStop(stop_event) => {
|
||||
Ok(MessagesStreamEvent::ContentBlockStop {
|
||||
index: stop_event.content_block_index as u32,
|
||||
})
|
||||
}
|
||||
|
||||
// MessageStop - convert to Anthropic MessageDelta with stop reason + MessageStop
|
||||
ConverseStreamEvent::MessageStop(stop_event) => {
|
||||
let anthropic_stop_reason = match stop_event.stop_reason {
|
||||
StopReason::EndTurn => MessagesStopReason::EndTurn,
|
||||
StopReason::ToolUse => MessagesStopReason::ToolUse,
|
||||
StopReason::MaxTokens => MessagesStopReason::MaxTokens,
|
||||
StopReason::StopSequence => MessagesStopReason::EndTurn,
|
||||
StopReason::GuardrailIntervened => MessagesStopReason::Refusal,
|
||||
StopReason::ContentFiltered => MessagesStopReason::Refusal,
|
||||
};
|
||||
|
||||
// Return MessageDelta (MessageStop will be sent separately by the streaming handler)
|
||||
Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Metadata - convert usage information to MessageDelta
|
||||
ConverseStreamEvent::Metadata(metadata_event) => {
|
||||
Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: MessagesStopReason::EndTurn,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: metadata_event.usage.input_tokens,
|
||||
output_tokens: metadata_event.usage.output_tokens,
|
||||
cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens,
|
||||
cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Exception events - convert to Ping (could be enhanced to return error events)
|
||||
ConverseStreamEvent::InternalServerException(_)
|
||||
| ConverseStreamEvent::ModelStreamErrorException(_)
|
||||
| ConverseStreamEvent::ServiceUnavailableException(_)
|
||||
| ConverseStreamEvent::ThrottlingException(_)
|
||||
| ConverseStreamEvent::ValidationException(_) => {
|
||||
// TODO: Consider adding proper error handling/events
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert tool call deltas to Anthropic stream events
|
||||
fn convert_tool_call_deltas(
|
||||
tool_calls: Vec<ToolCallDelta>,
|
||||
) -> Result<MessagesStreamEvent, TransformError> {
|
||||
for tool_call in tool_calls {
|
||||
if let Some(id) = &tool_call.id {
|
||||
// Tool call start
|
||||
if let Some(function) = &tool_call.function {
|
||||
if let Some(name) = &function.name {
|
||||
return Ok(MessagesStreamEvent::ContentBlockStart {
|
||||
index: tool_call.index,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input: Value::Object(serde_json::Map::new()),
|
||||
cache_control: None,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if let Some(function) = &tool_call.function {
|
||||
if let Some(arguments) = &function.arguments {
|
||||
// Tool arguments delta
|
||||
return Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: tool_call.index,
|
||||
delta: MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: arguments.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to ping if no valid tool call found
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
|
||||
/// Convert Bedrock Message to Anthropic content blocks
|
||||
///
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
ConverseOutput, ConverseResponse, ConverseStreamEvent, StopReason,
|
||||
ConverseOutput, ConverseResponse, StopReason,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesResponse, MessagesStopReason,
|
||||
MessagesStreamEvent, MessagesUsage,
|
||||
MessagesContentBlock, MessagesResponse, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason,
|
||||
FunctionCallDelta, MessageContent, MessageDelta, ResponseMessage, Role, StreamChoice,
|
||||
ToolCallDelta, Usage,
|
||||
ChatCompletionsResponse, Choice, FinishReason, MessageContent, ResponseMessage, Role, Usage,
|
||||
};
|
||||
use crate::apis::openai_responses::ResponsesAPIResponse;
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::*;
|
||||
|
||||
|
|
@ -30,6 +28,182 @@ impl Into<Usage> for MessagesUsage {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: ChatCompletionsResponse) -> Result<Self, Self::Error> {
|
||||
use crate::apis::openai_responses::{
|
||||
IncompleteDetails, IncompleteReason, OutputContent, OutputItem, OutputItemStatus,
|
||||
ResponseStatus, ResponseUsage, ResponsesAPIResponse,
|
||||
};
|
||||
|
||||
// Convert the first choice's message to output items
|
||||
let output = if let Some(choice) = resp.choices.first() {
|
||||
let mut items = Vec::new();
|
||||
|
||||
// Create a message output item from the response message
|
||||
let mut content = Vec::new();
|
||||
|
||||
// Add text content if present
|
||||
if let Some(text) = &choice.message.content {
|
||||
content.push(OutputContent::OutputText {
|
||||
text: text.clone(),
|
||||
annotations: vec![],
|
||||
logprobs: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Add audio content if present (audio is a Value, need to handle it carefully)
|
||||
if let Some(audio) = &choice.message.audio {
|
||||
// Audio is serde_json::Value, try to extract data and transcript
|
||||
if let Some(audio_obj) = audio.as_object() {
|
||||
let data = audio_obj
|
||||
.get("data")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let transcript = audio_obj
|
||||
.get("transcript")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
content.push(OutputContent::OutputAudio { data, transcript });
|
||||
}
|
||||
}
|
||||
|
||||
// Add refusal content if present
|
||||
if let Some(refusal) = &choice.message.refusal {
|
||||
content.push(OutputContent::Refusal {
|
||||
refusal: refusal.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Only add the message item if there's actual content (text, audio, or refusal)
|
||||
// Don't add empty message items when there are only tool calls
|
||||
if !content.is_empty() {
|
||||
// Generate message ID: strip common prefixes to avoid double-prefixing
|
||||
let message_id = if resp.id.starts_with("msg_") {
|
||||
resp.id.clone()
|
||||
} else if resp.id.starts_with("resp_") {
|
||||
format!("msg_{}", &resp.id[5..]) // Strip "resp_" prefix
|
||||
} else if resp.id.starts_with("chatcmpl-") {
|
||||
format!("msg_{}", &resp.id[9..]) // Strip "chatcmpl-" prefix
|
||||
} else {
|
||||
format!("msg_{}", resp.id)
|
||||
};
|
||||
|
||||
items.push(OutputItem::Message {
|
||||
id: message_id,
|
||||
status: OutputItemStatus::Completed,
|
||||
role: match choice.message.role {
|
||||
Role::User => "user".to_string(),
|
||||
Role::Assistant => "assistant".to_string(),
|
||||
Role::System => "system".to_string(),
|
||||
Role::Tool => "tool".to_string(),
|
||||
},
|
||||
content,
|
||||
});
|
||||
}
|
||||
|
||||
// Add tool calls as function call items if present
|
||||
if let Some(tool_calls) = &choice.message.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
items.push(OutputItem::FunctionCall {
|
||||
id: format!("func_{}", tool_call.id),
|
||||
status: OutputItemStatus::Completed,
|
||||
call_id: tool_call.id.clone(),
|
||||
name: Some(tool_call.function.name.clone()),
|
||||
arguments: Some(tool_call.function.arguments.clone()),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
items
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Convert finish_reason to status
|
||||
let status = if let Some(choice) = resp.choices.first() {
|
||||
match choice.finish_reason {
|
||||
Some(FinishReason::Stop) => ResponseStatus::Completed,
|
||||
Some(FinishReason::ToolCalls) => ResponseStatus::Completed,
|
||||
Some(FinishReason::Length) => ResponseStatus::Incomplete,
|
||||
Some(FinishReason::ContentFilter) => ResponseStatus::Failed,
|
||||
_ => ResponseStatus::Completed,
|
||||
}
|
||||
} else {
|
||||
ResponseStatus::Completed
|
||||
};
|
||||
|
||||
// Convert usage
|
||||
let usage = ResponseUsage {
|
||||
input_tokens: resp.usage.prompt_tokens as i32,
|
||||
output_tokens: resp.usage.completion_tokens as i32,
|
||||
total_tokens: resp.usage.total_tokens as i32,
|
||||
input_tokens_details: resp.usage.prompt_tokens_details.map(|details| {
|
||||
crate::apis::openai_responses::TokenDetails {
|
||||
cached_tokens: details.cached_tokens.unwrap_or(0) as i32,
|
||||
}
|
||||
}),
|
||||
output_tokens_details: resp.usage.completion_tokens_details.map(|details| {
|
||||
crate::apis::openai_responses::OutputTokenDetails {
|
||||
reasoning_tokens: details.reasoning_tokens.unwrap_or(0) as i32,
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
// Set incomplete_details if status is incomplete
|
||||
let incomplete_details = if matches!(status, ResponseStatus::Incomplete) {
|
||||
Some(IncompleteDetails {
|
||||
reason: IncompleteReason::MaxOutputTokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(ResponsesAPIResponse {
|
||||
// Generate proper resp_ prefixed ID if not already present
|
||||
id: if resp.id.starts_with("resp_") {
|
||||
resp.id
|
||||
} else {
|
||||
format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", ""))
|
||||
},
|
||||
object: "response".to_string(),
|
||||
created_at: resp.created as i64,
|
||||
status,
|
||||
background: Some(false),
|
||||
error: None,
|
||||
incomplete_details,
|
||||
instructions: None,
|
||||
max_output_tokens: None,
|
||||
max_tool_calls: None,
|
||||
model: resp.model,
|
||||
output,
|
||||
usage: Some(usage),
|
||||
parallel_tool_calls: true,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
tools: vec![],
|
||||
tool_choice: "auto".to_string(),
|
||||
temperature: 1.0,
|
||||
top_p: 1.0,
|
||||
metadata: resp.metadata.unwrap_or_default(),
|
||||
truncation: None,
|
||||
reasoning: Some(crate::apis::openai_responses::Reasoning {
|
||||
effort: None,
|
||||
summary: None,
|
||||
}),
|
||||
store: None,
|
||||
text: None,
|
||||
audio: None,
|
||||
modalities: None,
|
||||
service_tier: resp.service_tier,
|
||||
top_logprobs: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
|
|
@ -83,8 +257,7 @@ impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
|
|||
model: resp.model,
|
||||
choices: vec![choice],
|
||||
usage,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -169,422 +342,11 @@ impl TryFrom<ConverseResponse> for ChatCompletionsResponse {
|
|||
model,
|
||||
choices: vec![choice],
|
||||
usage,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// STREAMING TRANSFORMATIONS
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
|
||||
&message.id,
|
||||
&message.model,
|
||||
MessageDelta {
|
||||
role: Some(Role::Assistant),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
|
||||
convert_content_block_start(content_block)
|
||||
}
|
||||
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
|
||||
|
||||
MessagesStreamEvent::MessageDelta { delta, usage } => {
|
||||
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
|
||||
let openai_usage: Option<Usage> = Some(usage.into());
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason,
|
||||
openai_usage,
|
||||
))
|
||||
}
|
||||
|
||||
MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(FinishReason::Stop),
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
|
||||
id: "stream".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: "unknown".to_string(),
|
||||
choices: vec![],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
ConverseStreamEvent::MessageStart(start_event) => {
|
||||
let role = match start_event.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => Role::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: Some(role),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockStart(start_event) => {
|
||||
use crate::apis::amazon_bedrock::ContentBlockStart;
|
||||
|
||||
match start_event.start {
|
||||
ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: start_event.content_block_index as u32,
|
||||
id: Some(tool_use.tool_use_id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(tool_use.name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
|
||||
use crate::apis::amazon_bedrock::ContentBlockDelta;
|
||||
|
||||
match delta_event.delta {
|
||||
ContentBlockDelta::Text { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: delta_event.content_block_index as u32,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(tool_use.input),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()),
|
||||
|
||||
ConverseStreamEvent::MessageStop(stop_event) => {
|
||||
let finish_reason = match stop_event.stop_reason {
|
||||
StopReason::EndTurn => FinishReason::Stop,
|
||||
StopReason::ToolUse => FinishReason::ToolCalls,
|
||||
StopReason::MaxTokens => FinishReason::Length,
|
||||
StopReason::StopSequence => FinishReason::Stop,
|
||||
StopReason::GuardrailIntervened => FinishReason::ContentFilter,
|
||||
StopReason::ContentFiltered => FinishReason::ContentFilter,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(finish_reason),
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
ConverseStreamEvent::Metadata(metadata_event) => {
|
||||
let usage = Usage {
|
||||
prompt_tokens: metadata_event.usage.input_tokens,
|
||||
completion_tokens: metadata_event.usage.output_tokens,
|
||||
total_tokens: metadata_event.usage.total_tokens,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
Some(usage),
|
||||
))
|
||||
}
|
||||
|
||||
// Error events - convert to empty chunks (errors should be handled elsewhere)
|
||||
ConverseStreamEvent::InternalServerException(_)
|
||||
| ConverseStreamEvent::ModelStreamErrorException(_)
|
||||
| ConverseStreamEvent::ServiceUnavailableException(_)
|
||||
| ConverseStreamEvent::ThrottlingException(_)
|
||||
| ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content block start to OpenAI chunk
|
||||
fn convert_content_block_start(
|
||||
content_block: MessagesContentBlock,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match content_block {
|
||||
MessagesContentBlock::Text { .. } => {
|
||||
// No immediate output for text block start
|
||||
Ok(create_empty_openai_chunk())
|
||||
}
|
||||
MessagesContentBlock::ToolUse { id, name, .. }
|
||||
| MessagesContentBlock::ServerToolUse { id, name, .. }
|
||||
| MessagesContentBlock::McpToolUse { id, name, .. } => {
|
||||
// Tool use start → OpenAI chunk with tool_calls
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: Some(id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
_ => Err(TransformError::UnsupportedContent(
|
||||
"Unsupported content block type in stream start".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content delta to OpenAI chunk
|
||||
fn convert_content_delta(
|
||||
delta: MessagesContentDelta,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(format!("thinking: {}", thinking)),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(partial_json),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create OpenAI streaming chunk
|
||||
fn create_openai_chunk(
|
||||
id: &str,
|
||||
model: &str,
|
||||
delta: MessageDelta,
|
||||
finish_reason: Option<FinishReason>,
|
||||
usage: Option<Usage>,
|
||||
) -> ChatCompletionsStreamResponse {
|
||||
ChatCompletionsStreamResponse {
|
||||
id: id.to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: model.to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
finish_reason,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create empty OpenAI streaming chunk
|
||||
fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
|
||||
create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
/// Convert Anthropic content blocks to OpenAI message content
|
||||
fn convert_anthropic_content_to_openai(
|
||||
content: &[MessagesContentBlock],
|
||||
) -> Result<MessageContent, TransformError> {
|
||||
let mut text_parts = Vec::new();
|
||||
|
||||
for block in content {
|
||||
match block {
|
||||
MessagesContentBlock::Text { text, .. } => {
|
||||
text_parts.push(text.clone());
|
||||
}
|
||||
MessagesContentBlock::Thinking { thinking, .. } => {
|
||||
text_parts.push(format!("thinking: {}", thinking));
|
||||
}
|
||||
_ => {
|
||||
// Skip other content types for basic text conversion
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(MessageContent::Text(text_parts.join("\n")))
|
||||
}
|
||||
|
||||
// Stop Reason Conversions
|
||||
impl Into<FinishReason> for MessagesStopReason {
|
||||
fn into(self) -> FinishReason {
|
||||
match self {
|
||||
MessagesStopReason::EndTurn => FinishReason::Stop,
|
||||
MessagesStopReason::MaxTokens => FinishReason::Length,
|
||||
MessagesStopReason::StopSequence => FinishReason::Stop,
|
||||
MessagesStopReason::ToolUse => FinishReason::ToolCalls,
|
||||
MessagesStopReason::PauseTurn => FinishReason::Stop,
|
||||
MessagesStopReason::Refusal => FinishReason::ContentFilter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Bedrock Message to OpenAI content and tool calls
|
||||
/// This function extracts text content and tool calls from a Bedrock message
|
||||
fn convert_bedrock_message_to_openai(
|
||||
|
|
@ -629,6 +391,31 @@ fn convert_bedrock_message_to_openai(
|
|||
Ok((content, tool_calls))
|
||||
}
|
||||
|
||||
/// Convert Anthropic content blocks to OpenAI message content
|
||||
fn convert_anthropic_content_to_openai(
|
||||
content: &[MessagesContentBlock],
|
||||
) -> Result<MessageContent, TransformError> {
|
||||
let mut text_parts = Vec::new();
|
||||
|
||||
for block in content {
|
||||
match block {
|
||||
MessagesContentBlock::Text { text, .. } => {
|
||||
text_parts.push(text.clone());
|
||||
}
|
||||
MessagesContentBlock::Thinking { thinking, .. } => {
|
||||
text_parts.push(format!("thinking: {}", thinking));
|
||||
}
|
||||
_ => {
|
||||
// Skip other content types for basic text conversion
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(MessageContent::Text(text_parts.join("\n")))
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -1168,4 +955,214 @@ mod tests {
|
|||
assert!(content.contains("Here's the analysis:"));
|
||||
// Note: Image blocks are not converted to text in the current implementation
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_basic() {
|
||||
use crate::apis::openai_responses::{OutputContent, OutputItem, ResponsesAPIResponse};
|
||||
|
||||
let chat_response = ChatCompletionsResponse {
|
||||
id: "resp_6de5512800cf4375a329a473a4f02879".to_string(),
|
||||
object: Some("chat.completion".to_string()),
|
||||
created: 1677652288,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: crate::apis::openai::ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: Some("Hello! How can I help you?".to_string()),
|
||||
refusal: None,
|
||||
annotations: None,
|
||||
audio: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: Some(FinishReason::Stop),
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 20,
|
||||
total_tokens: 30,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
},
|
||||
system_fingerprint: None,
|
||||
service_tier: Some("default".to_string()),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
|
||||
|
||||
// Response ID should be generated with resp_ prefix
|
||||
assert!(responses_api.id.starts_with("resp_"), "Response ID should start with 'resp_'");
|
||||
assert_eq!(responses_api.id.len(), 37, "Response ID should be resp_ + 32 char UUID");
|
||||
assert_eq!(responses_api.object, "response");
|
||||
assert_eq!(responses_api.model, "gpt-4");
|
||||
|
||||
// Check usage conversion
|
||||
let usage = responses_api.usage.unwrap();
|
||||
assert_eq!(usage.input_tokens, 10);
|
||||
assert_eq!(usage.output_tokens, 20);
|
||||
assert_eq!(usage.total_tokens, 30);
|
||||
|
||||
// Check output items
|
||||
assert_eq!(responses_api.output.len(), 1);
|
||||
match &responses_api.output[0] {
|
||||
OutputItem::Message {
|
||||
role,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(role, "assistant");
|
||||
assert_eq!(content.len(), 1);
|
||||
match &content[0] {
|
||||
OutputContent::OutputText { text, .. } => {
|
||||
assert_eq!(text, "Hello! How can I help you?");
|
||||
}
|
||||
_ => panic!("Expected OutputText content"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Message output item"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_with_tool_calls() {
|
||||
use crate::apis::openai::{FunctionCall, ToolCall};
|
||||
use crate::apis::openai_responses::{OutputItem, ResponsesAPIResponse};
|
||||
|
||||
let chat_response = ChatCompletionsResponse {
|
||||
id: "chatcmpl-456".to_string(),
|
||||
object: Some("chat.completion".to_string()),
|
||||
created: 1677652300,
|
||||
model: "gpt-4".to_string(),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: crate::apis::openai::ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: Some("Let me check the weather.".to_string()),
|
||||
refusal: None,
|
||||
annotations: None,
|
||||
audio: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "call_abc123".to_string(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "get_weather".to_string(),
|
||||
arguments: r#"{"location":"San Francisco"}"#.to_string(),
|
||||
},
|
||||
}]),
|
||||
},
|
||||
finish_reason: Some(FinishReason::ToolCalls),
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: 15,
|
||||
completion_tokens: 25,
|
||||
total_tokens: 40,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
},
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
|
||||
|
||||
// Should have 2 output items: message + function call
|
||||
assert_eq!(responses_api.output.len(), 2);
|
||||
|
||||
// Check message item
|
||||
match &responses_api.output[0] {
|
||||
OutputItem::Message { content, .. } => {
|
||||
assert_eq!(content.len(), 1);
|
||||
}
|
||||
_ => panic!("Expected Message output item"),
|
||||
}
|
||||
|
||||
// Check function call item
|
||||
match &responses_api.output[1] {
|
||||
OutputItem::FunctionCall {
|
||||
call_id,
|
||||
name,
|
||||
arguments,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(call_id, "call_abc123");
|
||||
assert_eq!(name.as_ref().unwrap(), "get_weather");
|
||||
assert!(arguments.as_ref().unwrap().contains("San Francisco"));
|
||||
}
|
||||
_ => panic!("Expected FunctionCall output item"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_tool_calls_only() {
|
||||
use crate::apis::openai::{FunctionCall, ToolCall};
|
||||
use crate::apis::openai_responses::{OutputItem, ResponsesAPIResponse};
|
||||
|
||||
// Test the real-world case where content is null and there are only tool calls
|
||||
let chat_response = ChatCompletionsResponse {
|
||||
id: "chatcmpl-789".to_string(),
|
||||
object: Some("chat.completion".to_string()),
|
||||
created: 1764023939,
|
||||
model: "gpt-4o-2024-08-06".to_string(),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: crate::apis::openai::ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: None, // No text content, only tool calls
|
||||
refusal: None,
|
||||
annotations: None,
|
||||
audio: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: "call_oJBtqTJmRfBGlFS55QhMfUUV".to_string(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "get_weather".to_string(),
|
||||
arguments: r#"{"location":"San Francisco, CA"}"#.to_string(),
|
||||
},
|
||||
}]),
|
||||
},
|
||||
finish_reason: Some(FinishReason::ToolCalls),
|
||||
logprobs: None,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: 84,
|
||||
completion_tokens: 17,
|
||||
total_tokens: 101,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
},
|
||||
system_fingerprint: Some("fp_7eeb46f068".to_string()),
|
||||
service_tier: Some("default".to_string()),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
|
||||
|
||||
// Should have only 1 output item: function call (no empty message item)
|
||||
assert_eq!(responses_api.output.len(), 1);
|
||||
|
||||
// Check function call item
|
||||
match &responses_api.output[0] {
|
||||
OutputItem::FunctionCall {
|
||||
call_id,
|
||||
name,
|
||||
arguments,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(call_id, "call_oJBtqTJmRfBGlFS55QhMfUUV");
|
||||
assert_eq!(name.as_ref().unwrap(), "get_weather");
|
||||
assert!(arguments.as_ref().unwrap().contains("San Francisco, CA"));
|
||||
}
|
||||
_ => panic!("Expected FunctionCall output item as first item"),
|
||||
}
|
||||
|
||||
// Verify status is Completed for tool_calls finish reason
|
||||
assert!(matches!(responses_api.status, crate::apis::openai_responses::ResponseStatus::Completed));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
pub mod to_anthropic_streaming;
|
||||
pub mod to_openai_streaming;
|
||||
|
|
@ -0,0 +1,281 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
ContentBlockDelta, ConverseStreamEvent,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta,
|
||||
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::{ ChatCompletionsStreamResponse, ToolCallDelta,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use serde_json::Value;
|
||||
|
||||
impl TryFrom<ChatCompletionsStreamResponse> for MessagesStreamEvent {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(resp: ChatCompletionsStreamResponse) -> Result<Self, Self::Error> {
|
||||
if resp.choices.is_empty() {
|
||||
return Ok(MessagesStreamEvent::Ping);
|
||||
}
|
||||
|
||||
let choice = &resp.choices[0];
|
||||
|
||||
// Handle final chunk with usage
|
||||
let has_usage = resp.usage.is_some();
|
||||
if let Some(usage) = resp.usage {
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
|
||||
return Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: usage.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: We do NOT emit MessageStart here anymore!
|
||||
// The AnthropicMessagesStreamBuffer will inject message_start and content_block_start
|
||||
// when it sees the first content_block_delta. This solves the problem where OpenAI
|
||||
// sends both role and content in the same chunk - we can only return one event here,
|
||||
// so we prioritize the content and let the buffer handle lifecycle events.
|
||||
|
||||
// Handle content delta (even if role is present in the same chunk)
|
||||
if let Some(content) = &choice.delta.content {
|
||||
if !content.is_empty() {
|
||||
return Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: 0,
|
||||
delta: MessagesContentDelta::TextDelta {
|
||||
text: content.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if let Some(tool_calls) = &choice.delta.tool_calls {
|
||||
return convert_tool_call_deltas(tool_calls.clone());
|
||||
}
|
||||
|
||||
// Handle finish reason - generate MessageDelta only (MessageStop comes later)
|
||||
if let Some(finish_reason) = &choice.finish_reason {
|
||||
// If we have usage data, it was already handled above
|
||||
// If not, we need to generate MessageDelta with default usage
|
||||
if !has_usage {
|
||||
let anthropic_stop_reason: MessagesStopReason = finish_reason.clone().into();
|
||||
return Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
});
|
||||
}
|
||||
// If usage was already handled above, we don't need to do anything more here
|
||||
// MessageStop will be handled when [DONE] is encountered
|
||||
}
|
||||
|
||||
// Default to ping for unhandled cases
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<String> for MessagesStreamEvent {
|
||||
fn into(self) -> String {
|
||||
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
|
||||
let event_type = match &self {
|
||||
MessagesStreamEvent::MessageStart { .. } => "message_start",
|
||||
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
|
||||
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
|
||||
MessagesStreamEvent::MessageStop => "message_stop",
|
||||
MessagesStreamEvent::Ping => "ping",
|
||||
};
|
||||
|
||||
let event = format!("event: {}\n", event_type);
|
||||
let data = format!("data: {}\n\n", transformed_json);
|
||||
event + &data
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
// MessageStart - convert to Anthropic MessageStart
|
||||
ConverseStreamEvent::MessageStart(start_event) => {
|
||||
let role = match start_event.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => MessagesRole::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => {
|
||||
MessagesRole::Assistant
|
||||
}
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::MessageStart {
|
||||
message: MessagesStreamMessage {
|
||||
id: format!(
|
||||
"bedrock-stream-{}",
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos()
|
||||
),
|
||||
obj_type: "message".to_string(),
|
||||
role,
|
||||
content: vec![],
|
||||
model: "bedrock-model".to_string(),
|
||||
stop_reason: None,
|
||||
stop_sequence: None,
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ContentBlockStart - convert to Anthropic ContentBlockStart
|
||||
ConverseStreamEvent::ContentBlockStart(start_event) => {
|
||||
// Note: Bedrock sends tool_use_id and name at start, with input coming in subsequent deltas
|
||||
// Anthropic expects the same pattern, so we initialize with an empty input object
|
||||
match start_event.start {
|
||||
crate::apis::amazon_bedrock::ContentBlockStart::ToolUse { tool_use } => {
|
||||
Ok(MessagesStreamEvent::ContentBlockStart {
|
||||
index: start_event.content_block_index as u32,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: tool_use.tool_use_id,
|
||||
name: tool_use.name,
|
||||
input: Value::Object(serde_json::Map::new()), // Empty - will be filled by deltas
|
||||
cache_control: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ContentBlockDelta - convert to Anthropic ContentBlockDelta
|
||||
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
|
||||
let delta = match delta_event.delta {
|
||||
ContentBlockDelta::Text { text } => MessagesContentDelta::TextDelta { text },
|
||||
ContentBlockDelta::ToolUse { tool_use } => {
|
||||
MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: tool_use.input,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: delta_event.content_block_index as u32,
|
||||
delta,
|
||||
})
|
||||
}
|
||||
|
||||
// ContentBlockStop - convert to Anthropic ContentBlockStop
|
||||
ConverseStreamEvent::ContentBlockStop(stop_event) => {
|
||||
Ok(MessagesStreamEvent::ContentBlockStop {
|
||||
index: stop_event.content_block_index as u32,
|
||||
})
|
||||
}
|
||||
|
||||
// MessageStop - convert to Anthropic MessageDelta with stop reason
|
||||
// Note: Bedrock sends Metadata separately with usage info, creating a second MessageDelta
|
||||
// The client should merge these or use the final one with complete usage
|
||||
ConverseStreamEvent::MessageStop(stop_event) => {
|
||||
let anthropic_stop_reason = match stop_event.stop_reason {
|
||||
crate::apis::amazon_bedrock::StopReason::EndTurn => MessagesStopReason::EndTurn,
|
||||
crate::apis::amazon_bedrock::StopReason::ToolUse => MessagesStopReason::ToolUse,
|
||||
crate::apis::amazon_bedrock::StopReason::MaxTokens => MessagesStopReason::MaxTokens,
|
||||
crate::apis::amazon_bedrock::StopReason::StopSequence => MessagesStopReason::EndTurn,
|
||||
crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => MessagesStopReason::Refusal,
|
||||
crate::apis::amazon_bedrock::StopReason::ContentFiltered => MessagesStopReason::Refusal,
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: anthropic_stop_reason,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Metadata - convert usage information to MessageDelta
|
||||
ConverseStreamEvent::Metadata(metadata_event) => {
|
||||
Ok(MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: MessagesStopReason::EndTurn,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: metadata_event.usage.input_tokens,
|
||||
output_tokens: metadata_event.usage.output_tokens,
|
||||
cache_creation_input_tokens: metadata_event.usage.cache_write_input_tokens,
|
||||
cache_read_input_tokens: metadata_event.usage.cache_read_input_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Exception events - convert to Ping (could be enhanced to return error events)
|
||||
ConverseStreamEvent::InternalServerException(_)
|
||||
| ConverseStreamEvent::ModelStreamErrorException(_)
|
||||
| ConverseStreamEvent::ServiceUnavailableException(_)
|
||||
| ConverseStreamEvent::ThrottlingException(_)
|
||||
| ConverseStreamEvent::ValidationException(_) => {
|
||||
// TODO: Consider adding proper error handling/events
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert tool call deltas to Anthropic stream events
|
||||
fn convert_tool_call_deltas(
|
||||
tool_calls: Vec<ToolCallDelta>,
|
||||
) -> Result<MessagesStreamEvent, TransformError> {
|
||||
for tool_call in tool_calls {
|
||||
if let Some(id) = &tool_call.id {
|
||||
// Tool call start
|
||||
if let Some(function) = &tool_call.function {
|
||||
if let Some(name) = &function.name {
|
||||
return Ok(MessagesStreamEvent::ContentBlockStart {
|
||||
index: tool_call.index,
|
||||
content_block: MessagesContentBlock::ToolUse {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input: Value::Object(serde_json::Map::new()),
|
||||
cache_control: None,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if let Some(function) = &tool_call.function {
|
||||
if let Some(arguments) = &function.arguments {
|
||||
// Tool arguments delta
|
||||
return Ok(MessagesStreamEvent::ContentBlockDelta {
|
||||
index: tool_call.index,
|
||||
delta: MessagesContentDelta::InputJsonDelta {
|
||||
partial_json: arguments.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to ping if no valid tool call found
|
||||
Ok(MessagesStreamEvent::Ping)
|
||||
}
|
||||
|
|
@ -0,0 +1,546 @@
|
|||
use crate::apis::amazon_bedrock::{ ConverseStreamEvent, StopReason};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent};
|
||||
use crate::apis::openai::{ ChatCompletionsStreamResponse,FinishReason,
|
||||
FunctionCallDelta, MessageDelta, Role, StreamChoice, ToolCallDelta, Usage,
|
||||
};
|
||||
use crate::apis::openai_responses::ResponsesAPIStreamEvent;
|
||||
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::*;
|
||||
|
||||
// ============================================================================
|
||||
// PROVIDER STREAMING TRANSFORMATIONS TO OPENAI FORMAT
|
||||
// ============================================================================
|
||||
//
|
||||
// This module handles business logic for converting streaming events from
|
||||
// various providers (Anthropic, Bedrock, etc.) into OpenAI's ChatCompletions format.
|
||||
//
|
||||
// # Architecture Separation
|
||||
//
|
||||
// **Provider Transformations** (this module):
|
||||
// - Business logic for converting between provider formats
|
||||
// - Uses Rust traits (TryFrom, Into) for type-safe conversions
|
||||
// - Stateless event-by-event transformation
|
||||
// - Example: MessagesStreamEvent → ChatCompletionsStreamResponse
|
||||
//
|
||||
// **Wire Format Buffering** (`apis/streaming_shapes/`):
|
||||
// - SSE protocol handling (data:, event: lines)
|
||||
// - State accumulation and lifecycle management
|
||||
// - Buffering for stateful APIs (v1/responses)
|
||||
// - Example: ChatCompletionsToResponsesTransformer
|
||||
//
|
||||
// # Flow
|
||||
//
|
||||
// ```text
|
||||
// Anthropic Event → [Provider Transform] → OpenAI Event → [Wire Buffer] → SSE Wire Format
|
||||
// (business) (this module) (protocol) (streaming_shapes) (network)
|
||||
// ```
|
||||
//
|
||||
// ============================================================================
|
||||
|
||||
impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: MessagesStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
MessagesStreamEvent::MessageStart { message } => Ok(create_openai_chunk(
|
||||
&message.id,
|
||||
&message.model,
|
||||
MessageDelta {
|
||||
role: Some(Role::Assistant),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStart { content_block, index } => {
|
||||
convert_content_block_start(content_block, index)
|
||||
}
|
||||
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, index } => convert_content_delta(delta, index),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
|
||||
|
||||
MessagesStreamEvent::MessageDelta { delta, usage } => {
|
||||
let finish_reason: Option<FinishReason> = Some(delta.stop_reason.into());
|
||||
let openai_usage: Option<Usage> = Some(usage.into());
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason,
|
||||
openai_usage,
|
||||
))
|
||||
}
|
||||
|
||||
MessagesStreamEvent::MessageStop => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(FinishReason::Stop),
|
||||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::Ping => Ok(ChatCompletionsStreamResponse {
|
||||
id: "stream".to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: "unknown".to_string(),
|
||||
choices: vec![],
|
||||
usage: None,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(event: ConverseStreamEvent) -> Result<Self, Self::Error> {
|
||||
match event {
|
||||
ConverseStreamEvent::MessageStart(start_event) => {
|
||||
let role = match start_event.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => Role::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: Some(role),
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockStart(start_event) => {
|
||||
use crate::apis::amazon_bedrock::ContentBlockStart;
|
||||
|
||||
match start_event.start {
|
||||
ContentBlockStart::ToolUse { tool_use } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: start_event.content_block_index as u32,
|
||||
id: Some(tool_use.tool_use_id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(tool_use.name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockDelta(delta_event) => {
|
||||
use crate::apis::amazon_bedrock::ContentBlockDelta;
|
||||
|
||||
match delta_event.delta {
|
||||
ContentBlockDelta::Text { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
ContentBlockDelta::ToolUse { tool_use } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: delta_event.content_block_index as u32,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(tool_use.input),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamEvent::ContentBlockStop(_) => Ok(create_empty_openai_chunk()),
|
||||
|
||||
ConverseStreamEvent::MessageStop(stop_event) => {
|
||||
let finish_reason = match stop_event.stop_reason {
|
||||
StopReason::EndTurn => FinishReason::Stop,
|
||||
StopReason::ToolUse => FinishReason::ToolCalls,
|
||||
StopReason::MaxTokens => FinishReason::Length,
|
||||
StopReason::StopSequence => FinishReason::Stop,
|
||||
StopReason::GuardrailIntervened => FinishReason::ContentFilter,
|
||||
StopReason::ContentFiltered => FinishReason::ContentFilter,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some(finish_reason),
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
ConverseStreamEvent::Metadata(metadata_event) => {
|
||||
let usage = Usage {
|
||||
prompt_tokens: metadata_event.usage.input_tokens,
|
||||
completion_tokens: metadata_event.usage.output_tokens,
|
||||
total_tokens: metadata_event.usage.total_tokens,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
};
|
||||
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
Some(usage),
|
||||
))
|
||||
}
|
||||
|
||||
// Error events - convert to empty chunks (errors should be handled elsewhere)
|
||||
ConverseStreamEvent::InternalServerException(_)
|
||||
| ConverseStreamEvent::ModelStreamErrorException(_)
|
||||
| ConverseStreamEvent::ServiceUnavailableException(_)
|
||||
| ConverseStreamEvent::ThrottlingException(_)
|
||||
| ConverseStreamEvent::ValidationException(_) => Ok(create_empty_openai_chunk()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content block start to OpenAI chunk
|
||||
fn convert_content_block_start(
|
||||
content_block: MessagesContentBlock,
|
||||
index: u32,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match content_block {
|
||||
MessagesContentBlock::Text { .. } => {
|
||||
// No immediate output for text block start
|
||||
Ok(create_empty_openai_chunk())
|
||||
}
|
||||
MessagesContentBlock::ToolUse { id, name, .. }
|
||||
| MessagesContentBlock::ServerToolUse { id, name, .. }
|
||||
| MessagesContentBlock::McpToolUse { id, name, .. } => {
|
||||
// Tool use start → OpenAI chunk with tool_calls
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index,
|
||||
id: Some(id),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(name),
|
||||
arguments: Some("".to_string()),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
_ => Err(TransformError::UnsupportedContent(
|
||||
"Unsupported content block type in stream start".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert content delta to OpenAI chunk
|
||||
fn convert_content_delta(
|
||||
delta: MessagesContentDelta,
|
||||
index: u32,
|
||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||
match delta {
|
||||
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::ThinkingDelta { thinking } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: Some(format!("thinking: {}", thinking)),
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::InputJsonDelta { partial_json } => Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: Some(FunctionCallDelta {
|
||||
name: None,
|
||||
arguments: Some(partial_json),
|
||||
}),
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
MessagesContentDelta::SignatureDelta { signature: _ } => {
|
||||
// Signature delta is cryptographic verification metadata, not content
|
||||
// Create an empty delta chunk to maintain stream continuity
|
||||
Ok(create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create OpenAI streaming chunk
|
||||
fn create_openai_chunk(
|
||||
id: &str,
|
||||
model: &str,
|
||||
delta: MessageDelta,
|
||||
finish_reason: Option<FinishReason>,
|
||||
usage: Option<Usage>,
|
||||
) -> ChatCompletionsStreamResponse {
|
||||
ChatCompletionsStreamResponse {
|
||||
id: id.to_string(),
|
||||
object: Some("chat.completion.chunk".to_string()),
|
||||
created: current_timestamp(),
|
||||
model: model.to_string(),
|
||||
choices: vec![StreamChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
finish_reason,
|
||||
logprobs: None,
|
||||
}],
|
||||
usage,
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create empty OpenAI streaming chunk
|
||||
fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
|
||||
create_openai_chunk(
|
||||
"stream",
|
||||
"unknown",
|
||||
MessageDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
refusal: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
// Stop Reason Conversions
|
||||
impl Into<FinishReason> for MessagesStopReason {
|
||||
fn into(self) -> FinishReason {
|
||||
match self {
|
||||
MessagesStopReason::EndTurn => FinishReason::Stop,
|
||||
MessagesStopReason::MaxTokens => FinishReason::Length,
|
||||
MessagesStopReason::StopSequence => FinishReason::Stop,
|
||||
MessagesStopReason::ToolUse => FinishReason::ToolCalls,
|
||||
MessagesStopReason::PauseTurn => FinishReason::Stop,
|
||||
MessagesStopReason::Refusal => FinishReason::ContentFilter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionsStreamResponse> for ResponsesAPIStreamEvent {
|
||||
type Error = TransformError;
|
||||
|
||||
fn try_from(chunk: ChatCompletionsStreamResponse) -> Result<Self, TransformError> {
|
||||
// Stateless conversion - just extract the delta information
|
||||
// The buffer will manage state, item IDs, and sequence numbers
|
||||
|
||||
// Extract first choice if available
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
let delta = &choice.delta;
|
||||
|
||||
// Tool call with function name and/or arguments
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
if let Some(tool_call) = tool_calls.first() {
|
||||
// Extract call_id and name if available (metadata from initial event)
|
||||
let call_id = tool_call.id.clone();
|
||||
let function_name = tool_call.function.as_ref()
|
||||
.and_then(|f| f.name.clone());
|
||||
|
||||
// Check if we have function metadata (name, id)
|
||||
if let Some(function) = &tool_call.function {
|
||||
// If we have arguments delta, return that
|
||||
if let Some(args) = &function.arguments {
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
output_index: choice.index as i32,
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
delta: args.clone(),
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
call_id,
|
||||
name: function_name,
|
||||
});
|
||||
}
|
||||
|
||||
// If we have function name but no arguments yet (initial tool call event)
|
||||
// Return an empty arguments delta so the buffer knows to create the item
|
||||
if function.name.is_some() {
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
output_index: choice.index as i32,
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
delta: "".to_string(), // Empty delta signals this is the initial event
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
call_id,
|
||||
name: function_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Text content delta
|
||||
if let Some(content) = &delta.content {
|
||||
if !content.is_empty() {
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta {
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
output_index: choice.index as i32,
|
||||
content_index: 0,
|
||||
delta: content.clone(),
|
||||
logprobs: vec![],
|
||||
obfuscation: None,
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle finish_reason - this is a completion signal
|
||||
// Return an empty delta that the buffer can use to detect completion
|
||||
if choice.finish_reason.is_some() {
|
||||
// Return a minimal text delta to signal completion
|
||||
// The buffer will handle the finish_reason and generate response.completed
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta {
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
output_index: choice.index as i32,
|
||||
content_index: 0,
|
||||
delta: "".to_string(), // Empty delta signals completion
|
||||
logprobs: vec![],
|
||||
obfuscation: None,
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
});
|
||||
}
|
||||
|
||||
// Empty delta with role only (common at stream start)
|
||||
if delta.role.is_some() {
|
||||
// This is typically the first chunk establishing the assistant role
|
||||
// Return an empty text delta that the buffer can use to initialize state
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseOutputTextDelta {
|
||||
item_id: "".to_string(),
|
||||
output_index: choice.index as i32,
|
||||
content_index: 0,
|
||||
delta: "".to_string(),
|
||||
logprobs: vec![],
|
||||
obfuscation: None,
|
||||
sequence_number: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Empty chunk or no convertible content (e.g., keep-alive chunks with delta: {})
|
||||
// These are valid in OpenAI streaming and should be silently ignored
|
||||
// Return error so the caller can skip these chunks without warnings
|
||||
Err(TransformError::UnsupportedConversion(
|
||||
"Empty or keep-alive chunk with no convertible content".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
@ -2,26 +2,18 @@ use crate::metrics::Metrics;
|
|||
use crate::stream_context::StreamContext;
|
||||
use common::configuration::Configuration;
|
||||
use common::configuration::Overrides;
|
||||
use common::consts::OTEL_COLLECTOR_HTTP;
|
||||
use common::consts::OTEL_POST_PATH;
|
||||
use common::http::CallArgs;
|
||||
use common::http::Client;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use common::ratelimit;
|
||||
use common::stats::Gauge;
|
||||
use common::tracing::TraceData;
|
||||
use log::trace;
|
||||
use log::warn;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CallContext {}
|
||||
|
||||
|
|
@ -31,7 +23,6 @@ pub struct FilterContext {
|
|||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: RefCell<HashMap<u32, CallContext>>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
}
|
||||
|
||||
|
|
@ -41,7 +32,6 @@ impl FilterContext {
|
|||
callouts: RefCell::new(HashMap::new()),
|
||||
metrics: Rc::new(Metrics::new()),
|
||||
llm_providers: None,
|
||||
traces_queue: Arc::new(Mutex::new(VecDeque::new())),
|
||||
overrides: Rc::new(None),
|
||||
}
|
||||
}
|
||||
|
|
@ -95,7 +85,6 @@ impl RootContext for FilterContext {
|
|||
.as_ref()
|
||||
.expect("LLM Providers must exist when Streams are being created"),
|
||||
),
|
||||
Arc::clone(&self.traces_queue),
|
||||
Rc::clone(&self.overrides),
|
||||
)))
|
||||
}
|
||||
|
|
@ -108,34 +97,6 @@ impl RootContext for FilterContext {
|
|||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
let _ = self.traces_queue.try_lock().map(|mut traces_queue| {
|
||||
while let Some(trace) = traces_queue.pop_front() {
|
||||
let trace_str = serde_json::to_string(&trace).unwrap();
|
||||
trace!("trace details: {}", trace_str);
|
||||
let call_args = CallArgs::new(
|
||||
OTEL_COLLECTOR_HTTP,
|
||||
OTEL_POST_PATH,
|
||||
vec![
|
||||
(":method", http::Method::POST.as_str()),
|
||||
(":path", OTEL_POST_PATH),
|
||||
(":authority", OTEL_COLLECTOR_HTTP),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(trace_str.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(60),
|
||||
);
|
||||
if let Err(error) = self.http_call(call_args, CallContext {}) {
|
||||
warn!(
|
||||
"failed to schedule http call to otel-collector: {:?}",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
|
|
|
|||
|
|
@ -4,10 +4,8 @@ use log::{debug, info, warn};
|
|||
use proxy_wasm::hostcalls::get_current_time;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
|
@ -20,13 +18,13 @@ use common::errors::ServerError;
|
|||
use common::llm_providers::LlmProviders;
|
||||
use common::ratelimit::Header;
|
||||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::apis::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
use hermesllm::apis::anthropic::{MessagesContentBlock, MessagesStreamEvent};
|
||||
use hermesllm::apis::sse::{SseEvent, SseStreamIter};
|
||||
use hermesllm::clients::endpoints::SupportedAPIs;
|
||||
use hermesllm::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
use hermesllm::apis::streaming_shapes::sse::{SseEvent, SseStreamBuffer, SseStreamBufferTrait};
|
||||
use hermesllm::apis::streaming_shapes::sse_chunk_processor::SseChunkProcessor;
|
||||
use hermesllm::clients::endpoints::SupportedAPIsFromClient;
|
||||
use hermesllm::providers::response::ProviderResponse;
|
||||
use hermesllm::providers::streaming_response::ProviderStreamResponse;
|
||||
use hermesllm::{
|
||||
DecodedFrame, ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType,
|
||||
ProviderStreamResponseType,
|
||||
|
|
@ -38,7 +36,7 @@ pub struct StreamContext {
|
|||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
/// The API that is requested by the client (before compatibility mapping)
|
||||
client_api: Option<SupportedAPIs>,
|
||||
client_api: Option<SupportedAPIsFromClient>,
|
||||
/// The API that should be used for the upstream provider (after compatibility mapping)
|
||||
resolved_api: Option<SupportedUpstreamAPIs>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
|
|
@ -49,20 +47,20 @@ pub struct StreamContext {
|
|||
ttft_time: Option<u128>,
|
||||
traceparent: Option<String>,
|
||||
request_body_sent_time: Option<u128>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
user_message: Option<String>,
|
||||
upstream_status_code: Option<StatusCode>,
|
||||
binary_frame_decoder: Option<BedrockBinaryFrameDecoder<bytes::BytesMut>>,
|
||||
http_method: Option<String>,
|
||||
http_protocol: Option<String>,
|
||||
sse_buffer: Option<SseStreamBuffer>,
|
||||
sse_chunk_processor: Option<SseChunkProcessor>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
pub fn new(
|
||||
metrics: Rc<Metrics>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
|
|
@ -80,13 +78,14 @@ impl StreamContext {
|
|||
ttft_duration: None,
|
||||
traceparent: None,
|
||||
ttft_time: None,
|
||||
traces_queue,
|
||||
request_body_sent_time: None,
|
||||
user_message: None,
|
||||
upstream_status_code: None,
|
||||
binary_frame_decoder: None,
|
||||
http_method: None,
|
||||
http_protocol: None,
|
||||
sse_buffer: None,
|
||||
sse_chunk_processor: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -140,7 +139,7 @@ impl StreamContext {
|
|||
));
|
||||
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}] PROVIDER_SELECTION: Hint='{}' -> Selected='{}'",
|
||||
"[PLANO_REQ_ID:{}] PROVIDER_SELECTION: Hint='{}' -> Selected='{}'",
|
||||
self.request_identifier(),
|
||||
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.unwrap_or("none".to_string()),
|
||||
|
|
@ -172,7 +171,8 @@ impl StreamContext {
|
|||
Some(
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_)
|
||||
| SupportedUpstreamAPIs::AmazonBedrockConverse(_)
|
||||
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
| SupportedUpstreamAPIs::AmazonBedrockConverseStream(_)
|
||||
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
)
|
||||
| None => {
|
||||
// OpenAI and default: use Authorization Bearer token
|
||||
|
|
@ -224,7 +224,7 @@ impl StreamContext {
|
|||
let token_count = tokenizer::token_count(model, json_string).unwrap_or(0);
|
||||
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] TOKEN_COUNT: model='{}' input_tokens={}",
|
||||
"[PLANO_REQ_ID:{}] TOKEN_COUNT: model='{}' input_tokens={}",
|
||||
self.request_identifier(),
|
||||
model,
|
||||
token_count
|
||||
|
|
@ -238,7 +238,7 @@ impl StreamContext {
|
|||
// Check if rate limiting needs to be applied.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}] RATELIMIT_CHECK: model='{}' selector='{}:{}'",
|
||||
"[PLANO_REQ_ID:{}] RATELIMIT_CHECK: model='{}' selector='{}:{}'",
|
||||
self.request_identifier(),
|
||||
model,
|
||||
selector.key,
|
||||
|
|
@ -251,7 +251,7 @@ impl StreamContext {
|
|||
)?;
|
||||
} else {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] RATELIMIT_SKIP: model='{}' (no selector)",
|
||||
"[PLANO_REQ_ID:{}] RATELIMIT_SKIP: model='{}' (no selector)",
|
||||
self.request_identifier(),
|
||||
model
|
||||
);
|
||||
|
|
@ -270,7 +270,7 @@ impl StreamContext {
|
|||
Ok(duration) => {
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}] TIME_TO_FIRST_TOKEN: {}ms",
|
||||
"[PLANO_REQ_ID:{}] TIME_TO_FIRST_TOKEN: {}ms",
|
||||
self.request_identifier(),
|
||||
duration_ms
|
||||
);
|
||||
|
|
@ -279,7 +279,7 @@ impl StreamContext {
|
|||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] TIME_MEASUREMENT_ERROR: {:?}",
|
||||
"[PLANO_REQ_ID:{}] TIME_MEASUREMENT_ERROR: {:?}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
|
|
@ -295,7 +295,7 @@ impl StreamContext {
|
|||
// Convert the duration to milliseconds
|
||||
let duration_ms = duration.as_millis();
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}] REQUEST_COMPLETE: latency={}ms tokens={}",
|
||||
"[PLANO_REQ_ID:{}] REQUEST_COMPLETE: latency={}ms tokens={}",
|
||||
self.request_identifier(),
|
||||
duration_ms,
|
||||
self.response_tokens
|
||||
|
|
@ -311,7 +311,7 @@ impl StreamContext {
|
|||
self.metrics.time_per_output_token.record(tpot);
|
||||
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}] TOKEN_THROUGHPUT: time_per_token={}ms tokens_per_second={}",
|
||||
"[PLANO_REQ_ID:{}] TOKEN_THROUGHPUT: time_per_token={}ms tokens_per_second={}",
|
||||
self.request_identifier(),
|
||||
tpot,
|
||||
1000 / tpot
|
||||
|
|
@ -328,75 +328,13 @@ impl StreamContext {
|
|||
self.metrics
|
||||
.output_sequence_length
|
||||
.record(self.response_tokens as u64);
|
||||
|
||||
if let Some(traceparent) = self.traceparent.as_ref() {
|
||||
let current_time_ns = current_time_ns();
|
||||
|
||||
match Traceparent::try_from(traceparent.to_string()) {
|
||||
Err(e) => {
|
||||
warn!("traceparent header is invalid: {}", e);
|
||||
}
|
||||
Ok(traceparent) => {
|
||||
let service_name = match &self.resolved_api {
|
||||
Some(api) => {
|
||||
let api_display = api.to_string();
|
||||
format!("archgw.{}", api_display)
|
||||
}
|
||||
None => "archgw".to_string(),
|
||||
};
|
||||
|
||||
let mut trace_data =
|
||||
common::tracing::TraceData::new_with_service_name(service_name);
|
||||
let mut llm_span = Span::new(
|
||||
self.llm_provider().name.to_string(),
|
||||
Some(traceparent.trace_id),
|
||||
Some(traceparent.parent_id),
|
||||
self.request_body_sent_time.unwrap(),
|
||||
current_time_ns,
|
||||
);
|
||||
llm_span
|
||||
.add_attribute("model".to_string(), self.llm_provider().name.to_string());
|
||||
|
||||
if let Some(user_message) = &self.user_message {
|
||||
llm_span.add_attribute("message".to_string(), user_message.clone());
|
||||
}
|
||||
|
||||
// Add HTTP attributes
|
||||
if let Some(method) = &self.http_method {
|
||||
llm_span.add_attribute("http.method".to_string(), method.clone());
|
||||
}
|
||||
if let Some(protocol) = &self.http_protocol {
|
||||
llm_span.add_attribute("http.protocol".to_string(), protocol.clone());
|
||||
}
|
||||
if let Some(status_code) = &self.upstream_status_code {
|
||||
llm_span.add_attribute(
|
||||
"http.status_code".to_string(),
|
||||
status_code.as_u16().to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
// Add request ID attribute
|
||||
llm_span
|
||||
.add_attribute("http.request_id".to_string(), self.request_identifier());
|
||||
|
||||
if self.ttft_time.is_some() {
|
||||
llm_span.add_event(Event::new(
|
||||
"time_to_first_token".to_string(),
|
||||
self.ttft_time.unwrap(),
|
||||
));
|
||||
}
|
||||
trace_data.add_span(llm_span);
|
||||
self.traces_queue.lock().unwrap().push_back(trace_data);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn read_raw_response_body(&mut self, body_size: usize) -> Result<Vec<u8>, Action> {
|
||||
if self.streaming_response {
|
||||
let chunk_size = body_size;
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_CHUNK: streaming=true chunk_size={}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_CHUNK: streaming=true chunk_size={}",
|
||||
self.request_identifier(),
|
||||
chunk_size
|
||||
);
|
||||
|
|
@ -404,7 +342,7 @@ impl StreamContext {
|
|||
Some(chunk) => chunk,
|
||||
None => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_ERROR: empty chunk, size={}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_ERROR: empty chunk, size={}",
|
||||
self.request_identifier(),
|
||||
chunk_size
|
||||
);
|
||||
|
|
@ -414,7 +352,7 @@ impl StreamContext {
|
|||
|
||||
if streaming_chunk.len() != chunk_size {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_MISMATCH: expected={} actual={}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_MISMATCH: expected={} actual={}",
|
||||
self.request_identifier(),
|
||||
chunk_size,
|
||||
streaming_chunk.len()
|
||||
|
|
@ -426,7 +364,7 @@ impl StreamContext {
|
|||
return Err(Action::Continue);
|
||||
}
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_COMPLETE: streaming=false body_size={}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_COMPLETE: streaming=false body_size={}",
|
||||
self.request_identifier(),
|
||||
body_size
|
||||
);
|
||||
|
|
@ -446,7 +384,7 @@ impl StreamContext {
|
|||
provider_id: ProviderId,
|
||||
) -> Result<Vec<u8>, Action> {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] STREAMING_PROCESS: client={:?} provider_id={:?} chunk_size={}",
|
||||
"[PLANO_REQ_ID:{}] STREAMING_PROCESS: client={:?} provider_id={:?} chunk_size={}",
|
||||
self.request_identifier(),
|
||||
self.client_api,
|
||||
provider_id,
|
||||
|
|
@ -466,30 +404,59 @@ impl StreamContext {
|
|||
return self.handle_bedrock_binary_stream(body, &client_api, &upstream_api);
|
||||
}
|
||||
|
||||
// Parse body into SSE iterator using TryFrom
|
||||
let sse_iter: SseStreamIter<std::vec::IntoIter<String>> =
|
||||
match SseStreamIter::try_from(body) {
|
||||
Ok(iter) => iter,
|
||||
// Initialize SSE chunk processor if not present
|
||||
if self.sse_chunk_processor.is_none() {
|
||||
self.sse_chunk_processor = Some(SseChunkProcessor::new());
|
||||
}
|
||||
|
||||
// Initialize SSE buffer if not present
|
||||
if self.sse_buffer.is_none() {
|
||||
self.sse_buffer = match SseStreamBuffer::try_from((&client_api, &upstream_api))
|
||||
{
|
||||
Ok(buffer) => Some(buffer),
|
||||
Err(e) => {
|
||||
warn!("Failed to parse body into SSE iterator: {}", e);
|
||||
warn!("Failed to create SSE buffer: {}", e);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let mut response_buffer = Vec::new();
|
||||
// Process chunk through SSE processor (handles incomplete events)
|
||||
let transformed_events = match self.sse_chunk_processor.as_mut() {
|
||||
Some(processor) => {
|
||||
let result = processor.process_chunk(body, &client_api, &upstream_api);
|
||||
let has_buffered = processor.has_buffered_data();
|
||||
let buffered_size = processor.buffered_size();
|
||||
|
||||
// Process each SSE event
|
||||
for sse_event in sse_iter {
|
||||
// Transform event if upstream API != client API
|
||||
let transformed_event: SseEvent =
|
||||
match SseEvent::try_from((sse_event, &client_api, &upstream_api)) {
|
||||
Ok(event) => event,
|
||||
match result {
|
||||
Ok(events) => {
|
||||
if has_buffered {
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] SSE_INCOMPLETE_BUFFERED: {} bytes buffered for next chunk",
|
||||
self.request_identifier(),
|
||||
buffered_size
|
||||
);
|
||||
}
|
||||
events
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to transform SSE event: {}", e);
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] SSE_CHUNK_PROCESS_ERROR: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("SSE chunk processor unexpectedly missing");
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
};
|
||||
|
||||
// Process each successfully transformed SSE event
|
||||
for transformed_event in transformed_events {
|
||||
// Extract ProviderStreamResponse for processing (token counting, etc.)
|
||||
if !transformed_event.is_done() && !transformed_event.is_event_only() {
|
||||
match transformed_event.provider_response() {
|
||||
|
|
@ -498,7 +465,7 @@ impl StreamContext {
|
|||
|
||||
if provider_response.is_final() {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] STREAMING_FINAL_CHUNK: total_tokens={}",
|
||||
"[PLANO_REQ_ID:{}] STREAMING_FINAL_CHUNK: total_tokens={}",
|
||||
self.request_identifier(),
|
||||
self.response_tokens
|
||||
);
|
||||
|
|
@ -508,7 +475,7 @@ impl StreamContext {
|
|||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] STREAMING_TOKEN_UPDATE: delta_chars={} estimated_tokens={} total_tokens={}",
|
||||
"[PLANO_REQ_ID:{}] STREAMING_TOKEN_UPDATE: delta_chars={} estimated_tokens={} total_tokens={}",
|
||||
self.request_identifier(),
|
||||
content.len(),
|
||||
estimated_tokens.max(1),
|
||||
|
|
@ -518,7 +485,7 @@ impl StreamContext {
|
|||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] STREAMING_CHUNK_ERROR: {}",
|
||||
"[PLANO_REQ_ID:{}] STREAMING_CHUNK_ERROR: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
|
|
@ -527,12 +494,32 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
// Add transformed event to response buffer
|
||||
let bytes: Vec<u8> = transformed_event.into();
|
||||
response_buffer.extend_from_slice(&bytes);
|
||||
// Add transformed event to buffer (buffer may inject lifecycle events)
|
||||
if let Some(buffer) = self.sse_buffer.as_mut() {
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response_buffer)
|
||||
// Get accumulated bytes from buffer and return
|
||||
match self.sse_buffer.as_mut() {
|
||||
Some(buffer) => {
|
||||
let bytes = buffer.into_bytes();
|
||||
if !bytes.is_empty() {
|
||||
let content = String::from_utf8_lossy(&bytes);
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_TRANSFORMED_CLIENT_RESPONSE: size={} content={}",
|
||||
self.request_identifier(),
|
||||
bytes.len(),
|
||||
content
|
||||
);
|
||||
}
|
||||
Ok(bytes)
|
||||
}
|
||||
None => {
|
||||
warn!("SSE buffer unexpectedly missing after initialization");
|
||||
Err(Action::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("Missing client_api for non-streaming response");
|
||||
|
|
@ -544,7 +531,7 @@ impl StreamContext {
|
|||
fn handle_bedrock_binary_stream(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
client_api: &SupportedAPIs,
|
||||
client_api: &SupportedAPIsFromClient,
|
||||
upstream_api: &SupportedUpstreamAPIs,
|
||||
) -> Result<Vec<u8>, Action> {
|
||||
// Initialize decoder if not present
|
||||
|
|
@ -552,87 +539,61 @@ impl StreamContext {
|
|||
self.binary_frame_decoder = Some(BedrockBinaryFrameDecoder::from_bytes(&[]));
|
||||
}
|
||||
|
||||
// Add incoming bytes to buffer
|
||||
// Initialize SSE buffer if not present
|
||||
if self.sse_buffer.is_none() {
|
||||
self.sse_buffer = match SseStreamBuffer::try_from((client_api, upstream_api)) {
|
||||
Ok(buffer) => Some(buffer),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] BEDROCK_BUFFER_INIT_ERROR: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Add incoming bytes to decoder buffer
|
||||
let decoder = self.binary_frame_decoder.as_mut().unwrap();
|
||||
decoder.buffer_mut().extend_from_slice(body);
|
||||
|
||||
let mut response_buffer = Vec::new();
|
||||
// Process all complete frames
|
||||
loop {
|
||||
let decoded_frame = self.binary_frame_decoder.as_mut().unwrap().decode_frame();
|
||||
match decoded_frame {
|
||||
Some(DecodedFrame::Complete(ref frame_ref)) => {
|
||||
let frame = DecodedFrame::Complete(frame_ref.clone());
|
||||
|
||||
// Convert frame to provider response type
|
||||
match ProviderStreamResponseType::try_from((&frame, client_api, upstream_api)) {
|
||||
Ok(provider_response) => {
|
||||
self.record_ttft_if_needed();
|
||||
|
||||
// Handle ContentBlockStart and ContentBlockDelta events
|
||||
match &provider_response {
|
||||
ProviderStreamResponseType::MessagesStreamEvent(evt) => {
|
||||
match evt {
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index, ..
|
||||
} => {
|
||||
// Mark that we've seen ContentBlockStart for this index
|
||||
self.binary_frame_decoder
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.set_content_block_start_sent(*index as i32);
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_CONTENT_BLOCK_START_TRACKED: index={}",
|
||||
self.request_identifier(),
|
||||
*index
|
||||
);
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockDelta {
|
||||
index, ..
|
||||
} => {
|
||||
// Check if ContentBlockStart was sent for this index
|
||||
let needs_start = !self
|
||||
.binary_frame_decoder
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.has_content_block_start_been_sent(*index as i32);
|
||||
|
||||
if needs_start {
|
||||
// Emit empty ContentBlockStart before delta
|
||||
let content_block_start =
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
index: *index,
|
||||
content_block: MessagesContentBlock::Text {
|
||||
text: String::new(),
|
||||
cache_control: None,
|
||||
},
|
||||
};
|
||||
let start_sse: String = content_block_start.into();
|
||||
response_buffer
|
||||
.extend_from_slice(start_sse.as_bytes());
|
||||
|
||||
// Mark that we've now sent it
|
||||
self.binary_frame_decoder
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.set_content_block_start_sent(*index as i32);
|
||||
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_INJECTED_CONTENT_BLOCK_START: index={}",
|
||||
self.request_identifier(),
|
||||
*index
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
// Track token usage
|
||||
if let Some(content) = provider_response.content_delta() {
|
||||
let estimated_tokens = content.len() / 4;
|
||||
self.response_tokens += estimated_tokens.max(1);
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] BEDROCK_TOKEN_UPDATE: delta_chars={} estimated_tokens={} total_tokens={}",
|
||||
self.request_identifier(),
|
||||
content.len(),
|
||||
estimated_tokens.max(1),
|
||||
self.response_tokens
|
||||
);
|
||||
}
|
||||
|
||||
let sse_string: String = provider_response.into();
|
||||
response_buffer.extend_from_slice(sse_string.as_bytes());
|
||||
// Create SseEvent from provider response
|
||||
let event = SseEvent::from_provider_response(provider_response);
|
||||
|
||||
// Add to buffer (buffer handles all shim logic including ContentBlockStart injection)
|
||||
if let Some(buffer) = self.sse_buffer.as_mut() {
|
||||
buffer.add_transformed_event(event);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}",
|
||||
"[PLANO_REQ_ID:{}] BEDROCK_FRAME_CONVERSION_ERROR: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
|
|
@ -642,7 +603,7 @@ impl StreamContext {
|
|||
Some(DecodedFrame::Incomplete) => {
|
||||
// Incomplete frame - buffer retains partial data, wait for more bytes
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data",
|
||||
"[PLANO_REQ_ID:{}] BEDROCK_INCOMPLETE_FRAME: waiting for more data",
|
||||
self.request_identifier()
|
||||
);
|
||||
break;
|
||||
|
|
@ -650,7 +611,7 @@ impl StreamContext {
|
|||
None => {
|
||||
// Decode error
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] BEDROCK_DECODE_ERROR",
|
||||
"[PLANO_REQ_ID:{}] BEDROCK_DECODE_ERROR",
|
||||
self.request_identifier()
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
|
|
@ -658,8 +619,29 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
// Return accumulated complete frames (may be empty if all frames incomplete)
|
||||
Ok(response_buffer)
|
||||
// Get accumulated bytes from buffer and return
|
||||
match self.sse_buffer.as_mut() {
|
||||
Some(buffer) => {
|
||||
let bytes = buffer.into_bytes();
|
||||
if !bytes.is_empty() {
|
||||
let content = String::from_utf8_lossy(&bytes);
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_TRANSFORMED_CLIENT_RESPONSE: size={} content={}",
|
||||
self.request_identifier(),
|
||||
bytes.len(),
|
||||
content
|
||||
);
|
||||
}
|
||||
Ok(bytes)
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] BEDROCK_BUFFER_MISSING",
|
||||
self.request_identifier()
|
||||
);
|
||||
Err(Action::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_non_streaming_response(
|
||||
|
|
@ -668,7 +650,7 @@ impl StreamContext {
|
|||
provider_id: ProviderId,
|
||||
) -> Result<Vec<u8>, Action> {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] NON_STREAMING_PROCESS: provider_id={:?} body_size={}",
|
||||
"[PLANO_REQ_ID:{}] NON_STREAMING_PROCESS: provider_id={:?} body_size={}",
|
||||
self.request_identifier(),
|
||||
provider_id,
|
||||
body.len()
|
||||
|
|
@ -680,7 +662,7 @@ impl StreamContext {
|
|||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_PARSE_ERROR: {} | body: {}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_PARSE_ERROR: {} | body: {}",
|
||||
self.request_identifier(),
|
||||
e,
|
||||
String::from_utf8_lossy(body)
|
||||
|
|
@ -695,7 +677,7 @@ impl StreamContext {
|
|||
}
|
||||
None => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_ERROR: missing client_api",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_ERROR: missing client_api",
|
||||
self.request_identifier()
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
|
|
@ -707,7 +689,7 @@ impl StreamContext {
|
|||
response.extract_usage_counts()
|
||||
{
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] RESPONSE_USAGE: prompt_tokens={} completion_tokens={} total_tokens={}",
|
||||
"[PLANO_REQ_ID:{}] RESPONSE_USAGE: prompt_tokens={} completion_tokens={} total_tokens={}",
|
||||
self.request_identifier(),
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
|
|
@ -716,7 +698,7 @@ impl StreamContext {
|
|||
self.response_tokens = completion_tokens;
|
||||
} else {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] RESPONSE_USAGE: no usage information found",
|
||||
"[PLANO_REQ_ID:{}] RESPONSE_USAGE: no usage information found",
|
||||
self.request_identifier()
|
||||
);
|
||||
}
|
||||
|
|
@ -724,7 +706,7 @@ impl StreamContext {
|
|||
match serde_json::to_vec(&response) {
|
||||
Ok(bytes) => {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] CLIENT_RESPONSE_PAYLOAD: {}",
|
||||
"[PLANO_REQ_ID:{}] CLIENT_RESPONSE_PAYLOAD: {}",
|
||||
self.request_identifier(),
|
||||
String::from_utf8_lossy(&bytes)
|
||||
);
|
||||
|
|
@ -782,13 +764,14 @@ impl HttpContext for StreamContext {
|
|||
self.select_llm_provider();
|
||||
|
||||
// Check if this is a supported API endpoint
|
||||
if SupportedAPIs::from_endpoint(&request_path).is_none() {
|
||||
if SupportedAPIsFromClient::from_endpoint(&request_path).is_none() {
|
||||
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Get the SupportedApi for routing decisions
|
||||
let supported_api: Option<SupportedAPIs> = SupportedAPIs::from_endpoint(&request_path);
|
||||
let supported_api: Option<SupportedAPIsFromClient> =
|
||||
SupportedAPIsFromClient::from_endpoint(&request_path);
|
||||
self.client_api = supported_api;
|
||||
|
||||
// Debug: log provider, client API, resolved API, and request path
|
||||
|
|
@ -800,7 +783,7 @@ impl HttpContext for StreamContext {
|
|||
Some(provider_id.compatible_api_for_client(api, self.streaming_response));
|
||||
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] ROUTING_INFO: provider='{}' client_api={:?} resolved_api={:?} request_path='{}'",
|
||||
"[PLANO_REQ_ID:{}] ROUTING_INFO: provider='{}' client_api={:?} resolved_api={:?} request_path='{}'",
|
||||
self.request_identifier(),
|
||||
provider.to_provider_id(),
|
||||
api,
|
||||
|
|
@ -853,7 +836,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] REQUEST_BODY_CHUNK: bytes={} end_stream={}",
|
||||
"[PLANO_REQ_ID:{}] REQUEST_BODY_CHUNK: bytes={} end_stream={}",
|
||||
self.request_identifier(),
|
||||
body_size,
|
||||
end_of_stream
|
||||
|
|
@ -892,14 +875,14 @@ impl HttpContext for StreamContext {
|
|||
let mut deserialized_client_request: ProviderRequestType = match self.client_api.as_ref() {
|
||||
Some(the_client_api) => {
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}] CLIENT_REQUEST_RECEIVED: api={:?} body_size={}",
|
||||
"[PLANO_REQ_ID:{}] CLIENT_REQUEST_RECEIVED: api={:?} body_size={}",
|
||||
self.request_identifier(),
|
||||
the_client_api,
|
||||
body_bytes.len()
|
||||
);
|
||||
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] CLIENT_REQUEST_PAYLOAD: {}",
|
||||
"[PLANO_REQ_ID:{}] CLIENT_REQUEST_PAYLOAD: {}",
|
||||
self.request_identifier(),
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
);
|
||||
|
|
@ -908,7 +891,7 @@ impl HttpContext for StreamContext {
|
|||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] CLIENT_REQUEST_PARSE_ERROR: {} | body: {}",
|
||||
"[PLANO_REQ_ID:{}] CLIENT_REQUEST_PARSE_ERROR: {} | body: {}",
|
||||
self.request_identifier(),
|
||||
e,
|
||||
String::from_utf8_lossy(&body_bytes)
|
||||
|
|
@ -951,7 +934,7 @@ impl HttpContext for StreamContext {
|
|||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
warn!(
|
||||
"[ARCHGW_REQ_ID:{}] MODEL_RESOLUTION_ERROR: no model specified | req_model='{}' provider='{}' config_model={:?}",
|
||||
"[PLANO_REQ_ID:{}] MODEL_RESOLUTION_ERROR: no model specified | req_model='{}' provider='{}' config_model={:?}",
|
||||
self.request_identifier(),
|
||||
model_requested,
|
||||
self.llm_provider().name,
|
||||
|
|
@ -980,7 +963,7 @@ impl HttpContext for StreamContext {
|
|||
self.user_message = deserialized_client_request.get_recent_user_message();
|
||||
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}] MODEL_RESOLUTION: req_model='{}' -> resolved_model='{}' provider='{}' streaming={}",
|
||||
"[PLANO_REQ_ID:{}] MODEL_RESOLUTION: req_model='{}' -> resolved_model='{}' provider='{}' streaming={}",
|
||||
self.request_identifier(),
|
||||
model_requested,
|
||||
resolved_model,
|
||||
|
|
@ -1011,14 +994,14 @@ impl HttpContext for StreamContext {
|
|||
match self.resolved_api.as_ref() {
|
||||
Some(upstream) => {
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_TRANSFORM: client_api={:?} -> upstream_api={:?}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_TRANSFORM: client_api={:?} -> upstream_api={:?}",
|
||||
self.request_identifier(), self.client_api, upstream
|
||||
);
|
||||
|
||||
match ProviderRequestType::try_from((deserialized_client_request, upstream)) {
|
||||
Ok(request) => {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_REQUEST_PAYLOAD: {}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_REQUEST_PAYLOAD: {}",
|
||||
self.request_identifier(),
|
||||
String::from_utf8_lossy(&request.to_bytes().unwrap_or_default())
|
||||
);
|
||||
|
|
@ -1069,7 +1052,7 @@ impl HttpContext for StreamContext {
|
|||
self.upstream_status_code = StatusCode::from_u16(status_code).ok();
|
||||
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_RESPONSE_STATUS: {}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_RESPONSE_STATUS: {}",
|
||||
self.request_identifier(),
|
||||
status_code
|
||||
);
|
||||
|
|
@ -1096,7 +1079,7 @@ impl HttpContext for StreamContext {
|
|||
let current_time = get_current_time().unwrap();
|
||||
if end_of_stream && body_size == 0 {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] RESPONSE_BODY_COMPLETE: total_bytes={}",
|
||||
"[PLANO_REQ_ID:{}] RESPONSE_BODY_COMPLETE: total_bytes={}",
|
||||
self.request_identifier(),
|
||||
body_size
|
||||
);
|
||||
|
|
@ -1108,7 +1091,7 @@ impl HttpContext for StreamContext {
|
|||
if let Some(status_code) = &self.upstream_status_code {
|
||||
if status_code.is_client_error() || status_code.is_server_error() {
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_ERROR_RESPONSE: status={} body_size={}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_ERROR_RESPONSE: status={} body_size={}",
|
||||
self.request_identifier(),
|
||||
status_code.as_u16(),
|
||||
body_size
|
||||
|
|
@ -1118,7 +1101,7 @@ impl HttpContext for StreamContext {
|
|||
if body_size > 0 {
|
||||
if let Ok(body) = self.read_raw_response_body(body_size) {
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_ERROR_BODY: {}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_ERROR_BODY: {}",
|
||||
self.request_identifier(),
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
|
|
@ -1131,15 +1114,16 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
|
||||
match self.client_api {
|
||||
Some(SupportedAPIs::OpenAIChatCompletions(_)) => {}
|
||||
Some(SupportedAPIs::AnthropicMessagesAPI(_)) => {}
|
||||
Some(SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {}
|
||||
Some(SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {}
|
||||
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {}
|
||||
_ => {
|
||||
let api_info = match &self.client_api {
|
||||
Some(api) => format!("{}", api),
|
||||
None => "None".to_string(),
|
||||
};
|
||||
info!(
|
||||
"[ARCHGW_REQ_ID:{}], UNSUPPORTED API: {}",
|
||||
"[PLANO_REQ_ID:{}], UNSUPPORTED API: {}",
|
||||
self.request_identifier(),
|
||||
api_info
|
||||
);
|
||||
|
|
@ -1153,7 +1137,7 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
debug!(
|
||||
"[ARCHGW_REQ_ID:{}] UPSTREAM_RAW_RESPONSE: body_size={} content={}",
|
||||
"[PLANO_REQ_ID:{}] UPSTREAM_RAW_RESPONSE: body_size={} content={}",
|
||||
self.request_identifier(),
|
||||
body.len(),
|
||||
String::from_utf8_lossy(&body)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue