mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Compare commits
70 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
473ec70b5c | ||
|
|
dafd245332 | ||
|
|
897fda2deb | ||
|
|
5a652eb666 | ||
|
|
b81eb7266c | ||
|
|
78dc4edad9 | ||
|
|
aa726b1bba | ||
|
|
c8079ac971 | ||
|
|
6701195a5d | ||
|
|
22f332f62d | ||
|
|
9812540602 | ||
|
|
c3c213b2fd | ||
|
|
78d8c90184 | ||
|
|
ffea891dba | ||
|
|
e7464b817a | ||
|
|
254d2b03bc | ||
|
|
95a7beaab3 | ||
|
|
37600fd07a | ||
|
|
0f67b2c806 | ||
|
|
1f701258cb | ||
|
|
711e4dd07d | ||
|
|
743d074184 | ||
|
|
d39d7ddd1c | ||
|
|
90b926c2ce | ||
|
|
980faef6be | ||
|
|
128059e7c1 | ||
|
|
8dedf0bec1 | ||
|
|
978b1ea722 | ||
|
|
9406af3a09 | ||
|
|
aa16a6dc4b | ||
|
|
7606c55b4b | ||
|
|
1d3f4d6c05 | ||
|
|
5d79e7a7d4 | ||
|
|
76ff353c1e | ||
|
|
39b430d74b | ||
|
|
0857cfafbf | ||
|
|
f68c21f8df | ||
|
|
3dbda9741e | ||
|
|
d8f4fd76e3 | ||
|
|
36fa42b364 | ||
|
|
82f34f82f2 | ||
|
|
f019f05738 | ||
|
|
af98c11a6d | ||
|
|
e5751d6b13 | ||
|
|
3a531ce22a | ||
|
|
406fa92802 | ||
|
|
69df124c47 | ||
|
|
180a9cb748 | ||
|
|
cdad02c5ee | ||
|
|
1ad3e0f64e | ||
|
|
1f23c573bf | ||
|
|
de2d8847f3 | ||
|
|
5388c6777f | ||
|
|
f1b8c03e2f | ||
|
|
4bb5c6404f | ||
|
|
5d85829bc2 | ||
|
|
bc059aed4d | ||
|
|
785bf7e021 | ||
|
|
2f52774c0e | ||
|
|
5400b0a2fa | ||
|
|
b4313d93a4 | ||
|
|
6610097659 | ||
|
|
5189f7907a | ||
|
|
97b7a390ef | ||
|
|
028a2cd196 | ||
|
|
b9f01c8471 | ||
|
|
065328e11c | ||
|
|
780a0af132 | ||
|
|
a1508f4de1 | ||
|
|
8f7a8a8a17 |
331 changed files with 40807 additions and 17606 deletions
12
.claude/skills/build-brightstaff/SKILL.md
Normal file
12
.claude/skills/build-brightstaff/SKILL.md
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
---
|
||||
name: build-brightstaff
|
||||
description: Build the brightstaff native binary. Use when brightstaff code changes.
|
||||
---
|
||||
|
||||
Build brightstaff:
|
||||
|
||||
```
|
||||
cd crates && cargo build --release -p brightstaff
|
||||
```
|
||||
|
||||
If the build fails, diagnose and fix the errors.
|
||||
10
.claude/skills/build-cli/SKILL.md
Normal file
10
.claude/skills/build-cli/SKILL.md
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
---
|
||||
name: build-cli
|
||||
description: Build and install the Python CLI (planoai). Use after making changes to cli/ code to install locally.
|
||||
---
|
||||
|
||||
1. `cd cli && uv sync` — ensure dependencies are installed
|
||||
2. `cd cli && uv tool install --editable .` — install the CLI locally
|
||||
3. Verify the installation: `cd cli && uv run planoai --help`
|
||||
|
||||
If the build or install fails, diagnose and fix the issues.
|
||||
12
.claude/skills/build-wasm/SKILL.md
Normal file
12
.claude/skills/build-wasm/SKILL.md
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
---
|
||||
name: build-wasm
|
||||
description: Build the WASM plugins for Envoy. Use when WASM plugin code changes.
|
||||
---
|
||||
|
||||
Build the WASM plugins:
|
||||
|
||||
```
|
||||
cd crates && cargo build --release --target=wasm32-wasip1 -p llm_gateway -p prompt_gateway
|
||||
```
|
||||
|
||||
If the build fails, diagnose and fix the errors.
|
||||
12
.claude/skills/check/SKILL.md
Normal file
12
.claude/skills/check/SKILL.md
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
---
|
||||
name: check
|
||||
description: Run Rust fmt, clippy, and unit tests. Use after making Rust code changes.
|
||||
---
|
||||
|
||||
Run all local checks in order:
|
||||
|
||||
1. `cd crates && cargo fmt --all -- --check` — if formatting fails, run `cargo fmt --all` to fix it
|
||||
2. `cd crates && cargo clippy --locked --all-targets --all-features -- -D warnings` — fix any warnings
|
||||
3. `cd crates && cargo test --lib` — ensure all unit tests pass
|
||||
|
||||
Report a summary of what passed/failed.
|
||||
17
.claude/skills/new-provider/SKILL.md
Normal file
17
.claude/skills/new-provider/SKILL.md
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
---
|
||||
name: new-provider
|
||||
description: Add a new LLM provider to hermesllm. Use when integrating a new AI provider.
|
||||
disable-model-invocation: true
|
||||
user-invocable: true
|
||||
---
|
||||
|
||||
Add a new LLM provider to hermesllm. The user will provide the provider name as $ARGUMENTS.
|
||||
|
||||
1. Add a new variant to `ProviderId` enum in `crates/hermesllm/src/providers/id.rs`
|
||||
2. Implement string parsing in the `TryFrom<&str>` impl for the new provider
|
||||
3. If the provider uses a non-OpenAI API format, create request/response types in `crates/hermesllm/src/apis/`
|
||||
4. Add variant to `ProviderRequestType` and `ProviderResponseType` enums and update all match arms
|
||||
5. Add model list to `crates/hermesllm/src/providers/provider_models.yaml`
|
||||
6. Update `SupportedUpstreamAPIs` mapping if needed
|
||||
|
||||
After making changes, run `cd crates && cargo test --lib` to verify everything compiles and tests pass.
|
||||
16
.claude/skills/pr/SKILL.md
Normal file
16
.claude/skills/pr/SKILL.md
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
---
|
||||
name: pr
|
||||
description: Create a feature branch and open a pull request for the current changes.
|
||||
disable-model-invocation: true
|
||||
user-invocable: true
|
||||
---
|
||||
|
||||
Create a pull request for the current changes:
|
||||
|
||||
1. Determine the GitHub username via `gh api user --jq .login`. If the login is `adilhafeez`, use `adil` instead.
|
||||
2. Create a feature branch using format `<username>/<feature_name>` — infer the feature name from the changes
|
||||
3. Run `cd crates && cargo fmt --all -- --check` and `cd crates && cargo clippy --locked --all-targets --all-features -- -D warnings` to verify Rust code is clean
|
||||
4. Commit all changes with a short, concise commit message (one line, no Co-Authored-By)
|
||||
5. Push the branch and create a PR targeting `main`
|
||||
|
||||
Keep the PR title short (under 70 chars). Include a brief summary in the body. Never include a "Test plan" section or any "Generated with Claude Code" attribution.
|
||||
30
.claude/skills/release/SKILL.md
Normal file
30
.claude/skills/release/SKILL.md
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
---
|
||||
name: release
|
||||
description: Bump the Plano version across all required files. Use when preparing a release.
|
||||
disable-model-invocation: true
|
||||
user-invocable: true
|
||||
---
|
||||
|
||||
Prepare a release version bump. The user may provide the new version number as $ARGUMENTS (e.g., `/release 0.4.12`), or a bump type (`major`, `minor`, `patch`).
|
||||
|
||||
If no argument is provided, read the current version from `cli/planoai/__init__.py`, auto-increment the patch version (e.g., `0.4.11` → `0.4.12`), and confirm with the user before proceeding.
|
||||
|
||||
Update the version string in ALL of these files:
|
||||
|
||||
- `.github/workflows/ci.yml`
|
||||
- `cli/planoai/__init__.py`
|
||||
- `cli/planoai/consts.py`
|
||||
- `cli/pyproject.toml`
|
||||
- `build_filter_image.sh`
|
||||
- `config/validate_plano_config.sh`
|
||||
- `docs/source/conf.py`
|
||||
- `docs/source/get_started/quickstart.rst`
|
||||
- `docs/source/resources/deployment.rst`
|
||||
- `apps/www/src/components/Hero.tsx`
|
||||
- `demos/llm_routing/preference_based_routing/README.md`
|
||||
|
||||
Do NOT change version strings in `*.lock` files or `Cargo.lock`.
|
||||
|
||||
After updating all version strings, run `cd cli && uv lock` to update the lock file with the new version.
|
||||
|
||||
After making changes, show a summary of all files modified and the old → new version.
|
||||
9
.claude/skills/test-python/SKILL.md
Normal file
9
.claude/skills/test-python/SKILL.md
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
---
|
||||
name: test-python
|
||||
description: Run Python CLI tests. Use after making changes to cli/ code.
|
||||
---
|
||||
|
||||
1. `cd cli && uv sync` — ensure dependencies are installed
|
||||
2. `cd cli && uv run pytest -v` — run all tests
|
||||
|
||||
If tests fail, diagnose and fix the issues.
|
||||
10
.github/workflows/ci.yml
vendored
10
.github/workflows/ci.yml
vendored
|
|
@ -133,13 +133,13 @@ jobs:
|
|||
load: true
|
||||
tags: |
|
||||
${{ env.PLANO_DOCKER_IMAGE }}
|
||||
${{ env.DOCKER_IMAGE }}:0.4.10
|
||||
${{ env.DOCKER_IMAGE }}:0.4.21
|
||||
${{ env.DOCKER_IMAGE }}:latest
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Save image as artifact
|
||||
run: docker save ${{ env.PLANO_DOCKER_IMAGE }} ${{ env.DOCKER_IMAGE }}:0.4.10 ${{ env.DOCKER_IMAGE }}:latest -o /tmp/plano-image.tar
|
||||
run: docker save ${{ env.PLANO_DOCKER_IMAGE }} ${{ env.DOCKER_IMAGE }}:0.4.21 ${{ env.DOCKER_IMAGE }}:latest -o /tmp/plano-image.tar
|
||||
|
||||
- name: Upload image artifact
|
||||
uses: actions/upload-artifact@v6
|
||||
|
|
@ -163,7 +163,7 @@ jobs:
|
|||
python-version: "3.14"
|
||||
|
||||
- name: Install planoai
|
||||
run: pip install ./cli
|
||||
run: pip install -e ./cli
|
||||
|
||||
- name: Validate plano config
|
||||
run: bash config/validate_plano_config.sh
|
||||
|
|
@ -477,7 +477,7 @@ jobs:
|
|||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
run: |
|
||||
source venv/bin/activate
|
||||
cd demos/shared/test_runner && sh run_demo_tests.sh llm_routing/preference_based_routing
|
||||
cd demos/shared/test_runner && bash run_demo_tests.sh llm_routing/preference_based_routing
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# E2E: demo — currency conversion
|
||||
|
|
@ -527,4 +527,4 @@ jobs:
|
|||
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
|
||||
run: |
|
||||
source venv/bin/activate
|
||||
cd demos/shared/test_runner && sh run_demo_tests.sh advanced/currency_exchange
|
||||
cd demos/shared/test_runner && bash run_demo_tests.sh advanced/currency_exchange
|
||||
|
|
|
|||
80
.github/workflows/docker-push-release.yml
vendored
80
.github/workflows/docker-push-release.yml
vendored
|
|
@ -3,6 +3,8 @@ name: Publish docker image (release)
|
|||
env:
|
||||
DOCKER_IMAGE: katanemo/plano
|
||||
GHCR_IMAGE: ghcr.io/${{ github.repository_owner }}/plano
|
||||
DOCR_IMAGE: registry.digitalocean.com/genai-prod/plano
|
||||
DOCR_PREVIEW_IMAGE: registry.digitalocean.com/genai-preview/plano
|
||||
|
||||
on:
|
||||
release:
|
||||
|
|
@ -33,6 +35,14 @@ jobs:
|
|||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install doctl
|
||||
uses: digitalocean/action-doctl@v2
|
||||
with:
|
||||
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN }}
|
||||
|
||||
- name: Log in to DOCR (prod)
|
||||
run: doctl registry login
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
|
|
@ -51,6 +61,21 @@ jobs:
|
|||
tags: |
|
||||
${{ steps.meta.outputs.tags }}-arm64
|
||||
${{ env.GHCR_IMAGE }}:${{ github.event.release.tag_name }}-arm64
|
||||
${{ env.DOCR_IMAGE }}:${{ github.event.release.tag_name }}-arm64
|
||||
|
||||
- name: Switch to DOCR preview
|
||||
uses: digitalocean/action-doctl@v2
|
||||
with:
|
||||
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN_PREVIEW }}
|
||||
|
||||
- name: Log in to DOCR (preview)
|
||||
run: doctl registry login
|
||||
|
||||
- name: Push to DOCR Preview
|
||||
run: |
|
||||
docker buildx imagetools create \
|
||||
-t ${{ env.DOCR_PREVIEW_IMAGE }}:${{ github.event.release.tag_name }}-arm64 \
|
||||
${{ steps.meta.outputs.tags }}-arm64
|
||||
|
||||
# Build AMD64 image on GitHub's AMD64 runner — push to both registries
|
||||
build-amd64:
|
||||
|
|
@ -72,6 +97,14 @@ jobs:
|
|||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install doctl
|
||||
uses: digitalocean/action-doctl@v2
|
||||
with:
|
||||
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN }}
|
||||
|
||||
- name: Log in to DOCR (prod)
|
||||
run: doctl registry login
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
|
|
@ -90,6 +123,21 @@ jobs:
|
|||
tags: |
|
||||
${{ steps.meta.outputs.tags }}-amd64
|
||||
${{ env.GHCR_IMAGE }}:${{ github.event.release.tag_name }}-amd64
|
||||
${{ env.DOCR_IMAGE }}:${{ github.event.release.tag_name }}-amd64
|
||||
|
||||
- name: Switch to DOCR preview
|
||||
uses: digitalocean/action-doctl@v2
|
||||
with:
|
||||
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN_PREVIEW }}
|
||||
|
||||
- name: Log in to DOCR (preview)
|
||||
run: doctl registry login
|
||||
|
||||
- name: Push to DOCR Preview
|
||||
run: |
|
||||
docker buildx imagetools create \
|
||||
-t ${{ env.DOCR_PREVIEW_IMAGE }}:${{ github.event.release.tag_name }}-amd64 \
|
||||
${{ steps.meta.outputs.tags }}-amd64
|
||||
|
||||
# Combine ARM64 and AMD64 images into multi-arch manifests for both registries
|
||||
create-manifest:
|
||||
|
|
@ -109,6 +157,14 @@ jobs:
|
|||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install doctl
|
||||
uses: digitalocean/action-doctl@v2
|
||||
with:
|
||||
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN }}
|
||||
|
||||
- name: Log in to DOCR (prod)
|
||||
run: doctl registry login
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
|
|
@ -131,3 +187,27 @@ jobs:
|
|||
-t ${{ env.GHCR_IMAGE }}:${TAG} \
|
||||
${{ env.GHCR_IMAGE }}:${TAG}-arm64 \
|
||||
${{ env.GHCR_IMAGE }}:${TAG}-amd64
|
||||
|
||||
- name: Create DOCR Prod Multi-Arch Manifest
|
||||
run: |
|
||||
TAG=${{ github.event.release.tag_name }}
|
||||
docker buildx imagetools create \
|
||||
-t ${{ env.DOCR_IMAGE }}:${TAG} \
|
||||
${{ env.DOCR_IMAGE }}:${TAG}-arm64 \
|
||||
${{ env.DOCR_IMAGE }}:${TAG}-amd64
|
||||
|
||||
- name: Switch to DOCR preview
|
||||
uses: digitalocean/action-doctl@v2
|
||||
with:
|
||||
token: ${{ secrets.PLATFORM_DIGITALOCEAN_TOKEN_PREVIEW }}
|
||||
|
||||
- name: Log in to DOCR (preview)
|
||||
run: doctl registry login
|
||||
|
||||
- name: Create DOCR Preview Multi-Arch Manifest
|
||||
run: |
|
||||
TAG=${{ github.event.release.tag_name }}
|
||||
docker buildx imagetools create \
|
||||
-t ${{ env.DOCR_PREVIEW_IMAGE }}:${TAG} \
|
||||
${{ steps.meta.outputs.tags }}-arm64 \
|
||||
${{ steps.meta.outputs.tags }}-amd64
|
||||
|
|
|
|||
2
.github/workflows/publish-pypi.yml
vendored
2
.github/workflows/publish-pypi.yml
vendored
|
|
@ -30,7 +30,7 @@ jobs:
|
|||
enable-cache: true
|
||||
|
||||
- name: Build package
|
||||
run: uv build
|
||||
run: uv build --wheel
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -152,3 +152,4 @@ apps/*/dist/
|
|||
|
||||
.cursor/
|
||||
.agents
|
||||
docs/do/
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ repos:
|
|||
hooks:
|
||||
- id: check-yaml
|
||||
exclude: config/envoy.template*
|
||||
args: [--allow-multiple-documents]
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: local
|
||||
|
|
@ -36,7 +37,7 @@ repos:
|
|||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.1.0
|
||||
rev: 26.3.1
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
|
|
|
|||
152
CLAUDE.md
152
CLAUDE.md
|
|
@ -1,152 +1,106 @@
|
|||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
Plano is an AI-native proxy server and data plane for agentic applications, built on Envoy proxy. It centralizes agent orchestration, LLM routing, observability, and safety guardrails as an out-of-process dataplane.
|
||||
|
||||
## Build & Test Commands
|
||||
|
||||
### Rust (crates/)
|
||||
|
||||
```bash
|
||||
# Build WASM plugins (must target wasm32-wasip1)
|
||||
# Rust — WASM plugins (must target wasm32-wasip1)
|
||||
cd crates && cargo build --release --target=wasm32-wasip1 -p llm_gateway -p prompt_gateway
|
||||
|
||||
# Build brightstaff binary (native target)
|
||||
# Rust — brightstaff binary (native target)
|
||||
cd crates && cargo build --release -p brightstaff
|
||||
|
||||
# Run unit tests
|
||||
# Rust — tests, format, lint
|
||||
cd crates && cargo test --lib
|
||||
|
||||
# Format check
|
||||
cd crates && cargo fmt --all -- --check
|
||||
|
||||
# Lint
|
||||
cd crates && cargo clippy --locked --all-targets --all-features -- -D warnings
|
||||
```
|
||||
|
||||
### Python CLI (cli/)
|
||||
# Python CLI
|
||||
cd cli && uv sync && uv run pytest -v
|
||||
|
||||
```bash
|
||||
cd cli && uv sync # Install dependencies
|
||||
cd cli && uv run pytest -v # Run tests
|
||||
cd cli && uv run planoai --help # Run CLI
|
||||
```
|
||||
# JS/TS (Turbo monorepo)
|
||||
npm run build && npm run lint && npm run typecheck
|
||||
|
||||
### JavaScript/TypeScript (apps/, packages/)
|
||||
|
||||
```bash
|
||||
npm run build # Build all (via Turbo)
|
||||
npm run lint # Lint all
|
||||
npm run dev # Dev servers
|
||||
npm run typecheck # Type check
|
||||
```
|
||||
|
||||
### Pre-commit (runs fmt, clippy, cargo test, black, yaml checks)
|
||||
|
||||
```bash
|
||||
# Pre-commit (fmt, clippy, cargo test, black, yaml)
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Docker
|
||||
|
||||
```bash
|
||||
# Docker
|
||||
docker build -t katanemo/plano:latest .
|
||||
```
|
||||
|
||||
### E2E Tests (tests/e2e/)
|
||||
|
||||
E2E tests require a built Docker image and API keys. They run via `tests/e2e/run_e2e_tests.sh` which executes four test suites: `test_prompt_gateway.py`, `test_model_alias_routing.py`, `test_openai_responses_api_client.py`, and `test_openai_responses_api_client_with_state.py`.
|
||||
E2E tests require a Docker image and API keys: `tests/e2e/run_e2e_tests.sh`
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Data Flow
|
||||
|
||||
Requests flow through Envoy proxy with two WASM filter plugins, backed by a native Rust binary:
|
||||
|
||||
```
|
||||
Client → Envoy (prompt_gateway.wasm → llm_gateway.wasm) → Agents/LLM Providers
|
||||
↕
|
||||
brightstaff (native binary: state, routing, signals, tracing)
|
||||
```
|
||||
|
||||
### Rust Crates (crates/)
|
||||
### Crates (crates/)
|
||||
|
||||
All crates share a Cargo workspace. Two compile to `wasm32-wasip1` for Envoy, the rest are native:
|
||||
|
||||
- **prompt_gateway** (WASM) — Proxy-WASM filter for prompt/message processing, guardrails, and filter chains
|
||||
- **prompt_gateway** (WASM) — Proxy-WASM filter for prompt processing, guardrails, filter chains
|
||||
- **llm_gateway** (WASM) — Proxy-WASM filter for LLM request/response handling and routing
|
||||
- **brightstaff** (native binary) — Core application server: handlers, router, signals, state management, tracing
|
||||
- **common** (library) — Shared across all crates: configuration, LLM provider abstractions, HTTP utilities, routing logic, rate limiting, tokenizer, PII detection, tracing
|
||||
- **hermesllm** (library) — Translates LLM API formats between providers (OpenAI, Anthropic, Gemini, Mistral, Grok, AWS Bedrock, Azure, together.ai). Key types: `ProviderId`, `ProviderRequest`, `ProviderResponse`, `ProviderStreamResponse`
|
||||
- **brightstaff** (native) — Core server: handlers, router, signals, state, tracing
|
||||
- **common** (lib) — Shared: config, HTTP, routing, rate limiting, tokenizer, PII, tracing
|
||||
- **hermesllm** (lib) — LLM API translation between providers. Key types: `ProviderId`, `ProviderRequest`, `ProviderResponse`, `ProviderStreamResponse`
|
||||
|
||||
### Python CLI (cli/planoai/)
|
||||
|
||||
The `planoai` CLI manages the Plano lifecycle. Key commands:
|
||||
- `planoai up <config.yaml>` — Validate config, check API keys, start Docker container
|
||||
- `planoai down` — Stop container
|
||||
- `planoai build` — Build Docker image from repo root
|
||||
- `planoai logs` — Stream access/debug logs
|
||||
- `planoai trace` — OTEL trace collection and analysis
|
||||
- `planoai init` — Initialize new project
|
||||
- `planoai cli_agent` — Start a CLI agent connected to Plano
|
||||
- `planoai generate_prompt_targets` — Generate prompt_targets from python methods
|
||||
Entry point: `main.py`. Built with `rich-click`. Commands: `up`, `down`, `build`, `logs`, `trace`, `init`, `cli_agent`, `generate_prompt_targets`.
|
||||
|
||||
Entry point: `cli/planoai/main.py`. Container lifecycle in `core.py`. Docker operations in `docker_cli.py`.
|
||||
### Config (config/)
|
||||
|
||||
### Configuration System (config/)
|
||||
- `plano_config_schema.yaml` — JSON Schema for validating user configs
|
||||
- `envoy.template.yaml` — Jinja2 template → Envoy config
|
||||
- `supervisord.conf` — Process supervisor for Envoy + brightstaff
|
||||
|
||||
- `plano_config_schema.yaml` — JSON Schema (draft-07) for validating user config files
|
||||
- `envoy.template.yaml` — Jinja2 template rendered into Envoy proxy config
|
||||
- `supervisord.conf` — Process supervisor for Envoy + brightstaff in the container
|
||||
### JS Apps (apps/, packages/)
|
||||
|
||||
User configs define: `agents` (id + url), `model_providers` (model + access_key), `listeners` (type: agent/model/prompt, with router strategy), `filters` (filter chains), and `tracing` settings.
|
||||
Turbo monorepo with Next.js 16 / React 19. Not part of the core proxy.
|
||||
|
||||
### JavaScript Apps (apps/, packages/)
|
||||
## WASM Plugin Rules
|
||||
|
||||
Turbo monorepo with Next.js 16 / React 19 applications and shared packages (UI components, Tailwind config, TypeScript config). Not part of the core proxy — these are web applications.
|
||||
Code in `prompt_gateway` and `llm_gateway` runs in Envoy's WASM sandbox:
|
||||
|
||||
- **No std networking/filesystem** — use proxy-wasm host calls only
|
||||
- **No tokio/async** — synchronous, callback-driven. `Action::Pause` / `Action::Continue` for flow control
|
||||
- **Lifecycle**: `RootContext` → `on_configure`, `create_http_context`; `HttpContext` → `on_http_request/response_headers/body`
|
||||
- **HTTP callouts**: `dispatch_http_call()` → store context in `callouts: RefCell<HashMap<u32, CallContext>>` → match in `on_http_call_response()`
|
||||
- **Config**: `Rc`-wrapped, loaded once in `on_configure()` via `serde_yaml::from_slice()`
|
||||
- **Dependencies must be no_std compatible** (e.g., `governor` with `features = ["no_std"]`)
|
||||
- **Crate type**: `cdylib` → produces `.wasm`
|
||||
|
||||
## Adding a New LLM Provider
|
||||
|
||||
1. Add variant to `ProviderId` in `crates/hermesllm/src/providers/id.rs` + `TryFrom<&str>`
|
||||
2. Create request/response types in `crates/hermesllm/src/apis/` if non-OpenAI format
|
||||
3. Add variant to `ProviderRequestType`/`ProviderResponseType` enums, update all match arms
|
||||
4. Add models to `crates/hermesllm/src/providers/provider_models.yaml`
|
||||
5. Update `SupportedUpstreamAPIs` mapping if needed
|
||||
|
||||
## Release Process
|
||||
|
||||
To prepare a release (e.g., bumping from `0.4.6` to `0.4.7`), update the version string in all of the following files:
|
||||
Update version (e.g., `0.4.11` → `0.4.12`) in all of these files:
|
||||
|
||||
**CI Workflow:**
|
||||
- `.github/workflows/ci.yml` — docker build/save tags
|
||||
- `.github/workflows/ci.yml`, `build_filter_image.sh`, `config/validate_plano_config.sh`
|
||||
- `cli/planoai/__init__.py`, `cli/planoai/consts.py`, `cli/pyproject.toml`
|
||||
- `docs/source/conf.py`, `docs/source/get_started/quickstart.rst`, `docs/source/resources/deployment.rst`
|
||||
- `apps/www/src/components/Hero.tsx`, `demos/llm_routing/preference_based_routing/README.md`
|
||||
|
||||
**CLI:**
|
||||
- `cli/planoai/__init__.py` — `__version__`
|
||||
- `cli/planoai/consts.py` — `PLANO_DOCKER_IMAGE` default
|
||||
- `cli/pyproject.toml` — `version`
|
||||
|
||||
**Build & Config:**
|
||||
- `build_filter_image.sh` — docker build tag
|
||||
- `config/validate_plano_config.sh` — docker image tag
|
||||
|
||||
**Docs:**
|
||||
- `docs/source/conf.py` — `release`
|
||||
- `docs/source/get_started/quickstart.rst` — install commands and example output
|
||||
- `docs/source/resources/deployment.rst` — docker image tag
|
||||
|
||||
**Website & Demos:**
|
||||
- `apps/www/src/components/Hero.tsx` — version badge
|
||||
- `demos/llm_routing/preference_based_routing/README.md` — example output
|
||||
|
||||
**Important:** Do NOT change `0.4.6` references in `*.lock` files or `Cargo.lock` — those refer to the `colorama` and `http-body` dependency versions, not Plano.
|
||||
|
||||
Commit message format: `release X.Y.Z`
|
||||
Do NOT change version strings in `*.lock` files or `Cargo.lock`. Commit message: `release X.Y.Z`
|
||||
|
||||
## Workflow Preferences
|
||||
|
||||
- **Git commits:** Do NOT add `Co-Authored-By` lines. Keep commit messages short and concise (one line, no verbose descriptions). NEVER commit and push directly to `main`—always use a feature branch and PR.
|
||||
- **Git branches:** Use the format `<github_username>/<feature_name>` when creating branches for PRs. Determine the username from `gh api user --jq .login`.
|
||||
- **GitHub issues:** When a GitHub issue URL is pasted, fetch all requirements and context from the issue first. The end goal is always a PR with all tests passing.
|
||||
- **Commits:** No `Co-Authored-By`. Short one-line messages. Never push directly to `main` — always feature branch + PR.
|
||||
- **Branches:** Use `adil/<feature_name>` format.
|
||||
- **Issues:** When a GitHub issue URL is pasted, fetch all context first. Goal is always a PR with passing tests.
|
||||
|
||||
## Key Conventions
|
||||
|
||||
- Rust edition 2021, formatted with `cargo fmt`, linted with `cargo clippy -D warnings`
|
||||
- Python formatted with Black
|
||||
- WASM plugins must target `wasm32-wasip1` — they run inside Envoy, not as native binaries
|
||||
- The Docker image bundles Envoy + WASM plugins + brightstaff + Python CLI into a single container managed by supervisord
|
||||
- API keys come from environment variables or `.env` files, never hardcoded
|
||||
- Rust edition 2021, `cargo fmt`, `cargo clippy -D warnings`
|
||||
- Python: Black. Rust errors: `thiserror` with `#[from]`
|
||||
- API keys from env vars or `.env`, never hardcoded
|
||||
- Provider dispatch: `ProviderRequestType`/`ProviderResponseType` enums implementing `ProviderRequest`/`ProviderResponse` traits
|
||||
|
|
|
|||
|
|
@ -49,7 +49,8 @@ FROM python:3.14-slim AS arch
|
|||
|
||||
RUN set -eux; \
|
||||
apt-get update; \
|
||||
apt-get install -y --no-install-recommends gettext-base curl; \
|
||||
apt-get upgrade -y; \
|
||||
apt-get install -y --no-install-recommends gettext-base curl procps; \
|
||||
apt-get clean; rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --no-cache-dir supervisor
|
||||
|
|
@ -72,7 +73,7 @@ COPY cli/README.md ./
|
|||
COPY config/plano_config_schema.yaml /config/plano_config_schema.yaml
|
||||
COPY config/envoy.template.yaml /config/envoy.template.yaml
|
||||
|
||||
RUN uv run pip install --no-cache-dir .
|
||||
RUN pip install --no-cache-dir -e .
|
||||
|
||||
COPY cli/planoai planoai/
|
||||
COPY config/envoy.template.yaml .
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ Plano pulls rote plumbing out of your framework so you can stay focused on what
|
|||
**Jump to our [docs](https://docs.planoai.dev)** to learn how you can use Plano to improve the speed, safety and obervability of your agentic applications.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Plano and the Arch family of LLMs (like Plano-Orchestrator-4B, Arch-Router, etc) are hosted free of charge in the US-central region to give you a great first-run developer experience of Plano. To scale and run in production, you can either run these LLMs locally or contact us on [Discord](https://discord.gg/pGZf2gcwEc) for API keys.
|
||||
> Plano and the Plano family of LLMs (like Plano-Orchestrator) are hosted free of charge in the US-central region to give you a great first-run developer experience of Plano. To scale and run in production, you can either run these LLMs locally or contact us on [Discord](https://discord.gg/pGZf2gcwEc) for API keys.
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
"clean": "rm -rf .next"
|
||||
},
|
||||
"dependencies": {
|
||||
"@heroicons/react": "^2.2.0",
|
||||
"@katanemo/shared-styles": "*",
|
||||
"@katanemo/ui": "*",
|
||||
"next": "^16.1.6",
|
||||
|
|
|
|||
|
|
@ -66,7 +66,9 @@ export default function RootLayout({
|
|||
}>) {
|
||||
return (
|
||||
<html lang="en">
|
||||
<body className={`${ibmPlexSans.variable} antialiased text-white`}>
|
||||
<body
|
||||
className={`${ibmPlexSans.variable} overflow-hidden antialiased text-white`}
|
||||
>
|
||||
{/* Google tag (gtag.js) */}
|
||||
<Script
|
||||
src="https://www.googletagmanager.com/gtag/js?id=G-RLD5BDNW5N"
|
||||
|
|
@ -80,7 +82,9 @@ export default function RootLayout({
|
|||
gtag('config', 'G-RLD5BDNW5N');
|
||||
`}
|
||||
</Script>
|
||||
<div className="min-h-screen">{children}</div>
|
||||
<div className="h-screen overflow-hidden">
|
||||
{children}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,11 +1,23 @@
|
|||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import LogoSlider from "../components/LogoSlider";
|
||||
import { ArrowRightIcon } from "@heroicons/react/16/solid";
|
||||
|
||||
export default function HomePage() {
|
||||
return (
|
||||
<main className="relative flex min-h-screen items-center justify-center overflow-hidden px-6 pt-12 pb-16 font-sans sm:pt-20 lg:items-start lg:justify-start lg:pt-24">
|
||||
<main className="relative flex h-full items-center justify-center overflow-hidden px-6 pt-12 pb-16 font-sans sm:pt-20 lg:items-start lg:justify-start lg:pt-24">
|
||||
<div className="relative mx-auto w-full max-w-6xl flex flex-col items-center justify-center text-left lg:items-start lg:justify-start">
|
||||
<Link
|
||||
href="https://digitalocean.com/blog/digitalocean-acquires-katanemo-labs-inc"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="mb-7 inline-flex max-w-[20rem] items-center gap-1 self-start rounded-full border border-[#22A875]/30 bg-[#22A875]/30 px-2.5 py-1 text-left text-[12px] leading-tight font-medium text-white transition-opacity hover:opacity-90 lg:hidden"
|
||||
>
|
||||
<span>
|
||||
DigitalOcean acquires Katanemo Labs, Inc.
|
||||
</span>
|
||||
<ArrowRightIcon aria-hidden className="h-3 w-3 shrink-0 text-white/90" />
|
||||
</Link>
|
||||
<div className="pointer-events-none mb-6 w-full self-start lg:hidden">
|
||||
<Image
|
||||
src="/KatanemoLogo.svg"
|
||||
|
|
@ -17,6 +29,20 @@ export default function HomePage() {
|
|||
/>
|
||||
</div>
|
||||
<div className="relative z-10 max-w-xl sm:max-w-2xl lg:max-w-2xl xl:max-w-8xl lg:pr-[26vw] xl:pr-[2vw] sm:right-0 md:right-0 lg:right-0 xl:right-20 2xl:right-50 sm:mt-36 mt-0">
|
||||
<Link
|
||||
href="https://digitalocean.com/blog/digitalocean-acquires-katanemo-labs-inc"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="mb-4 hidden max-w-full items-center gap-2 rounded-full border border-[#22A875]/70 bg-[#22A875]/50 px-4 py-1 text-left text-sm font-medium text-white transition-opacity hover:opacity-90 lg:inline-flex"
|
||||
>
|
||||
<span>
|
||||
DigitalOcean acquires Katanemo Labs, Inc.
|
||||
</span>
|
||||
<ArrowRightIcon
|
||||
aria-hidden
|
||||
className="h-4 w-4 shrink-0 text-white/90"
|
||||
/>
|
||||
</Link>
|
||||
<h1 className="text-3xl sm:text-4xl md:text-5xl lg:text-6xl font-sans font-medium leading-tight tracking-tight text-white">
|
||||
Forward-deployed AI infrastructure engineers.
|
||||
</h1>
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
"migrate:blogs": "tsx scripts/migrate-blogs.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@heroicons/react": "^2.2.0",
|
||||
"@katanemo/shared-styles": "*",
|
||||
"@katanemo/ui": "*",
|
||||
"@portabletext/react": "^5.0.0",
|
||||
|
|
@ -32,8 +33,8 @@
|
|||
"next": "^16.1.6",
|
||||
"next-sanity": "^11.6.9",
|
||||
"papaparse": "^5.5.3",
|
||||
"react": "19.2.0",
|
||||
"react-dom": "19.2.0",
|
||||
"react": "19.2.3",
|
||||
"react-dom": "19.2.3",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-syntax-highlighter": "^16.1.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
|
|
|
|||
|
|
@ -39,6 +39,10 @@ function loadFont(fileName: string, baseUrl: string) {
|
|||
}
|
||||
|
||||
async function getBlogPost(slug: string) {
|
||||
if (!client) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const query = `*[_type == "blog" && slug.current == $slug && published == true][0] {
|
||||
_id,
|
||||
title,
|
||||
|
|
@ -53,8 +57,13 @@ async function getBlogPost(slug: string) {
|
|||
}
|
||||
}`;
|
||||
|
||||
const post = await client.fetch(query, { slug });
|
||||
return post;
|
||||
try {
|
||||
const post = await client.fetch(query, { slug });
|
||||
return post;
|
||||
} catch (error) {
|
||||
console.error("Error fetching blog post for OG image:", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function formatDate(dateString: string): string {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@ interface BlogPost {
|
|||
}
|
||||
|
||||
async function getBlogPost(slug: string): Promise<BlogPost | null> {
|
||||
if (!client) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const query = `*[_type == "blog" && slug.current == $slug && published == true][0] {
|
||||
_id,
|
||||
title,
|
||||
|
|
@ -26,8 +30,13 @@ async function getBlogPost(slug: string): Promise<BlogPost | null> {
|
|||
author
|
||||
}`;
|
||||
|
||||
const post = await client.fetch(query, { slug });
|
||||
return post || null;
|
||||
try {
|
||||
const post = await client.fetch(query, { slug });
|
||||
return post || null;
|
||||
} catch (error) {
|
||||
console.error("Error fetching blog post metadata:", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export async function generateMetadata({
|
||||
|
|
|
|||
|
|
@ -25,6 +25,10 @@ interface BlogPost {
|
|||
}
|
||||
|
||||
async function getBlogPost(slug: string): Promise<BlogPost | null> {
|
||||
if (!client) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const query = `*[_type == "blog" && slug.current == $slug && published == true][0] {
|
||||
_id,
|
||||
title,
|
||||
|
|
@ -51,17 +55,31 @@ async function getBlogPost(slug: string): Promise<BlogPost | null> {
|
|||
author
|
||||
}`;
|
||||
|
||||
const post = await client.fetch(query, { slug });
|
||||
return post || null;
|
||||
try {
|
||||
const post = await client.fetch(query, { slug });
|
||||
return post || null;
|
||||
} catch (error) {
|
||||
console.error("Error fetching blog post:", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function getAllBlogSlugs(): Promise<string[]> {
|
||||
if (!client) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const query = `*[_type == "blog" && published == true] {
|
||||
"slug": slug.current
|
||||
}`;
|
||||
|
||||
const posts = await client.fetch(query);
|
||||
return posts.map((post: { slug: string }) => post.slug);
|
||||
try {
|
||||
const posts = await client.fetch(query);
|
||||
return posts.map((post: { slug: string }) => post.slug);
|
||||
} catch (error) {
|
||||
console.error("Error fetching blog slugs:", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export async function generateStaticParams() {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import { BlogSectionHeader } from "@/components/BlogSectionHeader";
|
|||
import { pageMetadata } from "@/lib/metadata";
|
||||
|
||||
export const metadata: Metadata = pageMetadata.blog;
|
||||
export const dynamic = "force-dynamic";
|
||||
|
||||
interface BlogPost {
|
||||
_id: string;
|
||||
|
|
@ -44,6 +45,10 @@ function formatDate(dateString: string): string {
|
|||
}
|
||||
|
||||
async function getBlogPosts(): Promise<BlogPost[]> {
|
||||
if (!client) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const query = `*[_type == "blog" && published == true] | order(publishedAt desc) {
|
||||
_id,
|
||||
title,
|
||||
|
|
@ -58,12 +63,48 @@ async function getBlogPosts(): Promise<BlogPost[]> {
|
|||
featured
|
||||
}`;
|
||||
|
||||
return await client.fetch(query);
|
||||
try {
|
||||
return await client.fetch(query);
|
||||
} catch (error) {
|
||||
console.error("Error fetching blog posts:", error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
async function getFeaturedBlogPost(): Promise<BlogPost | null> {
|
||||
if (!client) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const query = `*[_type == "blog" && published == true && featured == true] | order(_updatedAt desc, publishedAt desc)[0] {
|
||||
_id,
|
||||
title,
|
||||
slug,
|
||||
summary,
|
||||
publishedAt,
|
||||
mainImage,
|
||||
mainImageUrl,
|
||||
thumbnailImage,
|
||||
thumbnailImageUrl,
|
||||
author,
|
||||
featured
|
||||
}`;
|
||||
|
||||
try {
|
||||
const post = await client.fetch(query);
|
||||
return post || null;
|
||||
} catch (error) {
|
||||
console.error("Error fetching featured blog post:", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export default async function BlogPage() {
|
||||
const posts = await getBlogPosts();
|
||||
const featuredPost = posts.find((post) => post.featured) || posts[0];
|
||||
const [posts, featuredCandidate] = await Promise.all([
|
||||
getBlogPosts(),
|
||||
getFeaturedBlogPost(),
|
||||
]);
|
||||
const featuredPost = featuredCandidate || posts[0];
|
||||
const recentPosts = posts
|
||||
.filter((post) => post._id !== featuredPost?._id)
|
||||
.slice(0, 3);
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import type { Metadata } from "next";
|
||||
import { ArrowRightIcon } from "@heroicons/react/16/solid";
|
||||
import Link from "next/link";
|
||||
import Script from "next/script";
|
||||
import "@katanemo/shared-styles/globals.css";
|
||||
import { Analytics } from "@vercel/analytics/next";
|
||||
|
|
@ -35,6 +37,27 @@ export default function RootLayout({
|
|||
gtag('config', 'G-ML7B1X9HY2');
|
||||
`}
|
||||
</Script>
|
||||
<Link
|
||||
href="https://digitalocean.com/blog/digitalocean-acquires-katanemo-labs-inc"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="block w-full bg-[#7780D9] py-3 text-white transition-opacity"
|
||||
>
|
||||
<div className="mx-auto flex max-w-[85rem] items-center justify-center gap-4 px-6 text-center md:justify-between md:text-left lg:px-8">
|
||||
<span className="w-full text-xs font-medium leading-snug md:w-auto md:text-base flex items-center">
|
||||
DigitalOcean acquires Katanemo Labs, Inc. to accelerate AI
|
||||
development
|
||||
<ArrowRightIcon
|
||||
aria-hidden
|
||||
className="ml-1 inline-block h-3 w-3 align-[-1px] text-white/90 md:hidden"
|
||||
/>
|
||||
</span>
|
||||
<span className="hidden shrink-0 items-center gap-1 text-base font-medium tracking-[-0.989px] font-mono leading-snug opacity-70 transition-opacity hover:opacity-100 md:inline-flex">
|
||||
Read the announcement
|
||||
<ArrowRightIcon aria-hidden className="h-3.5 w-3.5 text-white/70" />
|
||||
</span>
|
||||
</div>
|
||||
</Link>
|
||||
<ConditionalLayout>{children}</ConditionalLayout>
|
||||
<Analytics />
|
||||
</body>
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ interface BlogPost {
|
|||
}
|
||||
|
||||
async function getBlogPosts(): Promise<BlogPost[]> {
|
||||
if (!client) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const query = `*[_type == "blog" && published == true] | order(publishedAt desc) {
|
||||
slug,
|
||||
publishedAt,
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ export function Hero() {
|
|||
>
|
||||
<div className="inline-flex flex-wrap items-center gap-1.5 sm:gap-2 px-3 sm:px-4 py-1 rounded-full bg-[rgba(185,191,255,0.4)] border border-[var(--secondary)] shadow backdrop-blur hover:bg-[rgba(185,191,255,0.6)] transition-colors cursor-pointer">
|
||||
<span className="text-xs sm:text-sm font-medium text-black/65">
|
||||
v0.4.10
|
||||
v0.4.21
|
||||
</span>
|
||||
<span className="text-xs sm:text-sm font-medium text-black ">
|
||||
—
|
||||
|
|
|
|||
|
|
@ -2,19 +2,33 @@ import { createClient } from "@sanity/client";
|
|||
import imageUrlBuilder from "@sanity/image-url";
|
||||
import type { SanityImageSource } from "@sanity/image-url/lib/types/types";
|
||||
|
||||
const projectId = process.env.NEXT_PUBLIC_SANITY_PROJECT_ID;
|
||||
const dataset = process.env.NEXT_PUBLIC_SANITY_DATASET;
|
||||
const apiVersion = process.env.NEXT_PUBLIC_SANITY_API_VERSION;
|
||||
const projectId =
|
||||
process.env.NEXT_PUBLIC_SANITY_PROJECT_ID ||
|
||||
"71ny25bn";
|
||||
const dataset =
|
||||
process.env.NEXT_PUBLIC_SANITY_DATASET ||
|
||||
"production";
|
||||
const apiVersion =
|
||||
process.env.NEXT_PUBLIC_SANITY_API_VERSION ||
|
||||
"2025-01-01";
|
||||
|
||||
export const client = createClient({
|
||||
projectId,
|
||||
dataset,
|
||||
apiVersion,
|
||||
useCdn: true, // Set to false if statically generating pages, using ISR or using the on-demand revalidation API
|
||||
});
|
||||
export const hasSanityConfig = Boolean(projectId && dataset && apiVersion);
|
||||
|
||||
const builder = imageUrlBuilder(client);
|
||||
export const client = hasSanityConfig
|
||||
? createClient({
|
||||
projectId,
|
||||
dataset,
|
||||
apiVersion,
|
||||
// Keep blog/admin updates visible immediately after publishing.
|
||||
useCdn: false,
|
||||
})
|
||||
: null;
|
||||
|
||||
const builder = client ? imageUrlBuilder(client) : null;
|
||||
|
||||
export function urlFor(source: SanityImageSource) {
|
||||
if (!builder) {
|
||||
throw new Error("Sanity client is not configured.");
|
||||
}
|
||||
return builder.image(source);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
docker build -f Dockerfile . -t katanemo/plano -t katanemo/plano:0.4.10
|
||||
docker build -f Dockerfile . -t katanemo/plano -t katanemo/plano:0.4.21
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
"""Plano CLI - Intelligent Prompt Gateway."""
|
||||
|
||||
__version__ = "0.4.10"
|
||||
__version__ = "0.4.21"
|
||||
|
|
|
|||
290
cli/planoai/chatgpt_auth.py
Normal file
290
cli/planoai/chatgpt_auth.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
"""
|
||||
ChatGPT subscription OAuth device-flow authentication.
|
||||
|
||||
Implements the device code flow used by OpenAI Codex CLI to authenticate
|
||||
with a ChatGPT Plus/Pro subscription. Tokens are stored locally in
|
||||
~/.plano/chatgpt/auth.json and auto-refreshed when expired.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from planoai.consts import PLANO_HOME
|
||||
|
||||
# OAuth + API constants (derived from openai/codex)
|
||||
CHATGPT_AUTH_BASE = "https://auth.openai.com"
|
||||
CHATGPT_DEVICE_CODE_URL = f"{CHATGPT_AUTH_BASE}/api/accounts/deviceauth/usercode"
|
||||
CHATGPT_DEVICE_TOKEN_URL = f"{CHATGPT_AUTH_BASE}/api/accounts/deviceauth/token"
|
||||
CHATGPT_OAUTH_TOKEN_URL = f"{CHATGPT_AUTH_BASE}/oauth/token"
|
||||
CHATGPT_DEVICE_VERIFY_URL = f"{CHATGPT_AUTH_BASE}/codex/device"
|
||||
CHATGPT_API_BASE = "https://chatgpt.com/backend-api/codex"
|
||||
CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
|
||||
# Local storage
|
||||
CHATGPT_AUTH_DIR = os.path.join(PLANO_HOME, "chatgpt")
|
||||
CHATGPT_AUTH_FILE = os.path.join(CHATGPT_AUTH_DIR, "auth.json")
|
||||
|
||||
# Timeouts
|
||||
TOKEN_EXPIRY_SKEW_SECONDS = 60
|
||||
DEVICE_CODE_TIMEOUT_SECONDS = 15 * 60
|
||||
DEVICE_CODE_POLL_SECONDS = 5
|
||||
|
||||
|
||||
def _ensure_auth_dir():
|
||||
os.makedirs(CHATGPT_AUTH_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def load_auth() -> Optional[Dict[str, Any]]:
|
||||
"""Load auth data from disk."""
|
||||
try:
|
||||
with open(CHATGPT_AUTH_FILE, "r") as f:
|
||||
return json.load(f)
|
||||
except (IOError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
|
||||
def save_auth(data: Dict[str, Any]):
|
||||
"""Save auth data to disk."""
|
||||
_ensure_auth_dir()
|
||||
fd = os.open(CHATGPT_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def delete_auth():
|
||||
"""Remove stored credentials."""
|
||||
try:
|
||||
os.remove(CHATGPT_AUTH_FILE)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def _decode_jwt_claims(token: str) -> Dict[str, Any]:
|
||||
"""Decode JWT payload without verification."""
|
||||
try:
|
||||
parts = token.split(".")
|
||||
if len(parts) < 2:
|
||||
return {}
|
||||
payload_b64 = parts[1]
|
||||
payload_b64 += "=" * (-len(payload_b64) % 4)
|
||||
return json.loads(base64.urlsafe_b64decode(payload_b64).decode("utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _get_expires_at(token: str) -> Optional[int]:
|
||||
"""Extract expiration time from JWT."""
|
||||
claims = _decode_jwt_claims(token)
|
||||
exp = claims.get("exp")
|
||||
return int(exp) if isinstance(exp, (int, float)) else None
|
||||
|
||||
|
||||
def _extract_account_id(token: Optional[str]) -> Optional[str]:
|
||||
"""Extract ChatGPT account ID from JWT claims."""
|
||||
if not token:
|
||||
return None
|
||||
claims = _decode_jwt_claims(token)
|
||||
auth_claims = claims.get("https://api.openai.com/auth")
|
||||
if isinstance(auth_claims, dict):
|
||||
account_id = auth_claims.get("chatgpt_account_id")
|
||||
if isinstance(account_id, str) and account_id:
|
||||
return account_id
|
||||
return None
|
||||
|
||||
|
||||
def _is_token_expired(auth_data: Dict[str, Any]) -> bool:
|
||||
"""Check if the access token is expired."""
|
||||
expires_at = auth_data.get("expires_at")
|
||||
if expires_at is None:
|
||||
access_token = auth_data.get("access_token")
|
||||
if access_token:
|
||||
expires_at = _get_expires_at(access_token)
|
||||
if expires_at:
|
||||
auth_data["expires_at"] = expires_at
|
||||
save_auth(auth_data)
|
||||
if expires_at is None:
|
||||
return True
|
||||
return time.time() >= float(expires_at) - TOKEN_EXPIRY_SKEW_SECONDS
|
||||
|
||||
|
||||
def _refresh_tokens(refresh_token: str) -> Dict[str, str]:
|
||||
"""Refresh the access token using the refresh token."""
|
||||
resp = requests.post(
|
||||
CHATGPT_OAUTH_TOKEN_URL,
|
||||
json={
|
||||
"client_id": CHATGPT_CLIENT_ID,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"scope": "openid profile email",
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
access_token = data.get("access_token")
|
||||
id_token = data.get("id_token")
|
||||
if not access_token or not id_token:
|
||||
raise RuntimeError(f"Refresh response missing fields: {data}")
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": data.get("refresh_token", refresh_token),
|
||||
"id_token": id_token,
|
||||
}
|
||||
|
||||
|
||||
def _build_auth_record(tokens: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Build the auth record to persist."""
|
||||
access_token = tokens.get("access_token")
|
||||
id_token = tokens.get("id_token")
|
||||
expires_at = _get_expires_at(access_token) if access_token else None
|
||||
account_id = _extract_account_id(id_token or access_token)
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": tokens.get("refresh_token"),
|
||||
"id_token": id_token,
|
||||
"expires_at": expires_at,
|
||||
"account_id": account_id,
|
||||
}
|
||||
|
||||
|
||||
def request_device_code() -> Dict[str, str]:
|
||||
"""Request a device code from OpenAI's device auth endpoint."""
|
||||
resp = requests.post(
|
||||
CHATGPT_DEVICE_CODE_URL,
|
||||
json={"client_id": CHATGPT_CLIENT_ID},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
device_auth_id = data.get("device_auth_id")
|
||||
user_code = data.get("user_code") or data.get("usercode")
|
||||
interval = data.get("interval")
|
||||
if not device_auth_id or not user_code:
|
||||
raise RuntimeError(f"Device code response missing fields: {data}")
|
||||
|
||||
return {
|
||||
"device_auth_id": device_auth_id,
|
||||
"user_code": user_code,
|
||||
"interval": str(interval or "5"),
|
||||
}
|
||||
|
||||
|
||||
def poll_for_authorization(device_code: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Poll until the user completes authorization. Returns code_data."""
|
||||
interval = int(device_code.get("interval", "5"))
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < DEVICE_CODE_TIMEOUT_SECONDS:
|
||||
try:
|
||||
resp = requests.post(
|
||||
CHATGPT_DEVICE_TOKEN_URL,
|
||||
json={
|
||||
"device_auth_id": device_code["device_auth_id"],
|
||||
"user_code": device_code["user_code"],
|
||||
},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
if all(
|
||||
key in data
|
||||
for key in ("authorization_code", "code_challenge", "code_verifier")
|
||||
):
|
||||
return data
|
||||
if resp.status_code in (403, 404):
|
||||
time.sleep(max(interval, DEVICE_CODE_POLL_SECONDS))
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
except requests.HTTPError as exc:
|
||||
if exc.response is not None and exc.response.status_code in (403, 404):
|
||||
time.sleep(max(interval, DEVICE_CODE_POLL_SECONDS))
|
||||
continue
|
||||
raise RuntimeError(f"Polling failed: {exc}") from exc
|
||||
|
||||
time.sleep(max(interval, DEVICE_CODE_POLL_SECONDS))
|
||||
|
||||
raise RuntimeError("Timed out waiting for device authorization")
|
||||
|
||||
|
||||
def exchange_code_for_tokens(code_data: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Exchange the authorization code for access/refresh/id tokens."""
|
||||
redirect_uri = f"{CHATGPT_AUTH_BASE}/deviceauth/callback"
|
||||
body = (
|
||||
"grant_type=authorization_code"
|
||||
f"&code={code_data['authorization_code']}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&client_id={CHATGPT_CLIENT_ID}"
|
||||
f"&code_verifier={code_data['code_verifier']}"
|
||||
)
|
||||
resp = requests.post(
|
||||
CHATGPT_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data=body,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if not all(key in data for key in ("access_token", "refresh_token", "id_token")):
|
||||
raise RuntimeError(f"Token exchange response missing fields: {data}")
|
||||
|
||||
return {
|
||||
"access_token": data["access_token"],
|
||||
"refresh_token": data["refresh_token"],
|
||||
"id_token": data["id_token"],
|
||||
}
|
||||
|
||||
|
||||
def login() -> Dict[str, Any]:
|
||||
"""Run the full device code login flow. Returns the auth record."""
|
||||
device_code = request_device_code()
|
||||
auth_record = _build_auth_record({})
|
||||
auth_record["device_code_requested_at"] = time.time()
|
||||
save_auth(auth_record)
|
||||
|
||||
print(
|
||||
"\nSign in with your ChatGPT account:\n"
|
||||
f" 1) Visit: {CHATGPT_DEVICE_VERIFY_URL}\n"
|
||||
f" 2) Enter code: {device_code['user_code']}\n\n"
|
||||
"Device codes are a common phishing target. Never share this code.\n",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
code_data = poll_for_authorization(device_code)
|
||||
tokens = exchange_code_for_tokens(code_data)
|
||||
auth_record = _build_auth_record(tokens)
|
||||
save_auth(auth_record)
|
||||
return auth_record
|
||||
|
||||
|
||||
def get_access_token() -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
Get a valid access token and account ID.
|
||||
Refreshes automatically if expired. Raises if no auth data exists.
|
||||
Returns (access_token, account_id).
|
||||
"""
|
||||
auth_data = load_auth()
|
||||
if not auth_data:
|
||||
raise RuntimeError(
|
||||
"No ChatGPT credentials found. Run 'planoai chatgpt login' first."
|
||||
)
|
||||
|
||||
access_token = auth_data.get("access_token")
|
||||
if access_token and not _is_token_expired(auth_data):
|
||||
return access_token, auth_data.get("account_id")
|
||||
|
||||
# Try refresh
|
||||
refresh_token = auth_data.get("refresh_token")
|
||||
if refresh_token:
|
||||
tokens = _refresh_tokens(refresh_token)
|
||||
auth_record = _build_auth_record(tokens)
|
||||
save_auth(auth_record)
|
||||
return auth_record["access_token"], auth_record.get("account_id")
|
||||
|
||||
raise RuntimeError(
|
||||
"ChatGPT token expired and refresh failed. Run 'planoai chatgpt login' again."
|
||||
)
|
||||
86
cli/planoai/chatgpt_cmd.py
Normal file
86
cli/planoai/chatgpt_cmd.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""
|
||||
CLI commands for ChatGPT subscription management.
|
||||
|
||||
Usage:
|
||||
planoai chatgpt login - Authenticate with ChatGPT via device code flow
|
||||
planoai chatgpt status - Check authentication status
|
||||
planoai chatgpt logout - Remove stored credentials
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
|
||||
from planoai import chatgpt_auth
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@click.group()
|
||||
def chatgpt():
|
||||
"""ChatGPT subscription management."""
|
||||
pass
|
||||
|
||||
|
||||
@chatgpt.command()
|
||||
def login():
|
||||
"""Authenticate with your ChatGPT subscription using device code flow."""
|
||||
try:
|
||||
auth_record = chatgpt_auth.login()
|
||||
account_id = auth_record.get("account_id", "unknown")
|
||||
console.print(
|
||||
f"\n[green]Successfully authenticated with ChatGPT![/green]"
|
||||
f"\nAccount ID: {account_id}"
|
||||
f"\nCredentials saved to: {chatgpt_auth.CHATGPT_AUTH_FILE}"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]Authentication failed:[/red] {e}")
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
@chatgpt.command()
|
||||
def status():
|
||||
"""Check ChatGPT authentication status."""
|
||||
auth_data = chatgpt_auth.load_auth()
|
||||
if not auth_data or not auth_data.get("access_token"):
|
||||
console.print(
|
||||
"[yellow]Not authenticated.[/yellow] Run 'planoai chatgpt login'."
|
||||
)
|
||||
return
|
||||
|
||||
account_id = auth_data.get("account_id", "unknown")
|
||||
expires_at = auth_data.get("expires_at")
|
||||
|
||||
if expires_at:
|
||||
expiry_time = datetime.datetime.fromtimestamp(
|
||||
expires_at, tz=datetime.timezone.utc
|
||||
)
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
if expiry_time > now:
|
||||
remaining = expiry_time - now
|
||||
console.print(
|
||||
f"[green]Authenticated[/green]"
|
||||
f"\n Account ID: {account_id}"
|
||||
f"\n Token expires: {expiry_time.strftime('%Y-%m-%d %H:%M:%S UTC')}"
|
||||
f" ({remaining.seconds // 60}m remaining)"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]Token expired[/yellow]"
|
||||
f"\n Account ID: {account_id}"
|
||||
f"\n Expired at: {expiry_time.strftime('%Y-%m-%d %H:%M:%S UTC')}"
|
||||
f"\n Will auto-refresh on next use, or run 'planoai chatgpt login'."
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[green]Authenticated[/green] (no expiry info)"
|
||||
f"\n Account ID: {account_id}"
|
||||
)
|
||||
|
||||
|
||||
@chatgpt.command()
|
||||
def logout():
|
||||
"""Remove stored ChatGPT credentials."""
|
||||
chatgpt_auth.delete_auth()
|
||||
console.print("[green]ChatGPT credentials removed.[/green]")
|
||||
|
|
@ -1,20 +1,20 @@
|
|||
import json
|
||||
import os
|
||||
import uuid
|
||||
from planoai.utils import convert_legacy_listeners
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
import yaml
|
||||
from jsonschema import validate
|
||||
from jsonschema import validate, ValidationError
|
||||
from urllib.parse import urlparse
|
||||
from copy import deepcopy
|
||||
from planoai.consts import DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
|
||||
|
||||
|
||||
SUPPORTED_PROVIDERS_WITH_BASE_URL = [
|
||||
"azure_openai",
|
||||
"ollama",
|
||||
"qwen",
|
||||
"amazon_bedrock",
|
||||
"arch",
|
||||
"plano",
|
||||
]
|
||||
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [
|
||||
|
|
@ -22,14 +22,23 @@ SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [
|
|||
"groq",
|
||||
"mistral",
|
||||
"openai",
|
||||
"xiaomi",
|
||||
"gemini",
|
||||
"anthropic",
|
||||
"together_ai",
|
||||
"xai",
|
||||
"moonshotai",
|
||||
"zhipu",
|
||||
"chatgpt",
|
||||
"digitalocean",
|
||||
"vercel",
|
||||
"openrouter",
|
||||
]
|
||||
|
||||
CHATGPT_API_BASE = "https://chatgpt.com/backend-api/codex"
|
||||
CHATGPT_DEFAULT_ORIGINATOR = "codex_cli_rs"
|
||||
CHATGPT_DEFAULT_USER_AGENT = "codex_cli_rs/0.0.0 (Unknown 0; unknown) unknown"
|
||||
|
||||
SUPPORTED_PROVIDERS = (
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL + SUPPORTED_PROVIDERS_WITH_BASE_URL
|
||||
)
|
||||
|
|
@ -49,6 +58,110 @@ def get_endpoint_and_port(endpoint, protocol):
|
|||
return endpoint, port
|
||||
|
||||
|
||||
def migrate_inline_routing_preferences(config_yaml):
|
||||
"""Lift v0.3.0-style inline ``routing_preferences`` under each
|
||||
``model_providers`` entry to the v0.4.0 top-level ``routing_preferences``
|
||||
list with ``models: [...]``.
|
||||
|
||||
This function is a no-op for configs whose ``version`` is already
|
||||
``v0.4.0`` or newer — those are assumed to be on the canonical
|
||||
top-level shape and are passed through untouched.
|
||||
|
||||
For older configs, the version is bumped to ``v0.4.0`` up front so
|
||||
brightstaff's v0.4.0 gate for top-level ``routing_preferences``
|
||||
accepts the rendered config, then inline preferences under each
|
||||
provider are lifted into the top-level list. Preferences with the
|
||||
same ``name`` across multiple providers are merged into a single
|
||||
top-level entry whose ``models`` list contains every provider's
|
||||
full ``<provider>/<model>`` string in declaration order. The first
|
||||
``description`` encountered wins; conflicts are warned, not errored,
|
||||
so existing v0.3.0 configs keep compiling. Any top-level preference
|
||||
already defined by the user is preserved as-is.
|
||||
"""
|
||||
current_version = str(config_yaml.get("version", ""))
|
||||
if _version_tuple(current_version) >= (0, 4, 0):
|
||||
return
|
||||
|
||||
config_yaml["version"] = "v0.4.0"
|
||||
|
||||
model_providers = config_yaml.get("model_providers") or []
|
||||
if not model_providers:
|
||||
return
|
||||
|
||||
migrated = {}
|
||||
for model_provider in model_providers:
|
||||
inline_prefs = model_provider.get("routing_preferences")
|
||||
if not inline_prefs:
|
||||
continue
|
||||
|
||||
full_model_name = model_provider.get("model")
|
||||
if not full_model_name:
|
||||
continue
|
||||
|
||||
if "/" in full_model_name and full_model_name.split("/")[-1].strip() == "*":
|
||||
raise Exception(
|
||||
f"Model {full_model_name} has routing_preferences but uses wildcard (*). Models with routing preferences cannot be wildcards."
|
||||
)
|
||||
|
||||
for pref in inline_prefs:
|
||||
name = pref.get("name")
|
||||
description = pref.get("description", "")
|
||||
if not name:
|
||||
continue
|
||||
if name in migrated:
|
||||
entry = migrated[name]
|
||||
if description and description != entry["description"]:
|
||||
print(
|
||||
f"WARNING: routing preference '{name}' has conflicting descriptions across providers; keeping the first one."
|
||||
)
|
||||
if full_model_name not in entry["models"]:
|
||||
entry["models"].append(full_model_name)
|
||||
else:
|
||||
migrated[name] = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"models": [full_model_name],
|
||||
}
|
||||
|
||||
if not migrated:
|
||||
return
|
||||
|
||||
for model_provider in model_providers:
|
||||
if "routing_preferences" in model_provider:
|
||||
del model_provider["routing_preferences"]
|
||||
|
||||
existing_top_level = config_yaml.get("routing_preferences") or []
|
||||
existing_names = {entry.get("name") for entry in existing_top_level}
|
||||
merged = list(existing_top_level)
|
||||
for name, entry in migrated.items():
|
||||
if name in existing_names:
|
||||
continue
|
||||
merged.append(entry)
|
||||
config_yaml["routing_preferences"] = merged
|
||||
|
||||
print(
|
||||
"WARNING: inline routing_preferences under model_providers is deprecated "
|
||||
"and has been auto-migrated to top-level routing_preferences. Update your "
|
||||
"config to v0.4.0 top-level form. See docs/routing-api.md"
|
||||
)
|
||||
|
||||
|
||||
def _version_tuple(version_string):
|
||||
stripped = version_string.strip().lstrip("vV")
|
||||
if not stripped:
|
||||
return (0, 0, 0)
|
||||
parts = stripped.split("-", 1)[0].split(".")
|
||||
out = []
|
||||
for part in parts[:3]:
|
||||
try:
|
||||
out.append(int(part))
|
||||
except ValueError:
|
||||
out.append(0)
|
||||
while len(out) < 3:
|
||||
out.append(0)
|
||||
return tuple(out)
|
||||
|
||||
|
||||
def validate_and_render_schema():
|
||||
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
|
||||
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
|
||||
|
|
@ -92,6 +205,8 @@ def validate_and_render_schema():
|
|||
config_yaml["model_providers"] = config_yaml["llm_providers"]
|
||||
del config_yaml["llm_providers"]
|
||||
|
||||
migrate_inline_routing_preferences(config_yaml)
|
||||
|
||||
listeners, llm_gateway, prompt_gateway = convert_legacy_listeners(
|
||||
config_yaml.get("listeners"), config_yaml.get("model_providers")
|
||||
)
|
||||
|
|
@ -191,7 +306,16 @@ def validate_and_render_schema():
|
|||
model_provider_name_set = set()
|
||||
llms_with_usage = []
|
||||
model_name_keys = set()
|
||||
model_usage_name_keys = set()
|
||||
|
||||
top_level_preferences = config_yaml.get("routing_preferences") or []
|
||||
seen_pref_names = set()
|
||||
for pref in top_level_preferences:
|
||||
pref_name = pref.get("name")
|
||||
if pref_name in seen_pref_names:
|
||||
raise Exception(
|
||||
f'Duplicate routing preference name "{pref_name}", please provide unique name for each routing preference'
|
||||
)
|
||||
seen_pref_names.add(pref_name)
|
||||
|
||||
print("listeners: ", listeners)
|
||||
|
||||
|
|
@ -250,10 +374,6 @@ def validate_and_render_schema():
|
|||
raise Exception(
|
||||
f"Model {model_name} is configured as default but uses wildcard (*). Default models cannot be wildcards."
|
||||
)
|
||||
if model_provider.get("routing_preferences"):
|
||||
raise Exception(
|
||||
f"Model {model_name} has routing_preferences but uses wildcard (*). Models with routing preferences cannot be wildcards."
|
||||
)
|
||||
|
||||
# Validate azure_openai and ollama provider requires base_url
|
||||
if (provider in SUPPORTED_PROVIDERS_WITH_BASE_URL) and model_provider.get(
|
||||
|
|
@ -302,13 +422,6 @@ def validate_and_render_schema():
|
|||
)
|
||||
model_name_keys.add(model_id)
|
||||
|
||||
for routing_preference in model_provider.get("routing_preferences", []):
|
||||
if routing_preference.get("name") in model_usage_name_keys:
|
||||
raise Exception(
|
||||
f'Duplicate routing preference name "{routing_preference.get("name")}", please provide unique name for each routing preference'
|
||||
)
|
||||
model_usage_name_keys.add(routing_preference.get("name"))
|
||||
|
||||
# Warn if both passthrough_auth and access_key are configured
|
||||
if model_provider.get("passthrough_auth") and model_provider.get(
|
||||
"access_key"
|
||||
|
|
@ -331,6 +444,25 @@ def validate_and_render_schema():
|
|||
provider = model_provider["provider"]
|
||||
model_provider["provider_interface"] = provider
|
||||
del model_provider["provider"]
|
||||
|
||||
# Auto-wire ChatGPT provider: inject base_url, passthrough_auth, and extra headers
|
||||
if provider == "chatgpt":
|
||||
if not model_provider.get("base_url"):
|
||||
model_provider["base_url"] = CHATGPT_API_BASE
|
||||
if not model_provider.get("access_key") and not model_provider.get(
|
||||
"passthrough_auth"
|
||||
):
|
||||
model_provider["passthrough_auth"] = True
|
||||
headers = model_provider.get("headers", {})
|
||||
headers.setdefault(
|
||||
"ChatGPT-Account-Id",
|
||||
os.environ.get("CHATGPT_ACCOUNT_ID", ""),
|
||||
)
|
||||
headers.setdefault("originator", CHATGPT_DEFAULT_ORIGINATOR)
|
||||
headers.setdefault("user-agent", CHATGPT_DEFAULT_USER_AGENT)
|
||||
headers.setdefault("session_id", str(uuid.uuid4()))
|
||||
model_provider["headers"] = headers
|
||||
|
||||
updated_model_providers.append(model_provider)
|
||||
|
||||
if model_provider.get("base_url", None):
|
||||
|
|
@ -368,47 +500,51 @@ def validate_and_render_schema():
|
|||
llms_with_endpoint.append(model_provider)
|
||||
llms_with_endpoint_cluster_names.add(cluster_name)
|
||||
|
||||
if len(model_usage_name_keys) > 0:
|
||||
routing_model_provider = config_yaml.get("routing", {}).get(
|
||||
"model_provider", None
|
||||
overrides_config = config_yaml.get("overrides", {})
|
||||
# Build lookup of model names (already prefix-stripped by config processing)
|
||||
model_name_set = {mp.get("model") for mp in updated_model_providers}
|
||||
|
||||
# Auto-add plano-orchestrator provider if routing preferences exist and no provider matches the routing model
|
||||
router_model = overrides_config.get("llm_routing_model", "Plano-Orchestrator")
|
||||
router_model_id = (
|
||||
router_model.split("/", 1)[1] if "/" in router_model else router_model
|
||||
)
|
||||
if len(seen_pref_names) > 0 and router_model_id not in model_name_set:
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "plano-orchestrator",
|
||||
"provider_interface": "plano",
|
||||
"model": router_model_id,
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
if (
|
||||
routing_model_provider
|
||||
and routing_model_provider not in model_provider_name_set
|
||||
):
|
||||
raise Exception(
|
||||
f"Routing model_provider {routing_model_provider} is not defined in model_providers"
|
||||
)
|
||||
if (
|
||||
routing_model_provider is None
|
||||
and "arch-router" not in model_provider_name_set
|
||||
):
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "arch-router",
|
||||
"provider_interface": "arch",
|
||||
"model": config_yaml.get("routing", {}).get("model", "Arch-Router"),
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Always add arch-function model provider if not already defined
|
||||
if "arch-function" not in model_provider_name_set:
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "arch-function",
|
||||
"provider_interface": "arch",
|
||||
"provider_interface": "plano",
|
||||
"model": "Arch-Function",
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
|
||||
if "plano-orchestrator" not in model_provider_name_set:
|
||||
# Auto-add plano-orchestrator provider if no provider matches the orchestrator model
|
||||
orchestrator_model = overrides_config.get(
|
||||
"agent_orchestration_model", "Plano-Orchestrator"
|
||||
)
|
||||
orchestrator_model_id = (
|
||||
orchestrator_model.split("/", 1)[1]
|
||||
if "/" in orchestrator_model
|
||||
else orchestrator_model
|
||||
)
|
||||
if orchestrator_model_id not in model_name_set:
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "plano-orchestrator",
|
||||
"provider_interface": "arch",
|
||||
"model": "Plano-Orchestrator",
|
||||
"name": "plano/orchestrator",
|
||||
"provider_interface": "plano",
|
||||
"model": orchestrator_model_id,
|
||||
"internal": True,
|
||||
}
|
||||
)
|
||||
|
|
@ -426,6 +562,16 @@ def validate_and_render_schema():
|
|||
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
||||
)
|
||||
|
||||
# Validate input_filters IDs on listeners reference valid agent/filter IDs
|
||||
for listener in listeners:
|
||||
listener_input_filters = listener.get("input_filters", [])
|
||||
for fc_id in listener_input_filters:
|
||||
if fc_id not in agent_id_keys:
|
||||
raise Exception(
|
||||
f"Listener '{listener.get('name', 'unknown')}' references input_filters id '{fc_id}' "
|
||||
f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}"
|
||||
)
|
||||
|
||||
# Validate model aliases if present
|
||||
if "model_aliases" in config_yaml:
|
||||
model_aliases = config_yaml["model_aliases"]
|
||||
|
|
@ -503,11 +649,15 @@ def validate_prompt_config(plano_config_file, plano_config_schema_file):
|
|||
|
||||
try:
|
||||
validate(config_yaml, config_schema_yaml)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error validating plano_config file: {plano_config_file}, schema file: {plano_config_schema_file}, error: {e}"
|
||||
except ValidationError as e:
|
||||
path = (
|
||||
" → ".join(str(p) for p in e.absolute_path) if e.absolute_path else "root"
|
||||
)
|
||||
raise e
|
||||
raise ValidationError(
|
||||
f"{e.message}\n Location: {path}\n Value: {e.instance}"
|
||||
) from None
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ PLANO_COLOR = "#969FF4"
|
|||
|
||||
SERVICE_NAME_ARCHGW = "plano"
|
||||
PLANO_DOCKER_NAME = "plano"
|
||||
PLANO_DOCKER_IMAGE = os.getenv("PLANO_DOCKER_IMAGE", "katanemo/plano:0.4.10")
|
||||
PLANO_DOCKER_IMAGE = os.getenv("PLANO_DOCKER_IMAGE", "katanemo/plano:0.4.21")
|
||||
DEFAULT_OTEL_TRACING_GRPC_ENDPOINT = "http://localhost:4317"
|
||||
|
||||
# Native mode constants
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from planoai.consts import (
|
|||
PLANO_DOCKER_IMAGE,
|
||||
PLANO_DOCKER_NAME,
|
||||
)
|
||||
import subprocess
|
||||
from planoai.docker_cli import (
|
||||
docker_container_status,
|
||||
docker_remove_container,
|
||||
|
|
@ -20,7 +19,6 @@ from planoai.docker_cli import (
|
|||
stream_gateway_logs,
|
||||
)
|
||||
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -147,26 +145,48 @@ def stop_docker_container(service=PLANO_DOCKER_NAME):
|
|||
log.info(f"Failed to shut down services: {str(e)}")
|
||||
|
||||
|
||||
def start_cli_agent(plano_config_file=None, settings_json="{}"):
|
||||
"""Start a CLI client connected to Plano."""
|
||||
|
||||
with open(plano_config_file, "r") as file:
|
||||
plano_config = file.read()
|
||||
plano_config_yaml = yaml.safe_load(plano_config)
|
||||
|
||||
# Get egress listener configuration
|
||||
egress_config = plano_config_yaml.get("listeners", {}).get("egress_traffic", {})
|
||||
host = egress_config.get("host", "127.0.0.1")
|
||||
port = egress_config.get("port", 12000)
|
||||
|
||||
# Parse additional settings from command line
|
||||
def _parse_cli_agent_settings(settings_json: str) -> dict:
|
||||
try:
|
||||
additional_settings = json.loads(settings_json) if settings_json else {}
|
||||
return json.loads(settings_json) if settings_json else {}
|
||||
except json.JSONDecodeError:
|
||||
log.error("Settings must be valid JSON")
|
||||
sys.exit(1)
|
||||
|
||||
# Set up environment variables
|
||||
|
||||
def _resolve_cli_agent_endpoint(plano_config_yaml: dict) -> tuple[str, int]:
|
||||
listeners = plano_config_yaml.get("listeners")
|
||||
|
||||
if isinstance(listeners, dict):
|
||||
egress_config = listeners.get("egress_traffic", {})
|
||||
host = egress_config.get("host") or egress_config.get("address") or "0.0.0.0"
|
||||
port = egress_config.get("port", 12000)
|
||||
return host, port
|
||||
|
||||
if isinstance(listeners, list):
|
||||
for listener in listeners:
|
||||
if listener.get("type") == "model":
|
||||
host = listener.get("host") or listener.get("address") or "0.0.0.0"
|
||||
port = listener.get("port", 12000)
|
||||
return host, port
|
||||
|
||||
return "0.0.0.0", 12000
|
||||
|
||||
|
||||
def _apply_non_interactive_env(env: dict, additional_settings: dict) -> None:
|
||||
if additional_settings.get("NON_INTERACTIVE_MODE", False):
|
||||
env.update(
|
||||
{
|
||||
"CI": "true",
|
||||
"FORCE_COLOR": "0",
|
||||
"NODE_NO_READLINE": "1",
|
||||
"TERM": "dumb",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _start_claude_cli_agent(
|
||||
host: str, port: int, plano_config_yaml: dict, additional_settings: dict
|
||||
) -> None:
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
|
|
@ -186,7 +206,6 @@ def start_cli_agent(plano_config_file=None, settings_json="{}"):
|
|||
"ANTHROPIC_SMALL_FAST_MODEL"
|
||||
]
|
||||
else:
|
||||
# Check if arch.claude.code.small.fast alias exists in model_aliases
|
||||
model_aliases = plano_config_yaml.get("model_aliases", {})
|
||||
if "arch.claude.code.small.fast" in model_aliases:
|
||||
env["ANTHROPIC_SMALL_FAST_MODEL"] = "arch.claude.code.small.fast"
|
||||
|
|
@ -196,23 +215,10 @@ def start_cli_agent(plano_config_file=None, settings_json="{}"):
|
|||
)
|
||||
log.info("Or provide ANTHROPIC_SMALL_FAST_MODEL in --settings JSON")
|
||||
|
||||
# Non-interactive mode configuration from additional_settings only
|
||||
if additional_settings.get("NON_INTERACTIVE_MODE", False):
|
||||
env.update(
|
||||
{
|
||||
"CI": "true",
|
||||
"FORCE_COLOR": "0",
|
||||
"NODE_NO_READLINE": "1",
|
||||
"TERM": "dumb",
|
||||
}
|
||||
)
|
||||
_apply_non_interactive_env(env, additional_settings)
|
||||
|
||||
# Build claude command arguments
|
||||
claude_args = []
|
||||
|
||||
# Add settings if provided, excluding those already handled as environment variables
|
||||
if additional_settings:
|
||||
# Filter out settings that are already processed as environment variables
|
||||
claude_settings = {
|
||||
k: v
|
||||
for k, v in additional_settings.items()
|
||||
|
|
@ -221,10 +227,8 @@ def start_cli_agent(plano_config_file=None, settings_json="{}"):
|
|||
if claude_settings:
|
||||
claude_args.append(f"--settings={json.dumps(claude_settings)}")
|
||||
|
||||
# Use claude from PATH
|
||||
claude_path = "claude"
|
||||
log.info(f"Connecting Claude Code Agent to Plano at {host}:{port}")
|
||||
|
||||
try:
|
||||
subprocess.run([claude_path] + claude_args, env=env, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
|
|
@ -235,3 +239,61 @@ def start_cli_agent(plano_config_file=None, settings_json="{}"):
|
|||
f"{claude_path} not found. Make sure Claude Code is installed: npm install -g @anthropic-ai/claude-code"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _start_codex_cli_agent(host: str, port: int, additional_settings: dict) -> None:
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"OPENAI_API_KEY": "test", # Use test token for plano
|
||||
"OPENAI_BASE_URL": f"http://{host}:{port}/v1",
|
||||
"NO_PROXY": host,
|
||||
"DISABLE_TELEMETRY": "true",
|
||||
}
|
||||
)
|
||||
_apply_non_interactive_env(env, additional_settings)
|
||||
|
||||
codex_model = additional_settings.get("CODEX_MODEL", "gpt-5.3-codex")
|
||||
codex_path = "codex"
|
||||
codex_args = ["--model", codex_model]
|
||||
|
||||
log.info(
|
||||
f"Connecting Codex CLI Agent to Plano at {host}:{port} (default model: {codex_model})"
|
||||
)
|
||||
try:
|
||||
subprocess.run([codex_path] + codex_args, env=env, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.error(f"Error starting codex: {e}")
|
||||
sys.exit(1)
|
||||
except FileNotFoundError:
|
||||
log.error(
|
||||
f"{codex_path} not found. Make sure Codex CLI is installed: npm install -g @openai/codex"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def start_cli_agent(
|
||||
plano_config_file=None, cli_agent_type="claude", settings_json="{}"
|
||||
):
|
||||
"""Start a CLI client connected to Plano."""
|
||||
|
||||
with open(plano_config_file, "r") as file:
|
||||
plano_config = file.read()
|
||||
plano_config_yaml = yaml.safe_load(plano_config)
|
||||
|
||||
host, port = _resolve_cli_agent_endpoint(plano_config_yaml)
|
||||
|
||||
additional_settings = _parse_cli_agent_settings(settings_json)
|
||||
|
||||
if cli_agent_type == "claude":
|
||||
_start_claude_cli_agent(host, port, plano_config_yaml, additional_settings)
|
||||
return
|
||||
|
||||
if cli_agent_type == "codex":
|
||||
_start_codex_cli_agent(host, port, additional_settings)
|
||||
return
|
||||
|
||||
log.error(
|
||||
f"Unsupported cli agent type '{cli_agent_type}'. Supported values: claude, codex"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
|
|
|||
178
cli/planoai/defaults.py
Normal file
178
cli/planoai/defaults.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
"""Default config synthesizer for zero-config ``planoai up``.
|
||||
|
||||
When the user runs ``planoai up`` in a directory with no ``config.yaml`` /
|
||||
``plano_config.yaml``, we synthesize a pass-through config that covers the
|
||||
common LLM providers and auto-wires OTel export to ``localhost:4317`` so
|
||||
``planoai obs`` works out of the box.
|
||||
|
||||
Auth handling:
|
||||
- If the provider's env var is set, bind ``access_key: $ENV_VAR``.
|
||||
- Otherwise set ``passthrough_auth: true`` so the client's own Authorization
|
||||
header is forwarded. No env var is required to start the proxy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
DEFAULT_LLM_LISTENER_PORT = 12000
|
||||
# plano_config validation requires an http:// scheme on the OTLP endpoint.
|
||||
DEFAULT_OTLP_ENDPOINT = "http://localhost:4317"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderDefault:
|
||||
name: str
|
||||
env_var: str
|
||||
base_url: str
|
||||
model_pattern: str
|
||||
# Only set for providers whose prefix in the model pattern is NOT one of the
|
||||
# built-in SUPPORTED_PROVIDERS in cli/planoai/config_generator.py. For
|
||||
# built-ins, the validator infers the interface from the model prefix and
|
||||
# rejects configs that set this field explicitly.
|
||||
provider_interface: str | None = None
|
||||
|
||||
|
||||
# Keep ordering stable so synthesized configs diff cleanly across runs.
|
||||
PROVIDER_DEFAULTS: list[ProviderDefault] = [
|
||||
ProviderDefault(
|
||||
name="openai",
|
||||
env_var="OPENAI_API_KEY",
|
||||
base_url="https://api.openai.com/v1",
|
||||
model_pattern="openai/*",
|
||||
),
|
||||
ProviderDefault(
|
||||
name="anthropic",
|
||||
env_var="ANTHROPIC_API_KEY",
|
||||
base_url="https://api.anthropic.com/v1",
|
||||
model_pattern="anthropic/*",
|
||||
),
|
||||
ProviderDefault(
|
||||
name="gemini",
|
||||
env_var="GEMINI_API_KEY",
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||
model_pattern="gemini/*",
|
||||
),
|
||||
ProviderDefault(
|
||||
name="groq",
|
||||
env_var="GROQ_API_KEY",
|
||||
base_url="https://api.groq.com/openai/v1",
|
||||
model_pattern="groq/*",
|
||||
),
|
||||
ProviderDefault(
|
||||
name="deepseek",
|
||||
env_var="DEEPSEEK_API_KEY",
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
model_pattern="deepseek/*",
|
||||
),
|
||||
ProviderDefault(
|
||||
name="mistral",
|
||||
env_var="MISTRAL_API_KEY",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
model_pattern="mistral/*",
|
||||
),
|
||||
# DigitalOcean Gradient is a first-class provider post-#889 — the
|
||||
# `digitalocean/` model prefix routes to the built-in Envoy cluster, no
|
||||
# base_url needed at runtime.
|
||||
ProviderDefault(
|
||||
name="digitalocean",
|
||||
env_var="DO_API_KEY",
|
||||
base_url="https://inference.do-ai.run/v1",
|
||||
model_pattern="digitalocean/*",
|
||||
),
|
||||
ProviderDefault(
|
||||
name="vercel",
|
||||
env_var="AI_GATEWAY_API_KEY",
|
||||
base_url="https://ai-gateway.vercel.sh/v1",
|
||||
model_pattern="vercel/*",
|
||||
),
|
||||
# OpenRouter is a first-class provider — the `openrouter/` model prefix is
|
||||
# accepted by the schema and brightstaff's ProviderId parser, so no
|
||||
# provider_interface override is needed.
|
||||
ProviderDefault(
|
||||
name="openrouter",
|
||||
env_var="OPENROUTER_API_KEY",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_pattern="openrouter/*",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
with_keys: list[ProviderDefault]
|
||||
passthrough: list[ProviderDefault]
|
||||
|
||||
@property
|
||||
def summary(self) -> str:
|
||||
parts = []
|
||||
if self.with_keys:
|
||||
parts.append("env-keyed: " + ", ".join(p.name for p in self.with_keys))
|
||||
if self.passthrough:
|
||||
parts.append("pass-through: " + ", ".join(p.name for p in self.passthrough))
|
||||
return " | ".join(parts) if parts else "no providers"
|
||||
|
||||
|
||||
def detect_providers(env: dict[str, str] | None = None) -> DetectionResult:
|
||||
env = env if env is not None else dict(os.environ)
|
||||
with_keys: list[ProviderDefault] = []
|
||||
passthrough: list[ProviderDefault] = []
|
||||
for p in PROVIDER_DEFAULTS:
|
||||
val = env.get(p.env_var)
|
||||
if val:
|
||||
with_keys.append(p)
|
||||
else:
|
||||
passthrough.append(p)
|
||||
return DetectionResult(with_keys=with_keys, passthrough=passthrough)
|
||||
|
||||
|
||||
def synthesize_default_config(
|
||||
env: dict[str, str] | None = None,
|
||||
*,
|
||||
listener_port: int = DEFAULT_LLM_LISTENER_PORT,
|
||||
otel_endpoint: str = DEFAULT_OTLP_ENDPOINT,
|
||||
) -> dict:
|
||||
"""Build a pass-through config dict suitable for validation + envoy rendering.
|
||||
|
||||
The returned dict can be dumped to YAML and handed to the existing `planoai up`
|
||||
pipeline unchanged.
|
||||
"""
|
||||
detection = detect_providers(env)
|
||||
|
||||
def _entry(p: ProviderDefault, base: dict) -> dict:
|
||||
row: dict = {"name": p.name, "model": p.model_pattern, "base_url": p.base_url}
|
||||
if p.provider_interface is not None:
|
||||
row["provider_interface"] = p.provider_interface
|
||||
row.update(base)
|
||||
return row
|
||||
|
||||
model_providers: list[dict] = []
|
||||
for p in detection.with_keys:
|
||||
model_providers.append(_entry(p, {"access_key": f"${p.env_var}"}))
|
||||
for p in detection.passthrough:
|
||||
model_providers.append(_entry(p, {"passthrough_auth": True}))
|
||||
|
||||
# No explicit `default: true` entry is synthesized: the plano config
|
||||
# validator rejects wildcard models as defaults, and brightstaff already
|
||||
# registers bare model names as lookup keys during wildcard expansion
|
||||
# (crates/common/src/llm_providers.rs), so `{"model": "gpt-4o-mini"}`
|
||||
# without a prefix resolves via the openai wildcard without needing
|
||||
# `default: true`. See discussion on #890.
|
||||
|
||||
return {
|
||||
"version": "v0.4.0",
|
||||
"listeners": [
|
||||
{
|
||||
"name": "llm",
|
||||
"type": "model",
|
||||
"port": listener_port,
|
||||
"address": "0.0.0.0",
|
||||
}
|
||||
],
|
||||
"model_providers": model_providers,
|
||||
"tracing": {
|
||||
"random_sampling": 100,
|
||||
"opentracing_grpc_endpoint": otel_endpoint,
|
||||
},
|
||||
}
|
||||
|
|
@ -1,9 +1,18 @@
|
|||
import json
|
||||
import os
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
import sys
|
||||
import contextlib
|
||||
import logging
|
||||
import rich_click as click
|
||||
import yaml
|
||||
from planoai import targets
|
||||
from planoai.defaults import (
|
||||
DEFAULT_LLM_LISTENER_PORT,
|
||||
detect_providers,
|
||||
synthesize_default_config,
|
||||
)
|
||||
|
||||
# Brand color - Plano purple
|
||||
PLANO_COLOR = "#969FF4"
|
||||
|
|
@ -28,9 +37,13 @@ from planoai.core import (
|
|||
)
|
||||
from planoai.init_cmd import init as init_cmd
|
||||
from planoai.trace_cmd import trace as trace_cmd, start_trace_listener_background
|
||||
from planoai.chatgpt_cmd import chatgpt as chatgpt_cmd
|
||||
from planoai.obs_cmd import obs as obs_cmd
|
||||
from planoai.consts import (
|
||||
DEFAULT_OTEL_TRACING_GRPC_ENDPOINT,
|
||||
DEFAULT_NATIVE_OTEL_TRACING_GRPC_ENDPOINT,
|
||||
NATIVE_PID_FILE,
|
||||
PLANO_RUN_DIR,
|
||||
PLANO_DOCKER_IMAGE,
|
||||
PLANO_DOCKER_NAME,
|
||||
)
|
||||
|
|
@ -40,6 +53,30 @@ from planoai.versioning import check_version_status, get_latest_version, get_ver
|
|||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def _is_native_plano_running() -> bool:
|
||||
if not os.path.exists(NATIVE_PID_FILE):
|
||||
return False
|
||||
try:
|
||||
with open(NATIVE_PID_FILE, "r") as f:
|
||||
pids = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return False
|
||||
|
||||
envoy_pid = pids.get("envoy_pid")
|
||||
brightstaff_pid = pids.get("brightstaff_pid")
|
||||
if not isinstance(envoy_pid, int) or not isinstance(brightstaff_pid, int):
|
||||
return False
|
||||
|
||||
for pid in (envoy_pid, brightstaff_pid):
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
except PermissionError:
|
||||
continue
|
||||
return True
|
||||
|
||||
|
||||
def _is_port_in_use(port: int) -> bool:
|
||||
"""Check if a TCP port is already bound on localhost."""
|
||||
import socket
|
||||
|
|
@ -75,6 +112,42 @@ def _print_cli_header(console) -> None:
|
|||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _temporary_cli_log_level(level: str | None):
|
||||
if level is None:
|
||||
yield
|
||||
return
|
||||
|
||||
current_level = logging.getLevelName(logging.getLogger().level).lower()
|
||||
set_log_level(level)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
set_log_level(current_level)
|
||||
|
||||
|
||||
def _inject_chatgpt_tokens_if_needed(config, env, console):
|
||||
"""If config uses chatgpt providers, resolve tokens from ~/.plano/chatgpt/auth.json."""
|
||||
providers = config.get("model_providers") or config.get("llm_providers") or []
|
||||
has_chatgpt = any(str(p.get("model", "")).startswith("chatgpt/") for p in providers)
|
||||
if not has_chatgpt:
|
||||
return
|
||||
|
||||
try:
|
||||
from planoai.chatgpt_auth import get_access_token
|
||||
|
||||
access_token, account_id = get_access_token()
|
||||
env["CHATGPT_ACCESS_TOKEN"] = access_token
|
||||
if account_id:
|
||||
env["CHATGPT_ACCOUNT_ID"] = account_id
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"\n[red]ChatGPT auth error:[/red] {e}\n"
|
||||
f"[dim]Run 'planoai chatgpt login' to authenticate.[/dim]\n"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _print_missing_keys(console, missing_keys: list[str]) -> None:
|
||||
console.print(f"\n[red]✗[/red] [red]Missing API keys![/red]\n")
|
||||
for key in missing_keys:
|
||||
|
|
@ -179,7 +252,7 @@ def build(docker):
|
|||
cwd=crates_dir,
|
||||
check=True,
|
||||
)
|
||||
console.print("[green]✓[/green] WASM plugins built")
|
||||
log.info("WASM plugins built")
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]✗[/red] WASM build failed: {e}")
|
||||
sys.exit(1)
|
||||
|
|
@ -197,7 +270,7 @@ def build(docker):
|
|||
cwd=crates_dir,
|
||||
check=True,
|
||||
)
|
||||
console.print("[green]✓[/green] brightstaff built")
|
||||
log.info("brightstaff built")
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]✗[/red] brightstaff build failed: {e}")
|
||||
sys.exit(1)
|
||||
|
|
@ -267,163 +340,239 @@ def build(docker):
|
|||
help="Run Plano inside Docker instead of natively.",
|
||||
is_flag=True,
|
||||
)
|
||||
def up(file, path, foreground, with_tracing, tracing_port, docker):
|
||||
@click.option(
|
||||
"--verbose",
|
||||
"-v",
|
||||
default=False,
|
||||
help="Show detailed startup logs with timestamps.",
|
||||
is_flag=True,
|
||||
)
|
||||
@click.option(
|
||||
"--listener-port",
|
||||
default=DEFAULT_LLM_LISTENER_PORT,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="Override the LLM listener port when running without a config file. Ignored when a config file is present.",
|
||||
)
|
||||
def up(
|
||||
file,
|
||||
path,
|
||||
foreground,
|
||||
with_tracing,
|
||||
tracing_port,
|
||||
docker,
|
||||
verbose,
|
||||
listener_port,
|
||||
):
|
||||
"""Starts Plano."""
|
||||
from rich.status import Status
|
||||
|
||||
console = _console()
|
||||
_print_cli_header(console)
|
||||
|
||||
# Use the utility function to find config file
|
||||
plano_config_file = find_config_file(path, file)
|
||||
with _temporary_cli_log_level("warning" if not verbose else None):
|
||||
# Use the utility function to find config file
|
||||
plano_config_file = find_config_file(path, file)
|
||||
|
||||
# Check if the file exists
|
||||
if not os.path.exists(plano_config_file):
|
||||
console.print(
|
||||
f"[red]✗[/red] Config file not found: [dim]{plano_config_file}[/dim]"
|
||||
# Zero-config fallback: when no user config is present, synthesize a
|
||||
# pass-through config that covers the common LLM providers and
|
||||
# auto-wires OTel export to ``planoai obs``. See cli/planoai/defaults.py.
|
||||
if not os.path.exists(plano_config_file):
|
||||
detection = detect_providers()
|
||||
cfg_dict = synthesize_default_config(listener_port=listener_port)
|
||||
|
||||
default_dir = os.path.expanduser("~/.plano")
|
||||
os.makedirs(default_dir, exist_ok=True)
|
||||
synthesized_path = os.path.join(default_dir, "default_config.yaml")
|
||||
with open(synthesized_path, "w") as fh:
|
||||
yaml.safe_dump(cfg_dict, fh, sort_keys=False)
|
||||
plano_config_file = synthesized_path
|
||||
console.print(
|
||||
f"[dim]No plano config found; using defaults ({detection.summary}). "
|
||||
f"Listening on :{listener_port}, tracing -> http://localhost:4317.[/dim]"
|
||||
)
|
||||
|
||||
if not docker:
|
||||
from planoai.native_runner import native_validate_config
|
||||
|
||||
with Status(
|
||||
"[dim]Validating configuration[/dim]",
|
||||
spinner="dots",
|
||||
spinner_style="dim",
|
||||
):
|
||||
try:
|
||||
native_validate_config(plano_config_file)
|
||||
except SystemExit:
|
||||
console.print(f"[red]✗[/red] Validation failed")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗[/red] Validation failed")
|
||||
console.print(f" [dim]{str(e).strip()}[/dim]")
|
||||
sys.exit(1)
|
||||
else:
|
||||
with Status(
|
||||
"[dim]Validating configuration (Docker)[/dim]",
|
||||
spinner="dots",
|
||||
spinner_style="dim",
|
||||
):
|
||||
(
|
||||
validation_return_code,
|
||||
_,
|
||||
validation_stderr,
|
||||
) = docker_validate_plano_schema(plano_config_file)
|
||||
|
||||
if validation_return_code != 0:
|
||||
console.print(f"[red]✗[/red] Validation failed")
|
||||
if validation_stderr:
|
||||
console.print(f" [dim]{validation_stderr.strip()}[/dim]")
|
||||
sys.exit(1)
|
||||
|
||||
log.info("Configuration valid")
|
||||
|
||||
# Set up environment
|
||||
default_otel = (
|
||||
DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
|
||||
if docker
|
||||
else DEFAULT_NATIVE_OTEL_TRACING_GRPC_ENDPOINT
|
||||
)
|
||||
sys.exit(1)
|
||||
env_stage = {
|
||||
"OTEL_TRACING_GRPC_ENDPOINT": default_otel,
|
||||
}
|
||||
env = os.environ.copy()
|
||||
env.pop("PATH", None)
|
||||
|
||||
if not docker:
|
||||
from planoai.native_runner import native_validate_config
|
||||
import yaml
|
||||
|
||||
with Status(
|
||||
"[dim]Validating configuration[/dim]",
|
||||
spinner="dots",
|
||||
spinner_style="dim",
|
||||
):
|
||||
try:
|
||||
native_validate_config(plano_config_file)
|
||||
except SystemExit:
|
||||
console.print(f"[red]✗[/red] Validation failed")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗[/red] Validation failed")
|
||||
console.print(f" [dim]{str(e).strip()}[/dim]")
|
||||
sys.exit(1)
|
||||
else:
|
||||
with Status(
|
||||
"[dim]Validating configuration (Docker)[/dim]",
|
||||
spinner="dots",
|
||||
spinner_style="dim",
|
||||
):
|
||||
(
|
||||
validation_return_code,
|
||||
_,
|
||||
validation_stderr,
|
||||
) = docker_validate_plano_schema(plano_config_file)
|
||||
with open(plano_config_file, "r") as f:
|
||||
plano_config = yaml.safe_load(f)
|
||||
|
||||
if validation_return_code != 0:
|
||||
console.print(f"[red]✗[/red] Validation failed")
|
||||
if validation_stderr:
|
||||
console.print(f" [dim]{validation_stderr.strip()}[/dim]")
|
||||
# Inject ChatGPT tokens from ~/.plano/chatgpt/auth.json if any provider needs them
|
||||
_inject_chatgpt_tokens_if_needed(plano_config, env, console)
|
||||
|
||||
# Check access keys
|
||||
access_keys = get_llm_provider_access_keys(plano_config_file=plano_config_file)
|
||||
access_keys = set(access_keys)
|
||||
access_keys = [
|
||||
item[1:] if item.startswith("$") else item for item in access_keys
|
||||
]
|
||||
|
||||
missing_keys = []
|
||||
if access_keys:
|
||||
if file:
|
||||
app_env_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(file)), ".env"
|
||||
)
|
||||
else:
|
||||
app_env_file = os.path.abspath(os.path.join(path, ".env"))
|
||||
|
||||
if not os.path.exists(app_env_file):
|
||||
for access_key in access_keys:
|
||||
if env.get(access_key) is None:
|
||||
missing_keys.append(access_key)
|
||||
else:
|
||||
env_stage[access_key] = env.get(access_key)
|
||||
else:
|
||||
env_file_dict = load_env_file_to_dict(app_env_file)
|
||||
for access_key in access_keys:
|
||||
if env_file_dict.get(access_key) is None:
|
||||
missing_keys.append(access_key)
|
||||
else:
|
||||
env_stage[access_key] = env_file_dict[access_key]
|
||||
|
||||
if missing_keys:
|
||||
_print_missing_keys(console, missing_keys)
|
||||
sys.exit(1)
|
||||
|
||||
console.print(f"[green]✓[/green] Configuration valid")
|
||||
# Pass log level to the Docker container — supervisord uses LOG_LEVEL
|
||||
# to set RUST_LOG (brightstaff) and envoy component log levels
|
||||
env_stage["LOG_LEVEL"] = os.environ.get("LOG_LEVEL", "info")
|
||||
|
||||
# Set up environment
|
||||
default_otel = (
|
||||
DEFAULT_OTEL_TRACING_GRPC_ENDPOINT
|
||||
if docker
|
||||
else DEFAULT_NATIVE_OTEL_TRACING_GRPC_ENDPOINT
|
||||
)
|
||||
env_stage = {
|
||||
"OTEL_TRACING_GRPC_ENDPOINT": default_otel,
|
||||
}
|
||||
env = os.environ.copy()
|
||||
env.pop("PATH", None)
|
||||
# Start the local OTLP trace collector if --with-tracing is set
|
||||
trace_server = None
|
||||
if with_tracing:
|
||||
if _is_port_in_use(tracing_port):
|
||||
# A listener is already running (e.g. `planoai trace listen`)
|
||||
console.print(
|
||||
f"[green]✓[/green] Trace collector already running on port [cyan]{tracing_port}[/cyan]"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
trace_server = start_trace_listener_background(
|
||||
grpc_port=tracing_port
|
||||
)
|
||||
console.print(
|
||||
f"[green]✓[/green] Trace collector listening on [cyan]0.0.0.0:{tracing_port}[/cyan]"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[red]✗[/red] Failed to start trace collector on port {tracing_port}: {e}"
|
||||
)
|
||||
console.print(
|
||||
f"\n[dim]Check if another process is using port {tracing_port}:[/dim]"
|
||||
)
|
||||
console.print(f" [cyan]lsof -i :{tracing_port}[/cyan]")
|
||||
console.print(f"\n[dim]Or use a different port:[/dim]")
|
||||
console.print(
|
||||
f" [cyan]planoai up --with-tracing --tracing-port 4318[/cyan]\n"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Check access keys
|
||||
access_keys = get_llm_provider_access_keys(plano_config_file=plano_config_file)
|
||||
access_keys = set(access_keys)
|
||||
access_keys = [item[1:] if item.startswith("$") else item for item in access_keys]
|
||||
# Update the OTEL endpoint so the gateway sends traces to the right port
|
||||
tracing_host = "host.docker.internal" if docker else "localhost"
|
||||
otel_endpoint = f"http://{tracing_host}:{tracing_port}"
|
||||
env_stage["OTEL_TRACING_GRPC_ENDPOINT"] = otel_endpoint
|
||||
|
||||
missing_keys = []
|
||||
if access_keys:
|
||||
if file:
|
||||
app_env_file = os.path.join(os.path.dirname(os.path.abspath(file)), ".env")
|
||||
else:
|
||||
app_env_file = os.path.abspath(os.path.join(path, ".env"))
|
||||
env.update(env_stage)
|
||||
try:
|
||||
if not docker:
|
||||
from planoai.native_runner import start_native
|
||||
|
||||
if not os.path.exists(app_env_file):
|
||||
for access_key in access_keys:
|
||||
if env.get(access_key) is None:
|
||||
missing_keys.append(access_key)
|
||||
if not verbose:
|
||||
with Status(
|
||||
f"[{PLANO_COLOR}]Starting Plano...[/{PLANO_COLOR}]",
|
||||
spinner="dots",
|
||||
) as status:
|
||||
|
||||
def _update_progress(message: str) -> None:
|
||||
console.print(f"[dim]{message}[/dim]")
|
||||
|
||||
start_native(
|
||||
plano_config_file,
|
||||
env,
|
||||
foreground=foreground,
|
||||
with_tracing=with_tracing,
|
||||
progress_callback=_update_progress,
|
||||
)
|
||||
console.print("")
|
||||
console.print(
|
||||
"[green]✓[/green] Plano is running! [dim](native mode)[/dim]"
|
||||
)
|
||||
log_dir = os.path.join(PLANO_RUN_DIR, "logs")
|
||||
console.print(f"Logs: {log_dir}")
|
||||
console.print("Run 'planoai down' to stop.")
|
||||
else:
|
||||
env_stage[access_key] = env.get(access_key)
|
||||
else:
|
||||
env_file_dict = load_env_file_to_dict(app_env_file)
|
||||
for access_key in access_keys:
|
||||
if env_file_dict.get(access_key) is None:
|
||||
missing_keys.append(access_key)
|
||||
else:
|
||||
env_stage[access_key] = env_file_dict[access_key]
|
||||
start_native(
|
||||
plano_config_file,
|
||||
env,
|
||||
foreground=foreground,
|
||||
with_tracing=with_tracing,
|
||||
)
|
||||
else:
|
||||
start_plano(plano_config_file, env, foreground=foreground)
|
||||
|
||||
if missing_keys:
|
||||
_print_missing_keys(console, missing_keys)
|
||||
sys.exit(1)
|
||||
|
||||
# Pass log level to the Docker container — supervisord uses LOG_LEVEL
|
||||
# to set RUST_LOG (brightstaff) and envoy component log levels
|
||||
env_stage["LOG_LEVEL"] = os.environ.get("LOG_LEVEL", "info")
|
||||
|
||||
# Start the local OTLP trace collector if --with-tracing is set
|
||||
trace_server = None
|
||||
if with_tracing:
|
||||
if _is_port_in_use(tracing_port):
|
||||
# A listener is already running (e.g. `planoai trace listen`)
|
||||
console.print(
|
||||
f"[green]✓[/green] Trace collector already running on port [cyan]{tracing_port}[/cyan]"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
trace_server = start_trace_listener_background(grpc_port=tracing_port)
|
||||
# When tracing is enabled but --foreground is not, keep the process
|
||||
# alive so the OTLP collector continues to receive spans.
|
||||
if trace_server is not None and not foreground:
|
||||
console.print(
|
||||
f"[green]✓[/green] Trace collector listening on [cyan]0.0.0.0:{tracing_port}[/cyan]"
|
||||
f"[dim]Plano is running. Trace collector active on port {tracing_port}. Press Ctrl+C to stop.[/dim]"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[red]✗[/red] Failed to start trace collector on port {tracing_port}: {e}"
|
||||
)
|
||||
console.print(
|
||||
f"\n[dim]Check if another process is using port {tracing_port}:[/dim]"
|
||||
)
|
||||
console.print(f" [cyan]lsof -i :{tracing_port}[/cyan]")
|
||||
console.print(f"\n[dim]Or use a different port:[/dim]")
|
||||
console.print(
|
||||
f" [cyan]planoai up --with-tracing --tracing-port 4318[/cyan]\n"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Update the OTEL endpoint so the gateway sends traces to the right port
|
||||
tracing_host = "host.docker.internal" if docker else "localhost"
|
||||
otel_endpoint = f"http://{tracing_host}:{tracing_port}"
|
||||
env_stage["OTEL_TRACING_GRPC_ENDPOINT"] = otel_endpoint
|
||||
|
||||
env.update(env_stage)
|
||||
try:
|
||||
if not docker:
|
||||
from planoai.native_runner import start_native
|
||||
|
||||
start_native(
|
||||
plano_config_file, env, foreground=foreground, with_tracing=with_tracing
|
||||
)
|
||||
else:
|
||||
start_plano(plano_config_file, env, foreground=foreground)
|
||||
|
||||
# When tracing is enabled but --foreground is not, keep the process
|
||||
# alive so the OTLP collector continues to receive spans.
|
||||
if trace_server is not None and not foreground:
|
||||
console.print(
|
||||
f"[dim]Plano is running. Trace collector active on port {tracing_port}. Press Ctrl+C to stop.[/dim]"
|
||||
)
|
||||
trace_server.wait_for_termination()
|
||||
except KeyboardInterrupt:
|
||||
if trace_server is not None:
|
||||
console.print(f"\n[dim]Stopping trace collector...[/dim]")
|
||||
finally:
|
||||
if trace_server is not None:
|
||||
trace_server.stop(grace=2)
|
||||
trace_server.wait_for_termination()
|
||||
except KeyboardInterrupt:
|
||||
if trace_server is not None:
|
||||
console.print(f"\n[dim]Stopping trace collector...[/dim]")
|
||||
finally:
|
||||
if trace_server is not None:
|
||||
trace_server.stop(grace=2)
|
||||
|
||||
|
||||
@click.command()
|
||||
|
|
@ -433,25 +582,46 @@ def up(file, path, foreground, with_tracing, tracing_port, docker):
|
|||
help="Stop a Docker-based Plano instance.",
|
||||
is_flag=True,
|
||||
)
|
||||
def down(docker):
|
||||
@click.option(
|
||||
"--verbose",
|
||||
"-v",
|
||||
default=False,
|
||||
help="Show detailed shutdown logs with timestamps.",
|
||||
is_flag=True,
|
||||
)
|
||||
def down(docker, verbose):
|
||||
"""Stops Plano."""
|
||||
console = _console()
|
||||
_print_cli_header(console)
|
||||
|
||||
if not docker:
|
||||
from planoai.native_runner import stop_native
|
||||
with _temporary_cli_log_level("warning" if not verbose else None):
|
||||
if not docker:
|
||||
from planoai.native_runner import stop_native
|
||||
|
||||
with console.status(
|
||||
f"[{PLANO_COLOR}]Shutting down Plano...[/{PLANO_COLOR}]",
|
||||
spinner="dots",
|
||||
):
|
||||
stop_native()
|
||||
else:
|
||||
with console.status(
|
||||
f"[{PLANO_COLOR}]Shutting down Plano (Docker)...[/{PLANO_COLOR}]",
|
||||
spinner="dots",
|
||||
):
|
||||
stop_docker_container()
|
||||
with console.status(
|
||||
f"[{PLANO_COLOR}]Shutting down Plano...[/{PLANO_COLOR}]",
|
||||
spinner="dots",
|
||||
):
|
||||
stopped = stop_native()
|
||||
if not verbose:
|
||||
if stopped:
|
||||
console.print(
|
||||
"[green]✓[/green] Plano stopped! [dim](native mode)[/dim]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"[dim]No Plano instance was running (native mode)[/dim]"
|
||||
)
|
||||
else:
|
||||
with console.status(
|
||||
f"[{PLANO_COLOR}]Shutting down Plano (Docker)...[/{PLANO_COLOR}]",
|
||||
spinner="dots",
|
||||
):
|
||||
stop_docker_container()
|
||||
if not verbose:
|
||||
console.print(
|
||||
"[green]✓[/green] Plano stopped! [dim](docker mode)[/dim]"
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
|
|
@ -483,9 +653,21 @@ def generate_prompt_targets(file):
|
|||
is_flag=True,
|
||||
)
|
||||
@click.option("--follow", help="Follow the logs", is_flag=True)
|
||||
def logs(debug, follow):
|
||||
@click.option(
|
||||
"--docker",
|
||||
default=False,
|
||||
help="Stream logs from a Docker-based Plano instance.",
|
||||
is_flag=True,
|
||||
)
|
||||
def logs(debug, follow, docker):
|
||||
"""Stream logs from access logs services."""
|
||||
|
||||
if not docker:
|
||||
from planoai.native_runner import native_logs
|
||||
|
||||
native_logs(debug=debug, follow=follow)
|
||||
return
|
||||
|
||||
plano_process = None
|
||||
try:
|
||||
if debug:
|
||||
|
|
@ -511,7 +693,7 @@ def logs(debug, follow):
|
|||
|
||||
|
||||
@click.command()
|
||||
@click.argument("type", type=click.Choice(["claude"]), required=True)
|
||||
@click.argument("type", type=click.Choice(["claude", "codex"]), required=True)
|
||||
@click.argument("file", required=False) # Optional file argument
|
||||
@click.option(
|
||||
"--path", default=".", help="Path to the directory containing plano_config.yaml"
|
||||
|
|
@ -524,14 +706,19 @@ def logs(debug, follow):
|
|||
def cli_agent(type, file, path, settings):
|
||||
"""Start a CLI agent connected to Plano.
|
||||
|
||||
CLI_AGENT: The type of CLI agent to start (currently only 'claude' is supported)
|
||||
CLI_AGENT: The type of CLI agent to start ('claude' or 'codex')
|
||||
"""
|
||||
|
||||
# Check if plano docker container is running
|
||||
plano_status = docker_container_status(PLANO_DOCKER_NAME)
|
||||
if plano_status != "running":
|
||||
log.error(f"plano docker container is not running (status: {plano_status})")
|
||||
log.error("Please start plano using the 'planoai up' command.")
|
||||
native_running = _is_native_plano_running()
|
||||
docker_running = False
|
||||
if not native_running:
|
||||
docker_running = docker_container_status(PLANO_DOCKER_NAME) == "running"
|
||||
|
||||
if not (native_running or docker_running):
|
||||
log.error("Plano is not running.")
|
||||
log.error(
|
||||
"Start Plano first using 'planoai up <config.yaml>' (native or --docker mode)."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Determine plano_config.yaml path
|
||||
|
|
@ -541,7 +728,7 @@ def cli_agent(type, file, path, settings):
|
|||
sys.exit(1)
|
||||
|
||||
try:
|
||||
start_cli_agent(plano_config_file, settings)
|
||||
start_cli_agent(plano_config_file, type, settings)
|
||||
except SystemExit:
|
||||
# Re-raise SystemExit to preserve exit codes
|
||||
raise
|
||||
|
|
@ -559,6 +746,8 @@ main.add_command(cli_agent)
|
|||
main.add_command(generate_prompt_targets)
|
||||
main.add_command(init_cmd, name="init")
|
||||
main.add_command(trace_cmd, name="trace")
|
||||
main.add_command(chatgpt_cmd, name="chatgpt")
|
||||
main.add_command(obs_cmd, name="obs")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -100,15 +100,13 @@ def ensure_envoy_binary():
|
|||
with open(version_path, "r") as f:
|
||||
cached_version = f.read().strip()
|
||||
if cached_version == ENVOY_VERSION:
|
||||
log.info(f"Envoy {ENVOY_VERSION} found at {envoy_path}")
|
||||
log.info(f"Envoy {ENVOY_VERSION} (cached)")
|
||||
return envoy_path
|
||||
print(
|
||||
log.info(
|
||||
f"Envoy version changed ({cached_version} → {ENVOY_VERSION}), re-downloading..."
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"Envoy binary found at {envoy_path} (unknown version, re-downloading...)"
|
||||
)
|
||||
log.info("Envoy binary found (unknown version, re-downloading...)")
|
||||
|
||||
slug = _get_platform_slug()
|
||||
url = (
|
||||
|
|
@ -123,6 +121,7 @@ def ensure_envoy_binary():
|
|||
|
||||
try:
|
||||
_download_file(url, tmp_path, label=f"Envoy {ENVOY_VERSION}")
|
||||
log.info(f"Extracting Envoy {ENVOY_VERSION}...")
|
||||
with tarfile.open(tmp_path, "r:xz") as tar:
|
||||
# Find the envoy binary inside the archive
|
||||
envoy_member = None
|
||||
|
|
@ -150,7 +149,6 @@ def ensure_envoy_binary():
|
|||
os.chmod(envoy_path, 0o755)
|
||||
with open(version_path, "w") as f:
|
||||
f.write(ENVOY_VERSION)
|
||||
log.info(f"Envoy {ENVOY_VERSION} installed at {envoy_path}")
|
||||
return envoy_path
|
||||
|
||||
finally:
|
||||
|
|
@ -187,7 +185,7 @@ def ensure_wasm_plugins():
|
|||
# 1. Local source build (inside repo)
|
||||
local = _find_local_wasm_plugins()
|
||||
if local:
|
||||
log.info(f"Using locally-built WASM plugins: {local[0]}")
|
||||
log.info("Using locally-built WASM plugins")
|
||||
return local
|
||||
|
||||
# 2. Cached download
|
||||
|
|
@ -201,9 +199,9 @@ def ensure_wasm_plugins():
|
|||
with open(version_path, "r") as f:
|
||||
cached_version = f.read().strip()
|
||||
if cached_version == version:
|
||||
log.info(f"WASM plugins {version} found at {PLANO_PLUGINS_DIR}")
|
||||
log.info(f"WASM plugins {version} (cached)")
|
||||
return prompt_gw_path, llm_gw_path
|
||||
print(
|
||||
log.info(
|
||||
f"WASM plugins version changed ({cached_version} → {version}), re-downloading..."
|
||||
)
|
||||
else:
|
||||
|
|
@ -220,6 +218,7 @@ def ensure_wasm_plugins():
|
|||
url = f"{PLANO_RELEASE_BASE_URL}/{version}/{gz_name}"
|
||||
gz_dest = dest + ".gz"
|
||||
_download_file(url, gz_dest, label=f"{name} ({version})")
|
||||
log.info(f"Decompressing {name}...")
|
||||
with gzip.open(gz_dest, "rb") as f_in, open(dest, "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
os.unlink(gz_dest)
|
||||
|
|
@ -235,7 +234,7 @@ def ensure_brightstaff_binary():
|
|||
# 1. Local source build (inside repo)
|
||||
local = _find_local_brightstaff()
|
||||
if local:
|
||||
log.info(f"Using locally-built brightstaff: {local}")
|
||||
log.info("Using locally-built brightstaff")
|
||||
return local
|
||||
|
||||
# 2. Cached download
|
||||
|
|
@ -248,9 +247,9 @@ def ensure_brightstaff_binary():
|
|||
with open(version_path, "r") as f:
|
||||
cached_version = f.read().strip()
|
||||
if cached_version == version:
|
||||
log.info(f"brightstaff {version} found at {brightstaff_path}")
|
||||
log.info(f"brightstaff {version} (cached)")
|
||||
return brightstaff_path
|
||||
print(
|
||||
log.info(
|
||||
f"brightstaff version changed ({cached_version} → {version}), re-downloading..."
|
||||
)
|
||||
else:
|
||||
|
|
@ -265,6 +264,7 @@ def ensure_brightstaff_binary():
|
|||
|
||||
gz_path = brightstaff_path + ".gz"
|
||||
_download_file(url, gz_path, label=f"brightstaff ({version}, {slug})")
|
||||
log.info("Decompressing brightstaff...")
|
||||
with gzip.open(gz_path, "rb") as f_in, open(brightstaff_path, "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
os.unlink(gz_path)
|
||||
|
|
@ -272,7 +272,6 @@ def ensure_brightstaff_binary():
|
|||
os.chmod(brightstaff_path, 0o755)
|
||||
with open(version_path, "w") as f:
|
||||
f.write(version)
|
||||
log.info(f"brightstaff {version} installed at {brightstaff_path}")
|
||||
return brightstaff_path
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import signal
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
from planoai.consts import (
|
||||
NATIVE_PID_FILE,
|
||||
|
|
@ -157,23 +158,20 @@ def render_native_config(plano_config_file, env, with_tracing=False):
|
|||
return envoy_config_path, plano_config_rendered_path
|
||||
|
||||
|
||||
def start_native(plano_config_file, env, foreground=False, with_tracing=False):
|
||||
def start_native(
|
||||
plano_config_file,
|
||||
env,
|
||||
foreground=False,
|
||||
with_tracing=False,
|
||||
progress_callback: Callable[[str], None] | None = None,
|
||||
):
|
||||
"""Start Envoy and brightstaff natively."""
|
||||
from planoai.core import _get_gateway_ports
|
||||
|
||||
console = None
|
||||
try:
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def status_print(msg):
|
||||
if console:
|
||||
console.print(msg)
|
||||
else:
|
||||
print(msg)
|
||||
# Stop any existing instance first
|
||||
if os.path.exists(NATIVE_PID_FILE):
|
||||
log.info("Stopping existing Plano instance...")
|
||||
stop_native()
|
||||
|
||||
envoy_path = ensure_envoy_binary()
|
||||
ensure_wasm_plugins()
|
||||
|
|
@ -182,7 +180,9 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
|
|||
plano_config_file, env, with_tracing=with_tracing
|
||||
)
|
||||
|
||||
status_print(f"[green]✓[/green] Configuration rendered")
|
||||
log.info("Configuration rendered")
|
||||
if progress_callback:
|
||||
progress_callback("Configuration valid...")
|
||||
|
||||
log_dir = os.path.join(PLANO_RUN_DIR, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
|
@ -203,6 +203,8 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
|
|||
os.path.join(log_dir, "brightstaff.log"),
|
||||
)
|
||||
log.info(f"Started brightstaff (PID {brightstaff_pid})")
|
||||
if progress_callback:
|
||||
progress_callback(f"Started brightstaff (PID: {brightstaff_pid})...")
|
||||
|
||||
# Start envoy
|
||||
envoy_pid = _daemon_exec(
|
||||
|
|
@ -219,6 +221,8 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
|
|||
os.path.join(log_dir, "envoy.log"),
|
||||
)
|
||||
log.info(f"Started envoy (PID {envoy_pid})")
|
||||
if progress_callback:
|
||||
progress_callback(f"Started envoy (PID: {envoy_pid})...")
|
||||
|
||||
# Save PIDs
|
||||
os.makedirs(PLANO_RUN_DIR, exist_ok=True)
|
||||
|
|
@ -233,7 +237,9 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
|
|||
|
||||
# Health check
|
||||
gateway_ports = _get_gateway_ports(plano_config_file)
|
||||
status_print(f"[dim]Waiting for listeners to become healthy...[/dim]")
|
||||
log.info("Waiting for listeners to become healthy...")
|
||||
if progress_callback:
|
||||
progress_callback("Waiting for listeners to become healthy...")
|
||||
|
||||
start_time = time.time()
|
||||
timeout = 60
|
||||
|
|
@ -244,35 +250,36 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
|
|||
all_healthy = False
|
||||
|
||||
if all_healthy:
|
||||
status_print(f"[green]✓[/green] Plano is running (native mode)")
|
||||
log.info("Plano is running (native mode)")
|
||||
for port in gateway_ports:
|
||||
status_print(f" [cyan]http://localhost:{port}[/cyan]")
|
||||
log.info(f" http://localhost:{port}")
|
||||
|
||||
break
|
||||
|
||||
# Check if processes are still alive
|
||||
if not _is_pid_alive(brightstaff_pid):
|
||||
status_print("[red]✗[/red] brightstaff exited unexpectedly")
|
||||
status_print(f" Check logs: {os.path.join(log_dir, 'brightstaff.log')}")
|
||||
log.error("brightstaff exited unexpectedly")
|
||||
log.error(f" Check logs: {os.path.join(log_dir, 'brightstaff.log')}")
|
||||
_kill_pid(envoy_pid)
|
||||
sys.exit(1)
|
||||
|
||||
if not _is_pid_alive(envoy_pid):
|
||||
status_print("[red]✗[/red] envoy exited unexpectedly")
|
||||
status_print(f" Check logs: {os.path.join(log_dir, 'envoy.log')}")
|
||||
log.error("envoy exited unexpectedly")
|
||||
log.error(f" Check logs: {os.path.join(log_dir, 'envoy.log')}")
|
||||
_kill_pid(brightstaff_pid)
|
||||
sys.exit(1)
|
||||
|
||||
if time.time() - start_time > timeout:
|
||||
status_print(f"[red]✗[/red] Health check timed out after {timeout}s")
|
||||
status_print(f" Check logs in: {log_dir}")
|
||||
log.error(f"Health check timed out after {timeout}s")
|
||||
log.error(f" Check logs in: {log_dir}")
|
||||
stop_native()
|
||||
sys.exit(1)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
if foreground:
|
||||
status_print(f"[dim]Running in foreground. Press Ctrl+C to stop.[/dim]")
|
||||
status_print(f"[dim]Logs: {log_dir}[/dim]")
|
||||
log.info("Running in foreground. Press Ctrl+C to stop.")
|
||||
log.info(f"Logs: {log_dir}")
|
||||
try:
|
||||
import glob
|
||||
|
||||
|
|
@ -290,13 +297,13 @@ def start_native(plano_config_file, env, foreground=False, with_tracing=False):
|
|||
)
|
||||
tail_proc.wait()
|
||||
except KeyboardInterrupt:
|
||||
status_print(f"\n[dim]Stopping Plano...[/dim]")
|
||||
log.info("Stopping Plano...")
|
||||
if tail_proc.poll() is None:
|
||||
tail_proc.terminate()
|
||||
stop_native()
|
||||
else:
|
||||
status_print(f"[dim]Logs: {log_dir}[/dim]")
|
||||
status_print(f"[dim]Run 'planoai down' to stop.[/dim]")
|
||||
log.info(f"Logs: {log_dir}")
|
||||
log.info("Run 'planoai down' to stop.")
|
||||
|
||||
|
||||
def _daemon_exec(args, env, log_path):
|
||||
|
|
@ -361,11 +368,19 @@ def _kill_pid(pid):
|
|||
pass
|
||||
|
||||
|
||||
def stop_native():
|
||||
"""Stop natively-running Envoy and brightstaff processes."""
|
||||
def stop_native(skip_pids: set | None = None):
|
||||
"""Stop natively-running Envoy, brightstaff, and watchdog processes.
|
||||
|
||||
Args:
|
||||
skip_pids: Set of PIDs to skip (used by the watchdog to avoid self-termination).
|
||||
|
||||
Returns:
|
||||
bool: True if at least one process was running and received a stop signal,
|
||||
False if no running native Plano process was found.
|
||||
"""
|
||||
if not os.path.exists(NATIVE_PID_FILE):
|
||||
print("No native Plano instance found (PID file missing).")
|
||||
return
|
||||
log.info("No native Plano instance found (PID file missing).")
|
||||
return False
|
||||
|
||||
with open(NATIVE_PID_FILE, "r") as f:
|
||||
pids = json.load(f)
|
||||
|
|
@ -373,17 +388,24 @@ def stop_native():
|
|||
envoy_pid = pids.get("envoy_pid")
|
||||
brightstaff_pid = pids.get("brightstaff_pid")
|
||||
|
||||
for name, pid in [("envoy", envoy_pid), ("brightstaff", brightstaff_pid)]:
|
||||
had_running_process = False
|
||||
for name, pid in [
|
||||
("envoy", envoy_pid),
|
||||
("brightstaff", brightstaff_pid),
|
||||
]:
|
||||
if skip_pids and pid in skip_pids:
|
||||
continue
|
||||
if pid is None:
|
||||
continue
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
log.info(f"Sent SIGTERM to {name} (PID {pid})")
|
||||
had_running_process = True
|
||||
except ProcessLookupError:
|
||||
log.info(f"{name} (PID {pid}) already stopped")
|
||||
continue
|
||||
except PermissionError:
|
||||
log.info(f"Permission denied stopping {name} (PID {pid})")
|
||||
log.error(f"Permission denied stopping {name} (PID {pid})")
|
||||
continue
|
||||
|
||||
# Wait for graceful shutdown
|
||||
|
|
@ -403,7 +425,11 @@ def stop_native():
|
|||
pass
|
||||
|
||||
os.unlink(NATIVE_PID_FILE)
|
||||
print("Plano stopped (native mode).")
|
||||
if had_running_process:
|
||||
log.info("Plano stopped (native mode).")
|
||||
else:
|
||||
log.info("No native Plano instance was running.")
|
||||
return had_running_process
|
||||
|
||||
|
||||
def native_validate_config(plano_config_file):
|
||||
|
|
@ -429,6 +455,51 @@ def native_validate_config(plano_config_file):
|
|||
with _temporary_env(overrides):
|
||||
from planoai.config_generator import validate_and_render_schema
|
||||
|
||||
# Suppress verbose print output from config_generator
|
||||
with contextlib.redirect_stdout(io.StringIO()):
|
||||
validate_and_render_schema()
|
||||
# Suppress verbose print output from config_generator but capture errors
|
||||
captured = io.StringIO()
|
||||
try:
|
||||
with contextlib.redirect_stdout(captured):
|
||||
validate_and_render_schema()
|
||||
except SystemExit:
|
||||
# validate_and_render_schema calls exit(1) on failure after
|
||||
# printing to stdout; re-raise so the caller gets a useful message.
|
||||
output = captured.getvalue().strip()
|
||||
raise Exception(output) if output else Exception("Config validation failed")
|
||||
|
||||
|
||||
def native_logs(debug=False, follow=False):
|
||||
"""Stream logs from native-mode Plano."""
|
||||
import glob as glob_mod
|
||||
|
||||
log_dir = os.path.join(PLANO_RUN_DIR, "logs")
|
||||
if not os.path.isdir(log_dir):
|
||||
log.error(f"No native log directory found at {log_dir}")
|
||||
log.error("Is Plano running? Start it with: planoai up <config.yaml>")
|
||||
sys.exit(1)
|
||||
|
||||
log_files = sorted(glob_mod.glob(os.path.join(log_dir, "access_*.log")))
|
||||
if debug:
|
||||
log_files.extend(
|
||||
[
|
||||
os.path.join(log_dir, "envoy.log"),
|
||||
os.path.join(log_dir, "brightstaff.log"),
|
||||
]
|
||||
)
|
||||
|
||||
# Filter to files that exist
|
||||
log_files = [f for f in log_files if os.path.exists(f)]
|
||||
if not log_files:
|
||||
log.error(f"No log files found in {log_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
tail_args = ["tail"]
|
||||
if follow:
|
||||
tail_args.append("-f")
|
||||
tail_args.extend(log_files)
|
||||
|
||||
try:
|
||||
proc = subprocess.Popen(tail_args, stdout=sys.stdout, stderr=sys.stderr)
|
||||
proc.wait()
|
||||
except KeyboardInterrupt:
|
||||
if proc.poll() is None:
|
||||
proc.terminate()
|
||||
|
|
|
|||
6
cli/planoai/obs/__init__.py
Normal file
6
cli/planoai/obs/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
"""Plano observability console: in-memory live view of LLM traffic."""
|
||||
|
||||
from planoai.obs.collector import LLMCall, LLMCallStore, ObsCollector
|
||||
from planoai.obs.pricing import PricingCatalog
|
||||
|
||||
__all__ = ["LLMCall", "LLMCallStore", "ObsCollector", "PricingCatalog"]
|
||||
266
cli/planoai/obs/collector.py
Normal file
266
cli/planoai/obs/collector.py
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
"""In-memory collector for LLM calls, fed by OTLP/gRPC spans from brightstaff."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections import deque
|
||||
from concurrent import futures
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Iterable
|
||||
|
||||
import grpc
|
||||
from opentelemetry.proto.collector.trace.v1 import (
|
||||
trace_service_pb2,
|
||||
trace_service_pb2_grpc,
|
||||
)
|
||||
|
||||
DEFAULT_GRPC_PORT = 4317
|
||||
DEFAULT_CAPACITY = 1000
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMCall:
|
||||
"""One LLM call as reconstructed from a brightstaff LLM span.
|
||||
|
||||
Fields default to ``None`` when the underlying span attribute was absent.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
timestamp: datetime
|
||||
model: str
|
||||
provider: str | None = None
|
||||
request_model: str | None = None
|
||||
session_id: str | None = None
|
||||
route_name: str | None = None
|
||||
is_streaming: bool | None = None
|
||||
status_code: int | None = None
|
||||
prompt_tokens: int | None = None
|
||||
completion_tokens: int | None = None
|
||||
total_tokens: int | None = None
|
||||
cached_input_tokens: int | None = None
|
||||
cache_creation_tokens: int | None = None
|
||||
reasoning_tokens: int | None = None
|
||||
ttft_ms: float | None = None
|
||||
duration_ms: float | None = None
|
||||
routing_strategy: str | None = None
|
||||
routing_reason: str | None = None
|
||||
cost_usd: float | None = None
|
||||
|
||||
@property
|
||||
def tpt_ms(self) -> float | None:
|
||||
if self.duration_ms is None or self.completion_tokens in (None, 0):
|
||||
return None
|
||||
ttft = self.ttft_ms or 0.0
|
||||
generate_ms = max(0.0, self.duration_ms - ttft)
|
||||
if generate_ms <= 0:
|
||||
return None
|
||||
return generate_ms / self.completion_tokens
|
||||
|
||||
@property
|
||||
def tokens_per_sec(self) -> float | None:
|
||||
tpt = self.tpt_ms
|
||||
if tpt is None or tpt <= 0:
|
||||
return None
|
||||
return 1000.0 / tpt
|
||||
|
||||
|
||||
class LLMCallStore:
|
||||
"""Thread-safe ring buffer of recent LLM calls."""
|
||||
|
||||
def __init__(self, capacity: int = DEFAULT_CAPACITY) -> None:
|
||||
self._capacity = capacity
|
||||
self._calls: deque[LLMCall] = deque(maxlen=capacity)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def capacity(self) -> int:
|
||||
return self._capacity
|
||||
|
||||
def add(self, call: LLMCall) -> None:
|
||||
with self._lock:
|
||||
self._calls.append(call)
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._calls.clear()
|
||||
|
||||
def snapshot(self) -> list[LLMCall]:
|
||||
with self._lock:
|
||||
return list(self._calls)
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._calls)
|
||||
|
||||
|
||||
# Span attribute keys used below are the canonical OTel / Plano keys emitted by
|
||||
# brightstaff — see crates/brightstaff/src/tracing/constants.rs for the source
|
||||
# of truth.
|
||||
|
||||
|
||||
def _anyvalue_to_python(value: Any) -> Any: # AnyValue from OTLP
|
||||
kind = value.WhichOneof("value")
|
||||
if kind == "string_value":
|
||||
return value.string_value
|
||||
if kind == "bool_value":
|
||||
return value.bool_value
|
||||
if kind == "int_value":
|
||||
return value.int_value
|
||||
if kind == "double_value":
|
||||
return value.double_value
|
||||
return None
|
||||
|
||||
|
||||
def _attrs_to_dict(attrs: Iterable[Any]) -> dict[str, Any]:
|
||||
out: dict[str, Any] = {}
|
||||
for kv in attrs:
|
||||
py = _anyvalue_to_python(kv.value)
|
||||
if py is not None:
|
||||
out[kv.key] = py
|
||||
return out
|
||||
|
||||
|
||||
def _maybe_int(value: Any) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _maybe_float(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def span_to_llm_call(
|
||||
span: Any, service_name: str, pricing: Any | None = None
|
||||
) -> LLMCall | None:
|
||||
"""Convert an OTLP span into an LLMCall, or return None if it isn't one.
|
||||
|
||||
A span is considered an LLM call iff it carries the ``llm.model`` attribute.
|
||||
"""
|
||||
attrs = _attrs_to_dict(span.attributes)
|
||||
model = attrs.get("llm.model")
|
||||
if not model:
|
||||
return None
|
||||
|
||||
# Prefer explicit span attributes; fall back to likely aliases.
|
||||
request_id = next(
|
||||
(
|
||||
str(attrs[key])
|
||||
for key in ("request_id", "http.request_id")
|
||||
if key in attrs and attrs[key] is not None
|
||||
),
|
||||
span.span_id.hex() if span.span_id else "",
|
||||
)
|
||||
start_ns = span.start_time_unix_nano or 0
|
||||
ts = (
|
||||
datetime.fromtimestamp(start_ns / 1_000_000_000, tz=timezone.utc).astimezone()
|
||||
if start_ns
|
||||
else datetime.now().astimezone()
|
||||
)
|
||||
|
||||
call = LLMCall(
|
||||
request_id=str(request_id),
|
||||
timestamp=ts,
|
||||
model=str(model),
|
||||
provider=(
|
||||
str(attrs["llm.provider"]) if "llm.provider" in attrs else service_name
|
||||
),
|
||||
request_model=(
|
||||
str(attrs["model.requested"]) if "model.requested" in attrs else None
|
||||
),
|
||||
session_id=(
|
||||
str(attrs["plano.session_id"]) if "plano.session_id" in attrs else None
|
||||
),
|
||||
route_name=(
|
||||
str(attrs["plano.route.name"]) if "plano.route.name" in attrs else None
|
||||
),
|
||||
is_streaming=(
|
||||
bool(attrs["llm.is_streaming"]) if "llm.is_streaming" in attrs else None
|
||||
),
|
||||
status_code=_maybe_int(attrs.get("http.status_code")),
|
||||
prompt_tokens=_maybe_int(attrs.get("llm.usage.prompt_tokens")),
|
||||
completion_tokens=_maybe_int(attrs.get("llm.usage.completion_tokens")),
|
||||
total_tokens=_maybe_int(attrs.get("llm.usage.total_tokens")),
|
||||
cached_input_tokens=_maybe_int(attrs.get("llm.usage.cached_input_tokens")),
|
||||
cache_creation_tokens=_maybe_int(attrs.get("llm.usage.cache_creation_tokens")),
|
||||
reasoning_tokens=_maybe_int(attrs.get("llm.usage.reasoning_tokens")),
|
||||
ttft_ms=_maybe_float(attrs.get("llm.time_to_first_token")),
|
||||
duration_ms=_maybe_float(attrs.get("llm.duration_ms")),
|
||||
routing_strategy=(
|
||||
str(attrs["routing.strategy"]) if "routing.strategy" in attrs else None
|
||||
),
|
||||
routing_reason=(
|
||||
str(attrs["routing.selection_reason"])
|
||||
if "routing.selection_reason" in attrs
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if pricing is not None:
|
||||
call.cost_usd = pricing.cost_for_call(call)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
class _ObsServicer(trace_service_pb2_grpc.TraceServiceServicer):
|
||||
def __init__(self, store: LLMCallStore, pricing: Any | None) -> None:
|
||||
self._store = store
|
||||
self._pricing = pricing
|
||||
|
||||
def Export(self, request, context): # noqa: N802 — gRPC generated name
|
||||
for resource_spans in request.resource_spans:
|
||||
service_name = "unknown"
|
||||
for attr in resource_spans.resource.attributes:
|
||||
if attr.key == "service.name":
|
||||
val = _anyvalue_to_python(attr.value)
|
||||
if val is not None:
|
||||
service_name = str(val)
|
||||
break
|
||||
for scope_spans in resource_spans.scope_spans:
|
||||
for span in scope_spans.spans:
|
||||
call = span_to_llm_call(span, service_name, self._pricing)
|
||||
if call is not None:
|
||||
self._store.add(call)
|
||||
return trace_service_pb2.ExportTraceServiceResponse()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObsCollector:
|
||||
"""Owns the OTLP/gRPC server and the in-memory LLMCall ring buffer."""
|
||||
|
||||
store: LLMCallStore = field(default_factory=LLMCallStore)
|
||||
pricing: Any | None = None
|
||||
host: str = "0.0.0.0"
|
||||
port: int = DEFAULT_GRPC_PORT
|
||||
_server: grpc.Server | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def start(self) -> None:
|
||||
if self._server is not None:
|
||||
return
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
trace_service_pb2_grpc.add_TraceServiceServicer_to_server(
|
||||
_ObsServicer(self.store, self.pricing), server
|
||||
)
|
||||
address = f"{self.host}:{self.port}"
|
||||
bound = server.add_insecure_port(address)
|
||||
if bound == 0:
|
||||
raise OSError(
|
||||
f"Failed to bind OTLP listener on {address}: port already in use. "
|
||||
"Stop tracing via `planoai trace down` or pick another port with --port."
|
||||
)
|
||||
server.start()
|
||||
self._server = server
|
||||
|
||||
def stop(self, grace: float = 2.0) -> None:
|
||||
if self._server is not None:
|
||||
self._server.stop(grace)
|
||||
self._server = None
|
||||
321
cli/planoai/obs/pricing.py
Normal file
321
cli/planoai/obs/pricing.py
Normal file
|
|
@ -0,0 +1,321 @@
|
|||
"""DigitalOcean Gradient pricing catalog for the obs console.
|
||||
|
||||
Ported loosely from ``crates/brightstaff/src/router/model_metrics.rs::fetch_do_pricing``.
|
||||
Single-source: one fetch at startup, cached for the life of the process.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
DEFAULT_PRICING_URL = "https://api.digitalocean.com/v2/gen-ai/models/catalog"
|
||||
FETCH_TIMEOUT_SECS = 5.0
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelPrice:
|
||||
"""Input/output $/token rates. Token counts are multiplied by these."""
|
||||
|
||||
input_per_token_usd: float
|
||||
output_per_token_usd: float
|
||||
cached_input_per_token_usd: float | None = None
|
||||
|
||||
|
||||
class PricingCatalog:
|
||||
"""In-memory pricing lookup keyed by model id.
|
||||
|
||||
DO's catalog uses ids like ``openai-gpt-5.4``; Plano's resolved model names
|
||||
may arrive as ``do/openai-gpt-5.4`` or bare ``openai-gpt-5.4``. We strip the
|
||||
leading provider prefix when looking up.
|
||||
"""
|
||||
|
||||
def __init__(self, prices: dict[str, ModelPrice] | None = None) -> None:
|
||||
self._prices: dict[str, ModelPrice] = prices or {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._prices)
|
||||
|
||||
def sample_models(self, n: int = 5) -> list[str]:
|
||||
with self._lock:
|
||||
return list(self._prices.keys())[:n]
|
||||
|
||||
@classmethod
|
||||
def fetch(cls, url: str = DEFAULT_PRICING_URL) -> "PricingCatalog":
|
||||
"""Fetch pricing from DO's catalog endpoint. On failure, returns an
|
||||
empty catalog (cost column will be blank).
|
||||
|
||||
The catalog endpoint is public — no auth required, no signup — so
|
||||
``planoai obs`` gets cost data on first run out of the box.
|
||||
"""
|
||||
try:
|
||||
resp = requests.get(url, timeout=FETCH_TIMEOUT_SECS)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except Exception as exc: # noqa: BLE001 — best-effort; never fatal
|
||||
logger.warning(
|
||||
"DO pricing fetch failed: %s; cost column will be blank.",
|
||||
exc,
|
||||
)
|
||||
return cls()
|
||||
|
||||
prices = _parse_do_pricing(data)
|
||||
if not prices:
|
||||
# Dump the first entry's raw shape so we can see which fields DO
|
||||
# actually returned — helps when the catalog adds new fields or
|
||||
# the response doesn't match our parser.
|
||||
import json as _json
|
||||
|
||||
sample_items = _coerce_items(data)
|
||||
sample = sample_items[0] if sample_items else data
|
||||
logger.warning(
|
||||
"DO pricing response had no parseable entries; cost column "
|
||||
"will be blank. Sample entry: %s",
|
||||
_json.dumps(sample, default=str)[:400],
|
||||
)
|
||||
return cls(prices)
|
||||
|
||||
def price_for(self, model_name: str | None) -> ModelPrice | None:
|
||||
if not model_name:
|
||||
return None
|
||||
with self._lock:
|
||||
# Try the full name first, then stripped prefix, then lowercased variants.
|
||||
for candidate in _model_key_candidates(model_name):
|
||||
hit = self._prices.get(candidate)
|
||||
if hit is not None:
|
||||
return hit
|
||||
return None
|
||||
|
||||
def cost_for_call(self, call: Any) -> float | None:
|
||||
"""Compute USD cost for an LLMCall. Returns None when pricing is unknown."""
|
||||
price = self.price_for(getattr(call, "model", None)) or self.price_for(
|
||||
getattr(call, "request_model", None)
|
||||
)
|
||||
if price is None:
|
||||
return None
|
||||
prompt = int(getattr(call, "prompt_tokens", 0) or 0)
|
||||
completion = int(getattr(call, "completion_tokens", 0) or 0)
|
||||
cached = int(getattr(call, "cached_input_tokens", 0) or 0)
|
||||
|
||||
# Cached input tokens are priced separately at the cached rate when known;
|
||||
# otherwise they're already counted in prompt tokens at the regular rate.
|
||||
fresh_prompt = prompt
|
||||
if price.cached_input_per_token_usd is not None and cached:
|
||||
fresh_prompt = max(0, prompt - cached)
|
||||
cost_cached = cached * price.cached_input_per_token_usd
|
||||
else:
|
||||
cost_cached = 0.0
|
||||
|
||||
cost = (
|
||||
fresh_prompt * price.input_per_token_usd
|
||||
+ completion * price.output_per_token_usd
|
||||
+ cost_cached
|
||||
)
|
||||
return round(cost, 6)
|
||||
|
||||
|
||||
_DATE_SUFFIX_RE = re.compile(r"-\d{8}$")
|
||||
_PROVIDER_PREFIXES = ("anthropic", "openai", "google", "meta", "cohere", "mistral")
|
||||
_ANTHROPIC_FAMILIES = {"opus", "sonnet", "haiku"}
|
||||
|
||||
|
||||
def _model_key_candidates(model_name: str) -> list[str]:
|
||||
"""Lookup-side variants of a Plano-emitted model name.
|
||||
|
||||
Plano resolves names like ``claude-haiku-4-5-20251001``; the catalog stores
|
||||
them as ``anthropic-claude-haiku-4.5``. We strip the date suffix and the
|
||||
``provider/`` prefix here; the catalog itself registers the dash/dot and
|
||||
family-order aliases at parse time (see :func:`_expand_aliases`).
|
||||
"""
|
||||
base = model_name.strip()
|
||||
out = [base]
|
||||
if "/" in base:
|
||||
out.append(base.split("/", 1)[1])
|
||||
for k in list(out):
|
||||
stripped = _DATE_SUFFIX_RE.sub("", k)
|
||||
if stripped != k:
|
||||
out.append(stripped)
|
||||
out.extend([v.lower() for v in list(out)])
|
||||
seen: set[str] = set()
|
||||
uniq = []
|
||||
for key in out:
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
uniq.append(key)
|
||||
return uniq
|
||||
|
||||
|
||||
def _expand_aliases(model_id: str) -> set[str]:
|
||||
"""Catalog-side variants of a DO model id.
|
||||
|
||||
DO publishes Anthropic models under ids like ``anthropic-claude-opus-4.7``
|
||||
or ``anthropic-claude-4.6-sonnet`` while Plano emits ``claude-opus-4-7`` /
|
||||
``claude-sonnet-4-6``. Generate a set covering provider-prefix stripping,
|
||||
dash↔dot in version segments, and family↔version word order so a single
|
||||
catalog entry matches every name shape we'll see at lookup.
|
||||
"""
|
||||
aliases: set[str] = set()
|
||||
|
||||
def add(name: str) -> None:
|
||||
if not name:
|
||||
return
|
||||
aliases.add(name)
|
||||
aliases.add(name.lower())
|
||||
|
||||
add(model_id)
|
||||
|
||||
base = model_id
|
||||
head, _, rest = base.partition("-")
|
||||
if head.lower() in _PROVIDER_PREFIXES and rest:
|
||||
add(rest)
|
||||
base = rest
|
||||
|
||||
for key in list(aliases):
|
||||
if "." in key:
|
||||
add(key.replace(".", "-"))
|
||||
|
||||
parts = base.split("-")
|
||||
if len(parts) >= 3 and parts[0].lower() == "claude":
|
||||
rest_parts = parts[1:]
|
||||
for i, p in enumerate(rest_parts):
|
||||
if p.lower() in _ANTHROPIC_FAMILIES:
|
||||
others = rest_parts[:i] + rest_parts[i + 1 :]
|
||||
if not others:
|
||||
break
|
||||
family_last = "claude-" + "-".join(others) + "-" + p
|
||||
family_first = "claude-" + p + "-" + "-".join(others)
|
||||
add(family_last)
|
||||
add(family_first)
|
||||
add(family_last.replace(".", "-"))
|
||||
add(family_first.replace(".", "-"))
|
||||
break
|
||||
|
||||
return aliases
|
||||
|
||||
|
||||
def _parse_do_pricing(data: Any) -> dict[str, ModelPrice]:
|
||||
"""Parse DO catalog response into a ModelPrice map keyed by model id.
|
||||
|
||||
DO's shape (as of 2026-04):
|
||||
{
|
||||
"data": [
|
||||
{"model_id": "openai-gpt-5.4",
|
||||
"pricing": {"input_price_per_million": 5.0,
|
||||
"output_price_per_million": 15.0}},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
Older/alternate shapes are also accepted (flat top-level fields, or the
|
||||
``id``/``model``/``name`` key).
|
||||
"""
|
||||
prices: dict[str, ModelPrice] = {}
|
||||
items = _coerce_items(data)
|
||||
for item in items:
|
||||
model_id = (
|
||||
item.get("model_id")
|
||||
or item.get("id")
|
||||
or item.get("model")
|
||||
or item.get("name")
|
||||
)
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# DO nests rates under `pricing`; try that first, then fall back to
|
||||
# top-level fields for alternate response shapes.
|
||||
sources = [item]
|
||||
if isinstance(item.get("pricing"), dict):
|
||||
sources.insert(0, item["pricing"])
|
||||
|
||||
input_rate = _extract_rate_from_sources(
|
||||
sources,
|
||||
["input_per_token", "input_token_price", "price_input"],
|
||||
["input_price_per_million", "input_per_million", "input_per_mtok"],
|
||||
)
|
||||
output_rate = _extract_rate_from_sources(
|
||||
sources,
|
||||
["output_per_token", "output_token_price", "price_output"],
|
||||
["output_price_per_million", "output_per_million", "output_per_mtok"],
|
||||
)
|
||||
cached_rate = _extract_rate_from_sources(
|
||||
sources,
|
||||
[
|
||||
"cached_input_per_token",
|
||||
"cached_input_token_price",
|
||||
"prompt_cache_read_per_token",
|
||||
],
|
||||
[
|
||||
"cached_input_price_per_million",
|
||||
"cached_input_per_million",
|
||||
"cached_input_per_mtok",
|
||||
],
|
||||
)
|
||||
|
||||
if input_rate is None or output_rate is None:
|
||||
continue
|
||||
# Treat 0-rate entries as "unknown" so cost falls back to `—` rather
|
||||
# than showing a misleading $0.0000. DO's catalog sometimes omits
|
||||
# rates for promo/open-weight models.
|
||||
if input_rate == 0 and output_rate == 0:
|
||||
continue
|
||||
price = ModelPrice(
|
||||
input_per_token_usd=input_rate,
|
||||
output_per_token_usd=output_rate,
|
||||
cached_input_per_token_usd=cached_rate,
|
||||
)
|
||||
for alias in _expand_aliases(str(model_id)):
|
||||
prices.setdefault(alias, price)
|
||||
return prices
|
||||
|
||||
|
||||
def _coerce_items(data: Any) -> list[dict]:
|
||||
if isinstance(data, list):
|
||||
return [x for x in data if isinstance(x, dict)]
|
||||
if isinstance(data, dict):
|
||||
for key in ("data", "models", "pricing", "items"):
|
||||
val = data.get(key)
|
||||
if isinstance(val, list):
|
||||
return [x for x in val if isinstance(x, dict)]
|
||||
return []
|
||||
|
||||
|
||||
def _extract_rate_from_sources(
|
||||
sources: list[dict],
|
||||
per_token_keys: list[str],
|
||||
per_million_keys: list[str],
|
||||
) -> float | None:
|
||||
"""Return a per-token rate in USD, or None if unknown.
|
||||
|
||||
Some DO catalog responses put per-token values under a field whose name
|
||||
says ``_per_million`` (e.g. ``input_price_per_million: 5E-8`` — that's
|
||||
$5e-8 per token, not per million). Heuristic: values < 1 are already
|
||||
per-token (real per-million rates are ~0.1 to ~100); values >= 1 are
|
||||
treated as per-million and divided by 1,000,000.
|
||||
"""
|
||||
for src in sources:
|
||||
for key in per_token_keys:
|
||||
if key in src and src[key] is not None:
|
||||
try:
|
||||
return float(src[key])
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
for key in per_million_keys:
|
||||
if key in src and src[key] is not None:
|
||||
try:
|
||||
v = float(src[key])
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if v >= 1:
|
||||
return v / 1_000_000
|
||||
return v
|
||||
return None
|
||||
634
cli/planoai/obs/render.py
Normal file
634
cli/planoai/obs/render.py
Normal file
|
|
@ -0,0 +1,634 @@
|
|||
"""Rich TUI renderer for the observability console."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
|
||||
from rich.align import Align
|
||||
from rich.box import SIMPLE, SIMPLE_HEAVY
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
MAX_WIDTH = 160
|
||||
|
||||
from planoai.obs.collector import LLMCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregateStats:
|
||||
count: int
|
||||
total_cost_usd: float
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
distinct_sessions: int
|
||||
current_session: str | None
|
||||
p50_latency_ms: float | None = None
|
||||
p95_latency_ms: float | None = None
|
||||
p99_latency_ms: float | None = None
|
||||
p50_ttft_ms: float | None = None
|
||||
p95_ttft_ms: float | None = None
|
||||
p99_ttft_ms: float | None = None
|
||||
error_count: int = 0
|
||||
errors_4xx: int = 0
|
||||
errors_5xx: int = 0
|
||||
has_cost: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRollup:
|
||||
model: str
|
||||
requests: int
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_write: int
|
||||
cache_read: int
|
||||
cost_usd: float
|
||||
has_cost: bool = False
|
||||
avg_tokens_per_sec: float | None = None
|
||||
|
||||
|
||||
def _percentile(values: list[float], pct: float) -> float | None:
|
||||
if not values:
|
||||
return None
|
||||
s = sorted(values)
|
||||
k = max(0, min(len(s) - 1, int(round((pct / 100.0) * (len(s) - 1)))))
|
||||
return s[k]
|
||||
|
||||
|
||||
def aggregates(calls: list[LLMCall]) -> AggregateStats:
|
||||
total_cost = sum((c.cost_usd or 0.0) for c in calls)
|
||||
total_input = sum(int(c.prompt_tokens or 0) for c in calls)
|
||||
total_output = sum(int(c.completion_tokens or 0) for c in calls)
|
||||
session_ids = {c.session_id for c in calls if c.session_id}
|
||||
current = next(
|
||||
(c.session_id for c in reversed(calls) if c.session_id is not None), None
|
||||
)
|
||||
durations = [c.duration_ms for c in calls if c.duration_ms is not None]
|
||||
ttfts = [c.ttft_ms for c in calls if c.ttft_ms is not None]
|
||||
errors_4xx = sum(
|
||||
1 for c in calls if c.status_code is not None and 400 <= c.status_code < 500
|
||||
)
|
||||
errors_5xx = sum(
|
||||
1 for c in calls if c.status_code is not None and c.status_code >= 500
|
||||
)
|
||||
has_cost = any(c.cost_usd is not None for c in calls)
|
||||
return AggregateStats(
|
||||
count=len(calls),
|
||||
total_cost_usd=total_cost,
|
||||
total_input_tokens=total_input,
|
||||
total_output_tokens=total_output,
|
||||
distinct_sessions=len(session_ids),
|
||||
current_session=current,
|
||||
p50_latency_ms=_percentile(durations, 50),
|
||||
p95_latency_ms=_percentile(durations, 95),
|
||||
p99_latency_ms=_percentile(durations, 99),
|
||||
p50_ttft_ms=_percentile(ttfts, 50),
|
||||
p95_ttft_ms=_percentile(ttfts, 95),
|
||||
p99_ttft_ms=_percentile(ttfts, 99),
|
||||
error_count=errors_4xx + errors_5xx,
|
||||
errors_4xx=errors_4xx,
|
||||
errors_5xx=errors_5xx,
|
||||
has_cost=has_cost,
|
||||
)
|
||||
|
||||
|
||||
def model_rollups(calls: list[LLMCall]) -> list[ModelRollup]:
|
||||
buckets: dict[str, dict[str, float | int | bool]] = {}
|
||||
tps_samples: dict[str, list[float]] = {}
|
||||
for c in calls:
|
||||
key = c.model
|
||||
b = buckets.setdefault(
|
||||
key,
|
||||
{
|
||||
"requests": 0,
|
||||
"input": 0,
|
||||
"output": 0,
|
||||
"cache_write": 0,
|
||||
"cache_read": 0,
|
||||
"cost": 0.0,
|
||||
"has_cost": False,
|
||||
},
|
||||
)
|
||||
b["requests"] = int(b["requests"]) + 1
|
||||
b["input"] = int(b["input"]) + int(c.prompt_tokens or 0)
|
||||
b["output"] = int(b["output"]) + int(c.completion_tokens or 0)
|
||||
b["cache_write"] = int(b["cache_write"]) + int(c.cache_creation_tokens or 0)
|
||||
b["cache_read"] = int(b["cache_read"]) + int(c.cached_input_tokens or 0)
|
||||
b["cost"] = float(b["cost"]) + (c.cost_usd or 0.0)
|
||||
if c.cost_usd is not None:
|
||||
b["has_cost"] = True
|
||||
tps = c.tokens_per_sec
|
||||
if tps is not None:
|
||||
tps_samples.setdefault(key, []).append(tps)
|
||||
|
||||
rollups: list[ModelRollup] = []
|
||||
for model, b in buckets.items():
|
||||
samples = tps_samples.get(model)
|
||||
avg_tps = (sum(samples) / len(samples)) if samples else None
|
||||
rollups.append(
|
||||
ModelRollup(
|
||||
model=model,
|
||||
requests=int(b["requests"]),
|
||||
input_tokens=int(b["input"]),
|
||||
output_tokens=int(b["output"]),
|
||||
cache_write=int(b["cache_write"]),
|
||||
cache_read=int(b["cache_read"]),
|
||||
cost_usd=float(b["cost"]),
|
||||
has_cost=bool(b["has_cost"]),
|
||||
avg_tokens_per_sec=avg_tps,
|
||||
)
|
||||
)
|
||||
rollups.sort(key=lambda r: (r.cost_usd, r.requests), reverse=True)
|
||||
return rollups
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouteHit:
|
||||
route: str
|
||||
hits: int
|
||||
pct: float
|
||||
p95_latency_ms: float | None
|
||||
error_count: int
|
||||
|
||||
|
||||
def route_hits(calls: list[LLMCall]) -> list[RouteHit]:
|
||||
counts: Counter[str] = Counter()
|
||||
per_route_latency: dict[str, list[float]] = {}
|
||||
per_route_errors: dict[str, int] = {}
|
||||
for c in calls:
|
||||
if not c.route_name:
|
||||
continue
|
||||
counts[c.route_name] += 1
|
||||
if c.duration_ms is not None:
|
||||
per_route_latency.setdefault(c.route_name, []).append(c.duration_ms)
|
||||
if c.status_code is not None and c.status_code >= 400:
|
||||
per_route_errors[c.route_name] = per_route_errors.get(c.route_name, 0) + 1
|
||||
total = sum(counts.values())
|
||||
if total == 0:
|
||||
return []
|
||||
return [
|
||||
RouteHit(
|
||||
route=r,
|
||||
hits=n,
|
||||
pct=(n / total) * 100.0,
|
||||
p95_latency_ms=_percentile(per_route_latency.get(r, []), 95),
|
||||
error_count=per_route_errors.get(r, 0),
|
||||
)
|
||||
for r, n in counts.most_common()
|
||||
]
|
||||
|
||||
|
||||
def _fmt_cost(v: float | None, *, zero: str = "—") -> str:
|
||||
if v is None:
|
||||
return "—"
|
||||
if v == 0:
|
||||
return zero
|
||||
if abs(v) < 0.0001:
|
||||
return f"${v:.8f}".rstrip("0").rstrip(".")
|
||||
if abs(v) < 0.01:
|
||||
return f"${v:.6f}".rstrip("0").rstrip(".")
|
||||
if abs(v) < 1:
|
||||
return f"${v:.4f}"
|
||||
return f"${v:,.2f}"
|
||||
|
||||
|
||||
def _fmt_ms(v: float | None) -> str:
|
||||
if v is None:
|
||||
return "—"
|
||||
if v >= 1000:
|
||||
return f"{v / 1000:.1f}s"
|
||||
return f"{v:.0f}ms"
|
||||
|
||||
|
||||
def _fmt_int(v: int | None) -> str:
|
||||
if v is None or v == 0:
|
||||
return "—"
|
||||
return f"{v:,}"
|
||||
|
||||
|
||||
def _fmt_tokens(v: int | None) -> str:
|
||||
if v is None:
|
||||
return "—"
|
||||
return f"{v:,}"
|
||||
|
||||
|
||||
def _fmt_tps(v: float | None) -> str:
|
||||
if v is None or v <= 0:
|
||||
return "—"
|
||||
if v >= 100:
|
||||
return f"{v:.0f}/s"
|
||||
return f"{v:.1f}/s"
|
||||
|
||||
|
||||
def _latency_style(v: float | None) -> str:
|
||||
if v is None:
|
||||
return "dim"
|
||||
if v < 500:
|
||||
return "green"
|
||||
if v < 2000:
|
||||
return "yellow"
|
||||
return "red"
|
||||
|
||||
|
||||
def _ttft_style(v: float | None) -> str:
|
||||
if v is None:
|
||||
return "dim"
|
||||
if v < 300:
|
||||
return "green"
|
||||
if v < 1000:
|
||||
return "yellow"
|
||||
return "red"
|
||||
|
||||
|
||||
def _truncate_model(name: str, limit: int = 32) -> str:
|
||||
if len(name) <= limit:
|
||||
return name
|
||||
return name[: limit - 1] + "…"
|
||||
|
||||
|
||||
def _status_text(code: int | None) -> Text:
|
||||
if code is None:
|
||||
return Text("—", style="dim")
|
||||
if 200 <= code < 300:
|
||||
return Text("● ok", style="green")
|
||||
if 300 <= code < 400:
|
||||
return Text(f"● {code}", style="yellow")
|
||||
if 400 <= code < 500:
|
||||
return Text(f"● {code}", style="yellow bold")
|
||||
return Text(f"● {code}", style="red bold")
|
||||
|
||||
|
||||
def _summary_panel(last: LLMCall | None, stats: AggregateStats) -> Panel:
|
||||
# Content-sized columns with a fixed gutter keep the two blocks close
|
||||
# together instead of stretching across the full terminal on wide screens.
|
||||
grid = Table.grid(padding=(0, 4))
|
||||
grid.add_column(no_wrap=True)
|
||||
grid.add_column(no_wrap=True)
|
||||
|
||||
# Left: latest request snapshot.
|
||||
left = Table.grid(padding=(0, 1))
|
||||
left.add_column(style="dim", no_wrap=True)
|
||||
left.add_column(no_wrap=True)
|
||||
if last is None:
|
||||
left.add_row("latest", Text("waiting for spans…", style="dim italic"))
|
||||
else:
|
||||
model_text = Text(_truncate_model(last.model, 48), style="bold cyan")
|
||||
if last.is_streaming:
|
||||
model_text.append(" ⟳ stream", style="dim")
|
||||
left.add_row("model", model_text)
|
||||
if last.request_model and last.request_model != last.model:
|
||||
left.add_row(
|
||||
"requested", Text(_truncate_model(last.request_model, 48), style="cyan")
|
||||
)
|
||||
if last.route_name:
|
||||
left.add_row("route", Text(last.route_name, style="yellow"))
|
||||
left.add_row("status", _status_text(last.status_code))
|
||||
tokens = Text()
|
||||
tokens.append(_fmt_tokens(last.prompt_tokens))
|
||||
tokens.append(" in", style="dim")
|
||||
tokens.append(" · ", style="dim")
|
||||
tokens.append(_fmt_tokens(last.completion_tokens), style="green")
|
||||
tokens.append(" out", style="dim")
|
||||
if last.cached_input_tokens:
|
||||
tokens.append(" · ", style="dim")
|
||||
tokens.append(_fmt_tokens(last.cached_input_tokens), style="yellow")
|
||||
tokens.append(" cached", style="dim")
|
||||
left.add_row("tokens", tokens)
|
||||
timing = Text()
|
||||
timing.append("TTFT ", style="dim")
|
||||
timing.append(_fmt_ms(last.ttft_ms), style=_ttft_style(last.ttft_ms))
|
||||
timing.append(" · ", style="dim")
|
||||
timing.append("lat ", style="dim")
|
||||
timing.append(_fmt_ms(last.duration_ms), style=_latency_style(last.duration_ms))
|
||||
tps = last.tokens_per_sec
|
||||
if tps:
|
||||
timing.append(" · ", style="dim")
|
||||
timing.append(_fmt_tps(tps), style="green")
|
||||
left.add_row("timing", timing)
|
||||
left.add_row("cost", Text(_fmt_cost(last.cost_usd), style="green bold"))
|
||||
|
||||
# Right: lifetime totals.
|
||||
right = Table.grid(padding=(0, 1))
|
||||
right.add_column(style="dim", no_wrap=True)
|
||||
right.add_column(no_wrap=True)
|
||||
right.add_row(
|
||||
"requests",
|
||||
Text(f"{stats.count:,}", style="bold"),
|
||||
)
|
||||
if stats.error_count:
|
||||
err_text = Text()
|
||||
err_text.append(f"{stats.error_count:,}", style="red bold")
|
||||
parts: list[str] = []
|
||||
if stats.errors_4xx:
|
||||
parts.append(f"{stats.errors_4xx} 4xx")
|
||||
if stats.errors_5xx:
|
||||
parts.append(f"{stats.errors_5xx} 5xx")
|
||||
if parts:
|
||||
err_text.append(f" ({' · '.join(parts)})", style="dim")
|
||||
right.add_row("errors", err_text)
|
||||
cost_str = _fmt_cost(stats.total_cost_usd) if stats.has_cost else "—"
|
||||
right.add_row("total cost", Text(cost_str, style="green bold"))
|
||||
tokens_total = Text()
|
||||
tokens_total.append(_fmt_tokens(stats.total_input_tokens))
|
||||
tokens_total.append(" in", style="dim")
|
||||
tokens_total.append(" · ", style="dim")
|
||||
tokens_total.append(_fmt_tokens(stats.total_output_tokens), style="green")
|
||||
tokens_total.append(" out", style="dim")
|
||||
right.add_row("tokens", tokens_total)
|
||||
lat_text = Text()
|
||||
lat_text.append("p50 ", style="dim")
|
||||
lat_text.append(
|
||||
_fmt_ms(stats.p50_latency_ms), style=_latency_style(stats.p50_latency_ms)
|
||||
)
|
||||
lat_text.append(" · ", style="dim")
|
||||
lat_text.append("p95 ", style="dim")
|
||||
lat_text.append(
|
||||
_fmt_ms(stats.p95_latency_ms), style=_latency_style(stats.p95_latency_ms)
|
||||
)
|
||||
lat_text.append(" · ", style="dim")
|
||||
lat_text.append("p99 ", style="dim")
|
||||
lat_text.append(
|
||||
_fmt_ms(stats.p99_latency_ms), style=_latency_style(stats.p99_latency_ms)
|
||||
)
|
||||
right.add_row("latency", lat_text)
|
||||
ttft_text = Text()
|
||||
ttft_text.append("p50 ", style="dim")
|
||||
ttft_text.append(_fmt_ms(stats.p50_ttft_ms), style=_ttft_style(stats.p50_ttft_ms))
|
||||
ttft_text.append(" · ", style="dim")
|
||||
ttft_text.append("p95 ", style="dim")
|
||||
ttft_text.append(_fmt_ms(stats.p95_ttft_ms), style=_ttft_style(stats.p95_ttft_ms))
|
||||
ttft_text.append(" · ", style="dim")
|
||||
ttft_text.append("p99 ", style="dim")
|
||||
ttft_text.append(_fmt_ms(stats.p99_ttft_ms), style=_ttft_style(stats.p99_ttft_ms))
|
||||
right.add_row("TTFT", ttft_text)
|
||||
sess = Text()
|
||||
sess.append(f"{stats.distinct_sessions}")
|
||||
if stats.current_session:
|
||||
sess.append(" · current ", style="dim")
|
||||
sess.append(stats.current_session, style="magenta")
|
||||
right.add_row("sessions", sess)
|
||||
|
||||
grid.add_row(left, right)
|
||||
return Panel(
|
||||
grid,
|
||||
title="[bold]live LLM traffic[/]",
|
||||
border_style="cyan",
|
||||
box=SIMPLE_HEAVY,
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
def _model_rollup_table(rollups: list[ModelRollup]) -> Table:
|
||||
table = Table(
|
||||
title="by model",
|
||||
title_justify="left",
|
||||
title_style="bold dim",
|
||||
caption="cost via DigitalOcean Gradient catalog",
|
||||
caption_justify="left",
|
||||
caption_style="dim italic",
|
||||
box=SIMPLE,
|
||||
header_style="bold",
|
||||
pad_edge=False,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("model", style="cyan", no_wrap=True)
|
||||
table.add_column("req", justify="right")
|
||||
table.add_column("input", justify="right")
|
||||
table.add_column("output", justify="right", style="green")
|
||||
table.add_column("cache wr", justify="right", style="yellow")
|
||||
table.add_column("cache rd", justify="right", style="yellow")
|
||||
table.add_column("tok/s", justify="right")
|
||||
table.add_column("cost", justify="right", style="green")
|
||||
if not rollups:
|
||||
table.add_row(
|
||||
Text("no requests yet", style="dim italic"),
|
||||
*(["—"] * 7),
|
||||
)
|
||||
return table
|
||||
for r in rollups:
|
||||
cost_cell = _fmt_cost(r.cost_usd) if r.has_cost else "—"
|
||||
table.add_row(
|
||||
_truncate_model(r.model),
|
||||
f"{r.requests:,}",
|
||||
_fmt_tokens(r.input_tokens),
|
||||
_fmt_tokens(r.output_tokens),
|
||||
_fmt_int(r.cache_write),
|
||||
_fmt_int(r.cache_read),
|
||||
_fmt_tps(r.avg_tokens_per_sec),
|
||||
cost_cell,
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def _route_hit_table(hits: list[RouteHit]) -> Table:
|
||||
table = Table(
|
||||
title="route share",
|
||||
title_justify="left",
|
||||
title_style="bold dim",
|
||||
box=SIMPLE,
|
||||
header_style="bold",
|
||||
pad_edge=False,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("route", style="cyan")
|
||||
table.add_column("hits", justify="right")
|
||||
table.add_column("%", justify="right")
|
||||
table.add_column("p95", justify="right")
|
||||
table.add_column("err", justify="right")
|
||||
for h in hits:
|
||||
err_cell = (
|
||||
Text(f"{h.error_count:,}", style="red bold") if h.error_count else "—"
|
||||
)
|
||||
table.add_row(
|
||||
h.route,
|
||||
f"{h.hits:,}",
|
||||
f"{h.pct:5.1f}%",
|
||||
Text(_fmt_ms(h.p95_latency_ms), style=_latency_style(h.p95_latency_ms)),
|
||||
err_cell,
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def _recent_table(calls: list[LLMCall], limit: int = 15) -> Table:
|
||||
show_route = any(c.route_name for c in calls)
|
||||
show_cache = any((c.cached_input_tokens or 0) > 0 for c in calls)
|
||||
show_rsn = any((c.reasoning_tokens or 0) > 0 for c in calls)
|
||||
|
||||
caption_parts = ["in·new = fresh prompt tokens"]
|
||||
if show_cache:
|
||||
caption_parts.append("in·cache = cached read")
|
||||
if show_rsn:
|
||||
caption_parts.append("rsn = reasoning")
|
||||
caption_parts.append("lat = total latency")
|
||||
|
||||
table = Table(
|
||||
title=f"recent · last {min(limit, len(calls)) if calls else 0}",
|
||||
title_justify="left",
|
||||
title_style="bold dim",
|
||||
caption=" · ".join(caption_parts),
|
||||
caption_justify="left",
|
||||
caption_style="dim italic",
|
||||
box=SIMPLE,
|
||||
header_style="bold",
|
||||
pad_edge=False,
|
||||
padding=(0, 1),
|
||||
)
|
||||
table.add_column("time", no_wrap=True)
|
||||
table.add_column("model", style="cyan", no_wrap=True)
|
||||
if show_route:
|
||||
table.add_column("route", style="yellow", no_wrap=True)
|
||||
table.add_column("in·new", justify="right")
|
||||
if show_cache:
|
||||
table.add_column("in·cache", justify="right", style="yellow")
|
||||
table.add_column("out", justify="right", style="green")
|
||||
if show_rsn:
|
||||
table.add_column("rsn", justify="right")
|
||||
table.add_column("tok/s", justify="right")
|
||||
table.add_column("TTFT", justify="right")
|
||||
table.add_column("lat", justify="right")
|
||||
table.add_column("cost", justify="right", style="green")
|
||||
table.add_column("status")
|
||||
|
||||
if not calls:
|
||||
cols = len(table.columns)
|
||||
table.add_row(
|
||||
Text("waiting for spans…", style="dim italic"),
|
||||
*(["—"] * (cols - 1)),
|
||||
)
|
||||
return table
|
||||
|
||||
recent = list(reversed(calls))[:limit]
|
||||
for idx, c in enumerate(recent):
|
||||
is_newest = idx == 0
|
||||
time_style = "bold white" if is_newest else None
|
||||
model_style = "bold cyan" if is_newest else "cyan"
|
||||
row: list[object] = [
|
||||
(
|
||||
Text(c.timestamp.strftime("%H:%M:%S"), style=time_style)
|
||||
if time_style
|
||||
else c.timestamp.strftime("%H:%M:%S")
|
||||
),
|
||||
Text(_truncate_model(c.model), style=model_style),
|
||||
]
|
||||
if show_route:
|
||||
row.append(c.route_name or "—")
|
||||
row.append(_fmt_tokens(c.prompt_tokens))
|
||||
if show_cache:
|
||||
row.append(_fmt_int(c.cached_input_tokens))
|
||||
row.append(_fmt_tokens(c.completion_tokens))
|
||||
if show_rsn:
|
||||
row.append(_fmt_int(c.reasoning_tokens))
|
||||
row.extend(
|
||||
[
|
||||
_fmt_tps(c.tokens_per_sec),
|
||||
Text(_fmt_ms(c.ttft_ms), style=_ttft_style(c.ttft_ms)),
|
||||
Text(_fmt_ms(c.duration_ms), style=_latency_style(c.duration_ms)),
|
||||
_fmt_cost(c.cost_usd),
|
||||
_status_text(c.status_code),
|
||||
]
|
||||
)
|
||||
table.add_row(*row)
|
||||
return table
|
||||
|
||||
|
||||
def _last_error(calls: list[LLMCall]) -> LLMCall | None:
|
||||
for c in reversed(calls):
|
||||
if c.status_code is not None and c.status_code >= 400:
|
||||
return c
|
||||
return None
|
||||
|
||||
|
||||
def _http_reason(code: int) -> str:
|
||||
try:
|
||||
return HTTPStatus(code).phrase
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
|
||||
def _fmt_ago(ts: datetime) -> str:
|
||||
# `ts` is produced in collector.py via datetime.now(tz=...), but fall back
|
||||
# gracefully if a naive timestamp ever sneaks in.
|
||||
now = datetime.now(tz=ts.tzinfo) if ts.tzinfo else datetime.now()
|
||||
delta = (now - ts).total_seconds()
|
||||
if delta < 0:
|
||||
delta = 0
|
||||
if delta < 60:
|
||||
return f"{int(delta)}s ago"
|
||||
if delta < 3600:
|
||||
return f"{int(delta // 60)}m ago"
|
||||
return f"{int(delta // 3600)}h ago"
|
||||
|
||||
|
||||
def _error_banner(call: LLMCall) -> Panel:
|
||||
code = call.status_code or 0
|
||||
border = "red" if code >= 500 else "yellow"
|
||||
header = Text()
|
||||
header.append(f"● {code}", style=f"{border} bold")
|
||||
reason = _http_reason(code)
|
||||
if reason:
|
||||
header.append(f" {reason}", style=border)
|
||||
header.append(" · ", style="dim")
|
||||
header.append(_truncate_model(call.model, 48), style="cyan")
|
||||
if call.route_name:
|
||||
header.append(" · ", style="dim")
|
||||
header.append(call.route_name, style="yellow")
|
||||
header.append(" · ", style="dim")
|
||||
header.append(_fmt_ago(call.timestamp), style="dim")
|
||||
if call.request_id:
|
||||
header.append(" · req ", style="dim")
|
||||
header.append(call.request_id, style="magenta")
|
||||
return Panel(
|
||||
header,
|
||||
title="[bold]last error[/]",
|
||||
title_align="left",
|
||||
border_style=border,
|
||||
box=SIMPLE,
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
def _footer(stats: AggregateStats) -> Text:
|
||||
waiting = stats.count == 0
|
||||
text = Text()
|
||||
text.append("Ctrl-C ", style="bold")
|
||||
text.append("exit", style="dim")
|
||||
text.append(" · OTLP :4317", style="dim")
|
||||
text.append(" · pricing: DigitalOcean ", style="dim")
|
||||
if waiting:
|
||||
text.append("waiting for spans", style="yellow")
|
||||
text.append(
|
||||
" — set tracing.opentracing_grpc_endpoint=localhost:4317", style="dim"
|
||||
)
|
||||
else:
|
||||
text.append(f"receiving · {stats.count:,} call(s) buffered", style="green")
|
||||
return text
|
||||
|
||||
|
||||
def render(calls: list[LLMCall]) -> Align:
|
||||
last = calls[-1] if calls else None
|
||||
stats = aggregates(calls)
|
||||
rollups = model_rollups(calls)
|
||||
hits = route_hits(calls)
|
||||
|
||||
parts: list[object] = [_summary_panel(last, stats)]
|
||||
err = _last_error(calls)
|
||||
if err is not None:
|
||||
parts.append(_error_banner(err))
|
||||
if hits:
|
||||
split = Table.grid(padding=(0, 2))
|
||||
split.add_column(no_wrap=False)
|
||||
split.add_column(no_wrap=False)
|
||||
split.add_row(_model_rollup_table(rollups), _route_hit_table(hits))
|
||||
parts.append(split)
|
||||
else:
|
||||
parts.append(_model_rollup_table(rollups))
|
||||
parts.append(_recent_table(calls))
|
||||
parts.append(_footer(stats))
|
||||
# Cap overall width so wide terminals don't stretch the layout into a
|
||||
# mostly-whitespace gap between columns.
|
||||
return Align.left(Group(*parts), width=MAX_WIDTH)
|
||||
99
cli/planoai/obs_cmd.py
Normal file
99
cli/planoai/obs_cmd.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""`planoai obs` — live observability TUI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import rich_click as click
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
|
||||
from planoai.consts import PLANO_COLOR
|
||||
from planoai.obs.collector import (
|
||||
DEFAULT_CAPACITY,
|
||||
DEFAULT_GRPC_PORT,
|
||||
LLMCallStore,
|
||||
ObsCollector,
|
||||
)
|
||||
from planoai.obs.pricing import PricingCatalog
|
||||
from planoai.obs.render import render
|
||||
|
||||
|
||||
@click.command(name="obs", help="Live observability console for Plano LLM traffic.")
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=DEFAULT_GRPC_PORT,
|
||||
show_default=True,
|
||||
help="OTLP/gRPC port to listen on. Must match the brightstaff tracing endpoint.",
|
||||
)
|
||||
@click.option(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
show_default=True,
|
||||
help="Host to bind the OTLP listener.",
|
||||
)
|
||||
@click.option(
|
||||
"--capacity",
|
||||
type=int,
|
||||
default=DEFAULT_CAPACITY,
|
||||
show_default=True,
|
||||
help="Max LLM calls kept in memory; older calls evicted FIFO.",
|
||||
)
|
||||
@click.option(
|
||||
"--refresh-ms",
|
||||
type=int,
|
||||
default=500,
|
||||
show_default=True,
|
||||
help="TUI refresh interval.",
|
||||
)
|
||||
def obs(port: int, host: str, capacity: int, refresh_ms: int) -> None:
|
||||
console = Console()
|
||||
console.print(
|
||||
f"[bold {PLANO_COLOR}]planoai obs[/] — loading DO pricing catalog...",
|
||||
end="",
|
||||
)
|
||||
pricing = PricingCatalog.fetch()
|
||||
if len(pricing):
|
||||
sample = ", ".join(pricing.sample_models(3))
|
||||
console.print(
|
||||
f" [green]{len(pricing)} models loaded[/] [dim]({sample}, ...)[/]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
" [yellow]no pricing loaded[/] — "
|
||||
"[dim]cost column will be blank (DO catalog unreachable)[/]"
|
||||
)
|
||||
|
||||
store = LLMCallStore(capacity=capacity)
|
||||
collector = ObsCollector(store=store, pricing=pricing, host=host, port=port)
|
||||
try:
|
||||
collector.start()
|
||||
except OSError as exc:
|
||||
console.print(f"[red]{exc}[/]")
|
||||
raise SystemExit(1)
|
||||
|
||||
console.print(
|
||||
f"Listening for OTLP spans on [bold]{host}:{port}[/]. "
|
||||
"Ensure plano config has [cyan]tracing.opentracing_grpc_endpoint: http://localhost:4317[/] "
|
||||
"and [cyan]tracing.random_sampling: 100[/] (or run [bold]planoai up[/] "
|
||||
"with no config — it wires this automatically)."
|
||||
)
|
||||
console.print("Press [bold]Ctrl-C[/] to exit.\n")
|
||||
|
||||
refresh = max(0.05, refresh_ms / 1000.0)
|
||||
try:
|
||||
with Live(
|
||||
render(store.snapshot()),
|
||||
console=console,
|
||||
refresh_per_second=1.0 / refresh,
|
||||
screen=False,
|
||||
) as live:
|
||||
while True:
|
||||
time.sleep(refresh)
|
||||
live.update(render(store.snapshot()))
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[dim]obs stopped[/]")
|
||||
finally:
|
||||
collector.stop()
|
||||
|
|
@ -61,7 +61,7 @@ def configure_rich_click(plano_color: str) -> None:
|
|||
},
|
||||
{
|
||||
"name": "Observability",
|
||||
"commands": ["trace"],
|
||||
"commands": ["trace", "obs"],
|
||||
},
|
||||
{
|
||||
"name": "Utilities",
|
||||
|
|
|
|||
|
|
@ -189,26 +189,26 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
|
|||
if isinstance(
|
||||
arg.annotation.value, ast.Name
|
||||
) and arg.annotation.value.id in ["list", "tuple", "set", "dict"]:
|
||||
param_info[
|
||||
"type"
|
||||
] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
|
||||
param_info["type"] = (
|
||||
f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
|
||||
)
|
||||
else:
|
||||
param_info["type"] = "[UNKNOWN - PLEASE FIX]"
|
||||
|
||||
# Default for unknown types
|
||||
else:
|
||||
param_info[
|
||||
"type"
|
||||
] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type
|
||||
param_info["type"] = (
|
||||
"[UNKNOWN - PLEASE FIX]" # If unable to detect type
|
||||
)
|
||||
|
||||
# Handle default values
|
||||
if default is not None:
|
||||
if isinstance(default, ast.Constant) or isinstance(
|
||||
default, ast.NameConstant
|
||||
):
|
||||
param_info[
|
||||
"default"
|
||||
] = default.value # Use the default value directly
|
||||
param_info["default"] = (
|
||||
default.value
|
||||
) # Use the default value directly
|
||||
else:
|
||||
param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type
|
||||
param_info["required"] = False # Optional since it has a default value
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import yaml
|
|||
import logging
|
||||
from planoai.consts import PLANO_DOCKER_NAME
|
||||
|
||||
|
||||
# Standard env var for log level across all Plano components
|
||||
LOG_LEVEL_ENV = "LOG_LEVEL"
|
||||
|
||||
|
|
@ -89,19 +88,24 @@ def convert_legacy_listeners(
|
|||
) -> tuple[list, dict | None, dict | None]:
|
||||
llm_gateway_listener = {
|
||||
"name": "egress_traffic",
|
||||
"type": "model_listener",
|
||||
"type": "model",
|
||||
"port": 12000,
|
||||
"address": "0.0.0.0",
|
||||
"timeout": "30s",
|
||||
# LLM streaming responses routinely exceed 30s (extended thinking,
|
||||
# long tool reasoning, large completions). Match the 300s ceiling
|
||||
# used by the direct upstream-provider routes so Envoy doesn't
|
||||
# abort streams with UT mid-response. Users can override via their
|
||||
# plano_config.yaml `listeners.timeout` field.
|
||||
"timeout": "300s",
|
||||
"model_providers": model_providers or [],
|
||||
}
|
||||
|
||||
prompt_gateway_listener = {
|
||||
"name": "ingress_traffic",
|
||||
"type": "prompt_listener",
|
||||
"type": "prompt",
|
||||
"port": 10000,
|
||||
"address": "0.0.0.0",
|
||||
"timeout": "30s",
|
||||
"timeout": "300s",
|
||||
}
|
||||
|
||||
# Handle None case
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "planoai"
|
||||
version = "0.4.10"
|
||||
version = "0.4.21"
|
||||
description = "Python-based CLI tool to manage Plano."
|
||||
authors = [{name = "Katanemo Labs, Inc."}]
|
||||
readme = "README.md"
|
||||
|
|
@ -13,7 +13,7 @@ dependencies = [
|
|||
"opentelemetry-proto>=1.20.0",
|
||||
"questionary>=2.1.1,<3.0.0",
|
||||
"pyyaml>=6.0.2,<7.0.0",
|
||||
"requests>=2.31.0,<3.0.0",
|
||||
"requests>=2.33.0,<3.0.0",
|
||||
"urllib3>=2.6.3",
|
||||
"rich>=14.2.0",
|
||||
"rich-click>=1.9.5",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from unittest import mock
|
||||
from planoai.config_generator import validate_and_render_schema
|
||||
from planoai.config_generator import (
|
||||
validate_and_render_schema,
|
||||
migrate_inline_routing_preferences,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
@ -104,13 +108,13 @@ listeners:
|
|||
agents:
|
||||
- id: simple_tmobile_rag_agent
|
||||
description: t-mobile virtual assistant for device contracts.
|
||||
filter_chain:
|
||||
input_filters:
|
||||
- query_rewriter
|
||||
- context_builder
|
||||
- response_generator
|
||||
- id: research_agent
|
||||
description: agent to research and gather information from various sources.
|
||||
filter_chain:
|
||||
input_filters:
|
||||
- research_agent
|
||||
- response_generator
|
||||
port: 8000
|
||||
|
|
@ -253,38 +257,72 @@ llm_providers:
|
|||
base_url: "http://custom.com/api/v2"
|
||||
provider_interface: openai
|
||||
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "vercel_is_supported_provider",
|
||||
"expected_error": None,
|
||||
"plano_config": """
|
||||
version: v0.4.0
|
||||
|
||||
listeners:
|
||||
- name: llm
|
||||
type: model
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
- model: vercel/*
|
||||
base_url: https://ai-gateway.vercel.sh/v1
|
||||
passthrough_auth: true
|
||||
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "openrouter_is_supported_provider",
|
||||
"expected_error": None,
|
||||
"plano_config": """
|
||||
version: v0.4.0
|
||||
|
||||
listeners:
|
||||
- name: llm
|
||||
type: model
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
- model: openrouter/*
|
||||
base_url: https://openrouter.ai/api/v1
|
||||
passthrough_auth: true
|
||||
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "duplicate_routeing_preference_name",
|
||||
"expected_error": "Duplicate routing preference name",
|
||||
"plano_config": """
|
||||
version: v0.1.0
|
||||
version: v0.4.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
- name: llm
|
||||
type: model
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
|
||||
model_providers:
|
||||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
default: true
|
||||
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code understanding
|
||||
description: understand and explain existing code snippets, functions, or libraries
|
||||
|
||||
- model: openai/gpt-4.1
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code understanding
|
||||
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
|
||||
routing_preferences:
|
||||
- name: code understanding
|
||||
description: understand and explain existing code snippets, functions, or libraries
|
||||
models:
|
||||
- openai/gpt-4o
|
||||
- name: code understanding
|
||||
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
|
||||
models:
|
||||
- openai/gpt-4o-mini
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
|
|
@ -376,7 +414,7 @@ def test_convert_legacy_llm_providers():
|
|||
assert updated_providers == [
|
||||
{
|
||||
"name": "egress_traffic",
|
||||
"type": "model_listener",
|
||||
"type": "model",
|
||||
"port": 12000,
|
||||
"address": "0.0.0.0",
|
||||
"timeout": "30s",
|
||||
|
|
@ -384,7 +422,7 @@ def test_convert_legacy_llm_providers():
|
|||
},
|
||||
{
|
||||
"name": "ingress_traffic",
|
||||
"type": "prompt_listener",
|
||||
"type": "prompt",
|
||||
"port": 10000,
|
||||
"address": "0.0.0.0",
|
||||
"timeout": "30s",
|
||||
|
|
@ -400,7 +438,7 @@ def test_convert_legacy_llm_providers():
|
|||
},
|
||||
],
|
||||
"name": "egress_traffic",
|
||||
"type": "model_listener",
|
||||
"type": "model",
|
||||
"port": 12000,
|
||||
"timeout": "30s",
|
||||
}
|
||||
|
|
@ -410,7 +448,7 @@ def test_convert_legacy_llm_providers():
|
|||
"name": "ingress_traffic",
|
||||
"port": 10000,
|
||||
"timeout": "30s",
|
||||
"type": "prompt_listener",
|
||||
"type": "prompt",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -449,7 +487,7 @@ def test_convert_legacy_llm_providers_no_prompt_gateway():
|
|||
"name": "egress_traffic",
|
||||
"port": 12000,
|
||||
"timeout": "30s",
|
||||
"type": "model_listener",
|
||||
"type": "model",
|
||||
}
|
||||
]
|
||||
assert llm_gateway == {
|
||||
|
|
@ -461,7 +499,242 @@ def test_convert_legacy_llm_providers_no_prompt_gateway():
|
|||
},
|
||||
],
|
||||
"name": "egress_traffic",
|
||||
"type": "model_listener",
|
||||
"type": "model",
|
||||
"port": 12000,
|
||||
"timeout": "30s",
|
||||
}
|
||||
|
||||
|
||||
def test_inline_routing_preferences_migrated_to_top_level():
|
||||
plano_config = """
|
||||
version: v0.3.0
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
name: model_listener
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
default: true
|
||||
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code understanding
|
||||
description: understand and explain existing code snippets, functions, or libraries
|
||||
|
||||
- model: anthropic/claude-sonnet-4-20250514
|
||||
access_key: $ANTHROPIC_API_KEY
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
|
||||
"""
|
||||
config_yaml = yaml.safe_load(plano_config)
|
||||
migrate_inline_routing_preferences(config_yaml)
|
||||
|
||||
assert config_yaml["version"] == "v0.4.0"
|
||||
for provider in config_yaml["model_providers"]:
|
||||
assert "routing_preferences" not in provider
|
||||
|
||||
top_level = config_yaml["routing_preferences"]
|
||||
by_name = {entry["name"]: entry for entry in top_level}
|
||||
assert set(by_name) == {"code understanding", "code generation"}
|
||||
assert by_name["code understanding"]["models"] == ["openai/gpt-4o"]
|
||||
assert by_name["code generation"]["models"] == [
|
||||
"anthropic/claude-sonnet-4-20250514"
|
||||
]
|
||||
assert (
|
||||
by_name["code understanding"]["description"]
|
||||
== "understand and explain existing code snippets, functions, or libraries"
|
||||
)
|
||||
|
||||
|
||||
def test_inline_same_name_across_providers_merges_models():
|
||||
plano_config = """
|
||||
version: v0.3.0
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
name: model_listener
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
|
||||
|
||||
- model: anthropic/claude-sonnet-4-20250514
|
||||
access_key: $ANTHROPIC_API_KEY
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
|
||||
"""
|
||||
config_yaml = yaml.safe_load(plano_config)
|
||||
migrate_inline_routing_preferences(config_yaml)
|
||||
|
||||
top_level = config_yaml["routing_preferences"]
|
||||
assert len(top_level) == 1
|
||||
entry = top_level[0]
|
||||
assert entry["name"] == "code generation"
|
||||
assert entry["models"] == [
|
||||
"openai/gpt-4o",
|
||||
"anthropic/claude-sonnet-4-20250514",
|
||||
]
|
||||
assert config_yaml["version"] == "v0.4.0"
|
||||
|
||||
|
||||
def test_existing_top_level_routing_preferences_preserved():
|
||||
plano_config = """
|
||||
version: v0.4.0
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
name: model_listener
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
- model: anthropic/claude-sonnet-4-20250514
|
||||
access_key: $ANTHROPIC_API_KEY
|
||||
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: generating new code snippets or boilerplate
|
||||
models:
|
||||
- openai/gpt-4o
|
||||
- anthropic/claude-sonnet-4-20250514
|
||||
"""
|
||||
config_yaml = yaml.safe_load(plano_config)
|
||||
before = yaml.safe_dump(config_yaml, sort_keys=True)
|
||||
migrate_inline_routing_preferences(config_yaml)
|
||||
after = yaml.safe_dump(config_yaml, sort_keys=True)
|
||||
|
||||
assert before == after
|
||||
|
||||
|
||||
def test_existing_top_level_wins_over_inline_migration():
|
||||
plano_config = """
|
||||
version: v0.3.0
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
name: model_listener
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: inline description should lose
|
||||
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: user-defined top-level description wins
|
||||
models:
|
||||
- openai/gpt-4o
|
||||
"""
|
||||
config_yaml = yaml.safe_load(plano_config)
|
||||
migrate_inline_routing_preferences(config_yaml)
|
||||
|
||||
top_level = config_yaml["routing_preferences"]
|
||||
assert len(top_level) == 1
|
||||
entry = top_level[0]
|
||||
assert entry["description"] == "user-defined top-level description wins"
|
||||
assert entry["models"] == ["openai/gpt-4o"]
|
||||
|
||||
|
||||
def test_wildcard_with_inline_routing_preferences_errors():
|
||||
plano_config = """
|
||||
version: v0.3.0
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
name: model_listener
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
- model: openrouter/*
|
||||
base_url: https://openrouter.ai/api/v1
|
||||
passthrough_auth: true
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: generating code
|
||||
"""
|
||||
config_yaml = yaml.safe_load(plano_config)
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
migrate_inline_routing_preferences(config_yaml)
|
||||
assert "wildcard" in str(excinfo.value).lower()
|
||||
|
||||
|
||||
def test_migration_bumps_version_even_without_inline_preferences():
|
||||
plano_config = """
|
||||
version: v0.3.0
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
name: model_listener
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
"""
|
||||
config_yaml = yaml.safe_load(plano_config)
|
||||
migrate_inline_routing_preferences(config_yaml)
|
||||
|
||||
assert "routing_preferences" not in config_yaml
|
||||
assert config_yaml["version"] == "v0.4.0"
|
||||
|
||||
|
||||
def test_migration_is_noop_on_v040_config_with_stray_inline_preferences():
|
||||
# v0.4.0 configs are assumed to be on the canonical top-level shape.
|
||||
# The migration intentionally does not rescue stray inline preferences
|
||||
# at v0.4.0+ so that the deprecation boundary is a clean version gate.
|
||||
plano_config = """
|
||||
version: v0.4.0
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
name: model_listener
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: generating new code
|
||||
"""
|
||||
config_yaml = yaml.safe_load(plano_config)
|
||||
migrate_inline_routing_preferences(config_yaml)
|
||||
|
||||
assert config_yaml["version"] == "v0.4.0"
|
||||
assert "routing_preferences" not in config_yaml
|
||||
assert config_yaml["model_providers"][0]["routing_preferences"] == [
|
||||
{"name": "code generation", "description": "generating new code"}
|
||||
]
|
||||
|
||||
|
||||
def test_migration_does_not_downgrade_newer_versions():
|
||||
plano_config = """
|
||||
version: v0.5.0
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
name: model_listener
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
"""
|
||||
config_yaml = yaml.safe_load(plano_config)
|
||||
migrate_inline_routing_preferences(config_yaml)
|
||||
|
||||
assert config_yaml["version"] == "v0.5.0"
|
||||
|
|
|
|||
111
cli/test/test_defaults.py
Normal file
111
cli/test/test_defaults.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
from pathlib import Path
|
||||
|
||||
import jsonschema
|
||||
import yaml
|
||||
|
||||
from planoai.defaults import (
|
||||
PROVIDER_DEFAULTS,
|
||||
detect_providers,
|
||||
synthesize_default_config,
|
||||
)
|
||||
|
||||
_SCHEMA_PATH = Path(__file__).parents[2] / "config" / "plano_config_schema.yaml"
|
||||
|
||||
|
||||
def _schema() -> dict:
|
||||
return yaml.safe_load(_SCHEMA_PATH.read_text())
|
||||
|
||||
|
||||
def test_zero_env_vars_produces_pure_passthrough():
|
||||
cfg = synthesize_default_config(env={})
|
||||
assert cfg["version"] == "v0.4.0"
|
||||
assert cfg["listeners"][0]["port"] == 12000
|
||||
for provider in cfg["model_providers"]:
|
||||
assert provider.get("passthrough_auth") is True
|
||||
assert "access_key" not in provider
|
||||
# No provider should be marked default in pure pass-through mode.
|
||||
assert provider.get("default") is not True
|
||||
# All known providers should be listed.
|
||||
names = {p["name"] for p in cfg["model_providers"]}
|
||||
assert "digitalocean" in names
|
||||
assert "vercel" in names
|
||||
assert "openrouter" in names
|
||||
assert "openai" in names
|
||||
assert "anthropic" in names
|
||||
|
||||
|
||||
def test_env_keys_promote_providers_to_env_keyed():
|
||||
cfg = synthesize_default_config(
|
||||
env={"OPENAI_API_KEY": "sk-1", "DO_API_KEY": "do-1"}
|
||||
)
|
||||
by_name = {p["name"]: p for p in cfg["model_providers"]}
|
||||
assert by_name["openai"].get("access_key") == "$OPENAI_API_KEY"
|
||||
assert by_name["openai"].get("passthrough_auth") is None
|
||||
assert by_name["digitalocean"].get("access_key") == "$DO_API_KEY"
|
||||
# Unset env keys remain pass-through.
|
||||
assert by_name["anthropic"].get("passthrough_auth") is True
|
||||
|
||||
|
||||
def test_no_default_is_synthesized():
|
||||
# Bare model names resolve via brightstaff's wildcard expansion registering
|
||||
# bare keys, so the synthesizer intentionally never sets `default: true`.
|
||||
cfg = synthesize_default_config(
|
||||
env={"OPENAI_API_KEY": "sk-1", "ANTHROPIC_API_KEY": "a-1"}
|
||||
)
|
||||
assert not any(p.get("default") is True for p in cfg["model_providers"])
|
||||
|
||||
|
||||
def test_listener_port_is_configurable():
|
||||
cfg = synthesize_default_config(env={}, listener_port=11000)
|
||||
assert cfg["listeners"][0]["port"] == 11000
|
||||
|
||||
|
||||
def test_detection_summary_strings():
|
||||
det = detect_providers(env={"OPENAI_API_KEY": "sk", "DO_API_KEY": "d"})
|
||||
summary = det.summary
|
||||
assert "env-keyed" in summary and "openai" in summary and "digitalocean" in summary
|
||||
assert "pass-through" in summary
|
||||
|
||||
|
||||
def test_tracing_block_points_at_local_console():
|
||||
cfg = synthesize_default_config(env={})
|
||||
tracing = cfg["tracing"]
|
||||
assert tracing["opentracing_grpc_endpoint"] == "http://localhost:4317"
|
||||
# random_sampling is a percentage in the plano config — 100 = every span.
|
||||
assert tracing["random_sampling"] == 100
|
||||
|
||||
|
||||
def test_synthesized_config_validates_against_schema():
|
||||
cfg = synthesize_default_config(env={"OPENAI_API_KEY": "sk"})
|
||||
jsonschema.validate(cfg, _schema())
|
||||
|
||||
|
||||
def test_provider_defaults_digitalocean_is_configured():
|
||||
by_name = {p.name: p for p in PROVIDER_DEFAULTS}
|
||||
assert "digitalocean" in by_name
|
||||
assert by_name["digitalocean"].env_var == "DO_API_KEY"
|
||||
assert by_name["digitalocean"].base_url == "https://inference.do-ai.run/v1"
|
||||
assert by_name["digitalocean"].model_pattern == "digitalocean/*"
|
||||
|
||||
|
||||
def test_provider_defaults_vercel_is_configured():
|
||||
by_name = {p.name: p for p in PROVIDER_DEFAULTS}
|
||||
assert "vercel" in by_name
|
||||
assert by_name["vercel"].env_var == "AI_GATEWAY_API_KEY"
|
||||
assert by_name["vercel"].base_url == "https://ai-gateway.vercel.sh/v1"
|
||||
assert by_name["vercel"].model_pattern == "vercel/*"
|
||||
|
||||
|
||||
def test_provider_defaults_openrouter_is_configured():
|
||||
by_name = {p.name: p for p in PROVIDER_DEFAULTS}
|
||||
assert "openrouter" in by_name
|
||||
assert by_name["openrouter"].env_var == "OPENROUTER_API_KEY"
|
||||
assert by_name["openrouter"].base_url == "https://openrouter.ai/api/v1"
|
||||
assert by_name["openrouter"].model_pattern == "openrouter/*"
|
||||
|
||||
|
||||
def test_openrouter_env_key_promotes_to_env_keyed():
|
||||
cfg = synthesize_default_config(env={"OPENROUTER_API_KEY": "or-1"})
|
||||
by_name = {p["name"]: p for p in cfg["model_providers"]}
|
||||
assert by_name["openrouter"].get("access_key") == "$OPENROUTER_API_KEY"
|
||||
assert by_name["openrouter"].get("passthrough_auth") is None
|
||||
145
cli/test/test_obs_collector.py
Normal file
145
cli/test/test_obs_collector.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
import time
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from planoai.obs.collector import LLMCall, LLMCallStore, span_to_llm_call
|
||||
|
||||
|
||||
def _mk_attr(key: str, value):
|
||||
v = MagicMock()
|
||||
if isinstance(value, bool):
|
||||
v.WhichOneof.return_value = "bool_value"
|
||||
v.bool_value = value
|
||||
elif isinstance(value, int):
|
||||
v.WhichOneof.return_value = "int_value"
|
||||
v.int_value = value
|
||||
elif isinstance(value, float):
|
||||
v.WhichOneof.return_value = "double_value"
|
||||
v.double_value = value
|
||||
else:
|
||||
v.WhichOneof.return_value = "string_value"
|
||||
v.string_value = str(value)
|
||||
kv = MagicMock()
|
||||
kv.key = key
|
||||
kv.value = v
|
||||
return kv
|
||||
|
||||
|
||||
def _mk_span(
|
||||
attrs: dict, start_ns: int | None = None, span_id_hex: str = "ab"
|
||||
) -> MagicMock:
|
||||
span = MagicMock()
|
||||
span.attributes = [_mk_attr(k, v) for k, v in attrs.items()]
|
||||
span.start_time_unix_nano = start_ns or int(time.time() * 1_000_000_000)
|
||||
span.span_id.hex.return_value = span_id_hex
|
||||
return span
|
||||
|
||||
|
||||
def test_span_without_llm_model_is_ignored():
|
||||
span = _mk_span({"http.method": "POST"})
|
||||
assert span_to_llm_call(span, "plano(llm)") is None
|
||||
|
||||
|
||||
def test_span_with_full_llm_attrs_produces_call():
|
||||
span = _mk_span(
|
||||
{
|
||||
"llm.model": "openai-gpt-5.4",
|
||||
"model.requested": "router:software-engineering",
|
||||
"plano.session_id": "sess-abc",
|
||||
"plano.route.name": "software-engineering",
|
||||
"llm.is_streaming": False,
|
||||
"llm.duration_ms": 1234,
|
||||
"llm.time_to_first_token": 210,
|
||||
"llm.usage.prompt_tokens": 100,
|
||||
"llm.usage.completion_tokens": 50,
|
||||
"llm.usage.total_tokens": 150,
|
||||
"llm.usage.cached_input_tokens": 30,
|
||||
"llm.usage.cache_creation_tokens": 5,
|
||||
"llm.usage.reasoning_tokens": 200,
|
||||
"http.status_code": 200,
|
||||
"request_id": "req-42",
|
||||
}
|
||||
)
|
||||
call = span_to_llm_call(span, "plano(llm)")
|
||||
assert call is not None
|
||||
assert call.request_id == "req-42"
|
||||
assert call.model == "openai-gpt-5.4"
|
||||
assert call.request_model == "router:software-engineering"
|
||||
assert call.session_id == "sess-abc"
|
||||
assert call.route_name == "software-engineering"
|
||||
assert call.is_streaming is False
|
||||
assert call.duration_ms == 1234.0
|
||||
assert call.ttft_ms == 210.0
|
||||
assert call.prompt_tokens == 100
|
||||
assert call.completion_tokens == 50
|
||||
assert call.total_tokens == 150
|
||||
assert call.cached_input_tokens == 30
|
||||
assert call.cache_creation_tokens == 5
|
||||
assert call.reasoning_tokens == 200
|
||||
assert call.status_code == 200
|
||||
|
||||
|
||||
def test_pricing_lookup_attaches_cost():
|
||||
class StubPricing:
|
||||
def cost_for_call(self, call):
|
||||
# Simple: 2 * prompt + 3 * completion, in cents
|
||||
return 0.02 * (call.prompt_tokens or 0) + 0.03 * (
|
||||
call.completion_tokens or 0
|
||||
)
|
||||
|
||||
span = _mk_span(
|
||||
{
|
||||
"llm.model": "do/openai-gpt-5.4",
|
||||
"llm.usage.prompt_tokens": 10,
|
||||
"llm.usage.completion_tokens": 2,
|
||||
}
|
||||
)
|
||||
call = span_to_llm_call(span, "plano(llm)", pricing=StubPricing())
|
||||
assert call is not None
|
||||
assert call.cost_usd == pytest.approx(0.26)
|
||||
|
||||
|
||||
def test_tpt_and_tokens_per_sec_derived():
|
||||
call = LLMCall(
|
||||
request_id="x",
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
model="m",
|
||||
duration_ms=1000,
|
||||
ttft_ms=200,
|
||||
completion_tokens=80,
|
||||
)
|
||||
# (1000 - 200) / 80 = 10ms per token => 100 tokens/sec
|
||||
assert call.tpt_ms == 10.0
|
||||
assert call.tokens_per_sec == 100.0
|
||||
|
||||
|
||||
def test_tpt_returns_none_when_no_completion_tokens():
|
||||
call = LLMCall(
|
||||
request_id="x",
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
model="m",
|
||||
duration_ms=1000,
|
||||
ttft_ms=200,
|
||||
completion_tokens=0,
|
||||
)
|
||||
assert call.tpt_ms is None
|
||||
assert call.tokens_per_sec is None
|
||||
|
||||
|
||||
def test_store_evicts_fifo_at_capacity():
|
||||
store = LLMCallStore(capacity=3)
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
for i in range(5):
|
||||
store.add(
|
||||
LLMCall(
|
||||
request_id=f"r{i}",
|
||||
timestamp=now,
|
||||
model="m",
|
||||
)
|
||||
)
|
||||
snap = store.snapshot()
|
||||
assert len(snap) == 3
|
||||
assert [c.request_id for c in snap] == ["r2", "r3", "r4"]
|
||||
146
cli/test/test_obs_pricing.py
Normal file
146
cli/test/test_obs_pricing.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from planoai.obs.collector import LLMCall
|
||||
from planoai.obs.pricing import ModelPrice, PricingCatalog
|
||||
|
||||
|
||||
def _call(model: str, prompt: int, completion: int, cached: int = 0) -> LLMCall:
|
||||
return LLMCall(
|
||||
request_id="r",
|
||||
timestamp=datetime.now(tz=timezone.utc),
|
||||
model=model,
|
||||
prompt_tokens=prompt,
|
||||
completion_tokens=completion,
|
||||
cached_input_tokens=cached,
|
||||
)
|
||||
|
||||
|
||||
def test_lookup_matches_bare_and_prefixed():
|
||||
prices = {
|
||||
"openai-gpt-5.4": ModelPrice(
|
||||
input_per_token_usd=0.000001, output_per_token_usd=0.000002
|
||||
)
|
||||
}
|
||||
catalog = PricingCatalog(prices)
|
||||
assert catalog.price_for("openai-gpt-5.4") is not None
|
||||
# do/openai-gpt-5.4 should resolve after stripping the provider prefix.
|
||||
assert catalog.price_for("do/openai-gpt-5.4") is not None
|
||||
assert catalog.price_for("unknown-model") is None
|
||||
|
||||
|
||||
def test_cost_computation_without_cache():
|
||||
prices = {
|
||||
"m": ModelPrice(input_per_token_usd=0.000001, output_per_token_usd=0.000002)
|
||||
}
|
||||
cost = PricingCatalog(prices).cost_for_call(_call("m", 1000, 500))
|
||||
assert cost == 0.002 # 1000 * 1e-6 + 500 * 2e-6
|
||||
|
||||
|
||||
def test_cost_computation_with_cached_discount():
|
||||
prices = {
|
||||
"m": ModelPrice(
|
||||
input_per_token_usd=0.000001,
|
||||
output_per_token_usd=0.000002,
|
||||
cached_input_per_token_usd=0.0000001,
|
||||
)
|
||||
}
|
||||
# 800 fresh @ 1e-6 = 8e-4; 200 cached @ 1e-7 = 2e-5; 500 out @ 2e-6 = 1e-3
|
||||
cost = PricingCatalog(prices).cost_for_call(_call("m", 1000, 500, cached=200))
|
||||
assert cost == round(0.0008 + 0.00002 + 0.001, 6)
|
||||
|
||||
|
||||
def test_empty_catalog_returns_none():
|
||||
assert PricingCatalog().cost_for_call(_call("m", 100, 50)) is None
|
||||
|
||||
|
||||
def test_parse_do_catalog_treats_small_values_as_per_token():
|
||||
"""DO's real catalog uses per-token values under the `_per_million` key
|
||||
(e.g. 5E-8 for GPT-oss-20b). We treat values < 1 as already per-token."""
|
||||
from planoai.obs.pricing import _parse_do_pricing
|
||||
|
||||
sample = {
|
||||
"data": [
|
||||
{
|
||||
"model_id": "openai-gpt-oss-20b",
|
||||
"pricing": {
|
||||
"input_price_per_million": 5e-8,
|
||||
"output_price_per_million": 4.5e-7,
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_id": "openai-gpt-oss-120b",
|
||||
"pricing": {
|
||||
"input_price_per_million": 1e-7,
|
||||
"output_price_per_million": 7e-7,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
prices = _parse_do_pricing(sample)
|
||||
# Values < 1 are assumed to already be per-token — no extra division.
|
||||
assert prices["openai-gpt-oss-20b"].input_per_token_usd == 5e-8
|
||||
assert prices["openai-gpt-oss-20b"].output_per_token_usd == 4.5e-7
|
||||
assert prices["openai-gpt-oss-120b"].input_per_token_usd == 1e-7
|
||||
|
||||
|
||||
def test_anthropic_aliases_match_plano_emitted_names():
|
||||
"""DO publishes 'anthropic-claude-opus-4.7' and 'anthropic-claude-haiku-4.5';
|
||||
Plano emits 'claude-opus-4-7' and 'claude-haiku-4-5-20251001'. Aliases
|
||||
registered at parse time should bridge the gap."""
|
||||
from planoai.obs.pricing import _parse_do_pricing
|
||||
|
||||
sample = {
|
||||
"data": [
|
||||
{
|
||||
"model_id": "anthropic-claude-opus-4.7",
|
||||
"pricing": {
|
||||
"input_price_per_million": 15.0,
|
||||
"output_price_per_million": 75.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_id": "anthropic-claude-haiku-4.5",
|
||||
"pricing": {
|
||||
"input_price_per_million": 1.0,
|
||||
"output_price_per_million": 5.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_id": "anthropic-claude-4.6-sonnet",
|
||||
"pricing": {
|
||||
"input_price_per_million": 3.0,
|
||||
"output_price_per_million": 15.0,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
catalog = PricingCatalog(_parse_do_pricing(sample))
|
||||
# Family-last shapes Plano emits.
|
||||
assert catalog.price_for("claude-opus-4-7") is not None
|
||||
assert catalog.price_for("claude-haiku-4-5") is not None
|
||||
# Date-suffixed name (Anthropic API style).
|
||||
assert catalog.price_for("claude-haiku-4-5-20251001") is not None
|
||||
# Word-order swap: DO has 'claude-4.6-sonnet', Plano emits 'claude-sonnet-4-6'.
|
||||
assert catalog.price_for("claude-sonnet-4-6") is not None
|
||||
# Original DO ids still resolve.
|
||||
assert catalog.price_for("anthropic-claude-opus-4.7") is not None
|
||||
|
||||
|
||||
def test_parse_do_catalog_divides_large_values_as_per_million():
|
||||
"""A provider that genuinely reports $5-per-million in that field gets divided."""
|
||||
from planoai.obs.pricing import _parse_do_pricing
|
||||
|
||||
sample = {
|
||||
"data": [
|
||||
{
|
||||
"model_id": "mystery-model",
|
||||
"pricing": {
|
||||
"input_price_per_million": 5.0, # > 1 → treated as per-million
|
||||
"output_price_per_million": 15.0,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
prices = _parse_do_pricing(sample)
|
||||
assert prices["mystery-model"].input_per_token_usd == 5.0 / 1_000_000
|
||||
assert prices["mystery-model"].output_per_token_usd == 15.0 / 1_000_000
|
||||
106
cli/test/test_obs_render.py
Normal file
106
cli/test/test_obs_render.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from planoai.obs.collector import LLMCall
|
||||
from planoai.obs.render import aggregates, model_rollups, route_hits
|
||||
|
||||
|
||||
def _call(
|
||||
model: str,
|
||||
ts: datetime,
|
||||
prompt=0,
|
||||
completion=0,
|
||||
cost=None,
|
||||
route=None,
|
||||
session=None,
|
||||
cache_read=0,
|
||||
cache_write=0,
|
||||
):
|
||||
return LLMCall(
|
||||
request_id="r",
|
||||
timestamp=ts,
|
||||
model=model,
|
||||
prompt_tokens=prompt,
|
||||
completion_tokens=completion,
|
||||
cached_input_tokens=cache_read,
|
||||
cache_creation_tokens=cache_write,
|
||||
cost_usd=cost,
|
||||
route_name=route,
|
||||
session_id=session,
|
||||
)
|
||||
|
||||
|
||||
def test_aggregates_sum_and_session_counts():
|
||||
now = datetime.now(tz=timezone.utc).astimezone()
|
||||
calls = [
|
||||
_call(
|
||||
"m1",
|
||||
now - timedelta(seconds=50),
|
||||
prompt=10,
|
||||
completion=5,
|
||||
cost=0.001,
|
||||
session="s1",
|
||||
),
|
||||
_call(
|
||||
"m2",
|
||||
now - timedelta(seconds=40),
|
||||
prompt=20,
|
||||
completion=10,
|
||||
cost=0.002,
|
||||
session="s1",
|
||||
),
|
||||
_call(
|
||||
"m1",
|
||||
now - timedelta(seconds=30),
|
||||
prompt=30,
|
||||
completion=15,
|
||||
cost=0.003,
|
||||
session="s2",
|
||||
),
|
||||
]
|
||||
stats = aggregates(calls)
|
||||
assert stats.count == 3
|
||||
assert stats.total_cost_usd == 0.006
|
||||
assert stats.total_input_tokens == 60
|
||||
assert stats.total_output_tokens == 30
|
||||
assert stats.distinct_sessions == 2
|
||||
assert stats.current_session == "s2"
|
||||
|
||||
|
||||
def test_rollups_split_by_model_and_cache():
|
||||
now = datetime.now(tz=timezone.utc).astimezone()
|
||||
calls = [
|
||||
_call(
|
||||
"m1", now, prompt=10, completion=5, cost=0.001, cache_write=3, cache_read=7
|
||||
),
|
||||
_call("m1", now, prompt=20, completion=10, cost=0.002, cache_read=1),
|
||||
_call("m2", now, prompt=30, completion=15, cost=0.004),
|
||||
]
|
||||
rollups = model_rollups(calls)
|
||||
by_model = {r.model: r for r in rollups}
|
||||
assert by_model["m1"].requests == 2
|
||||
assert by_model["m1"].input_tokens == 30
|
||||
assert by_model["m1"].cache_write == 3
|
||||
assert by_model["m1"].cache_read == 8
|
||||
assert by_model["m2"].input_tokens == 30
|
||||
|
||||
|
||||
def test_route_hits_only_for_routed_calls():
|
||||
now = datetime.now(tz=timezone.utc).astimezone()
|
||||
calls = [
|
||||
_call("m", now, route="code"),
|
||||
_call("m", now, route="code"),
|
||||
_call("m", now, route="summarization"),
|
||||
_call("m", now), # no route
|
||||
]
|
||||
hits = route_hits(calls)
|
||||
# Only calls with route names are counted.
|
||||
assert sum(h.hits for h in hits) == 3
|
||||
hits_by_name = {h.route: h for h in hits}
|
||||
assert hits_by_name["code"].hits == 2
|
||||
assert hits_by_name["summarization"].hits == 1
|
||||
|
||||
|
||||
def test_route_hits_empty_when_no_routes():
|
||||
now = datetime.now(tz=timezone.utc).astimezone()
|
||||
calls = [_call("m", now), _call("m", now)]
|
||||
assert route_hits(calls) == []
|
||||
16
cli/uv.lock
generated
16
cli/uv.lock
generated
|
|
@ -337,7 +337,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "planoai"
|
||||
version = "0.4.7"
|
||||
version = "0.4.21"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
|
|
@ -373,7 +373,7 @@ requires-dist = [
|
|||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.4.1,<9.0.0" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.2,<7.0.0" },
|
||||
{ name = "questionary", specifier = ">=2.1.1,<3.0.0" },
|
||||
{ name = "requests", specifier = ">=2.31.0,<3.0.0" },
|
||||
{ name = "requests", specifier = ">=2.33.0,<3.0.0" },
|
||||
{ name = "rich", specifier = ">=14.2.0" },
|
||||
{ name = "rich-click", specifier = ">=1.9.5" },
|
||||
{ name = "urllib3", specifier = ">=2.6.3" },
|
||||
|
|
@ -421,11 +421,11 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.2"
|
||||
version = "2.20.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/b2/bc9c9196916376152d655522fdcebac55e66de6603a76a02bca1b6414f6c/pygments-2.20.0.tar.gz", hash = "sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f", size = 4955991, upload-time = "2026-03-29T13:29:33.898Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/7e/a72dd26f3b0f4f2bf1dd8923c85f7ceb43172af56d63c7383eb62b332364/pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176", size = 1231151, upload-time = "2026-03-29T13:29:30.038Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -518,7 +518,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.5"
|
||||
version = "2.33.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
|
|
@ -526,9 +526,9 @@ dependencies = [
|
|||
{ name = "idna" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5f/a4/98b9c7c6428a668bf7e42ebb7c79d576a1c3c1e3ae2d47e674b468388871/requests-2.33.1.tar.gz", hash = "sha256:18817f8c57c6263968bc123d237e3b8b08ac046f5456bd1e307ee8f4250d3517", size = 134120, upload-time = "2026-03-30T16:09:15.531Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/8e/7540e8a2036f79a125c1d2ebadf69ed7901608859186c856fa0388ef4197/requests-2.33.1-py3-none-any.whl", hash = "sha256:4e6d1ef462f3626a1f0a0a9c42dd93c63bad33f9f1c1937509b8c5c8718ab56a", size = 64947, upload-time = "2026-03-30T16:09:13.83Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -594,13 +594,13 @@ static_resources:
|
|||
|
||||
clusters:
|
||||
|
||||
- name: arch
|
||||
- name: plano
|
||||
connect_timeout: {{ upstream_connect_timeout | default('5s') }}
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: arch
|
||||
cluster_name: plano
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
|
|
@ -901,6 +901,60 @@ static_resources:
|
|||
validation_context:
|
||||
trusted_ca:
|
||||
filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }}
|
||||
- name: digitalocean
|
||||
connect_timeout: {{ upstream_connect_timeout | default('5s') }}
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: digitalocean
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: inference.do-ai.run
|
||||
port_value: 443
|
||||
hostname: "inference.do-ai.run"
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
sni: inference.do-ai.run
|
||||
common_tls_context:
|
||||
tls_params:
|
||||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
validation_context:
|
||||
trusted_ca:
|
||||
filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }}
|
||||
- name: xiaomi
|
||||
connect_timeout: {{ upstream_connect_timeout | default('5s') }}
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: xiaomi
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: api.xiaomimimo.com
|
||||
port_value: 443
|
||||
hostname: "api.xiaomimimo.com"
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
sni: api.xiaomimimo.com
|
||||
common_tls_context:
|
||||
tls_params:
|
||||
tls_minimum_protocol_version: TLSv1_2
|
||||
tls_maximum_protocol_version: TLSv1_3
|
||||
validation_context:
|
||||
trusted_ca:
|
||||
filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }}
|
||||
- name: mistral_7b_instruct
|
||||
connect_timeout: 0.5s
|
||||
type: STRICT_DNS
|
||||
|
|
|
|||
541
config/grafana/brightstaff_dashboard.json
Normal file
541
config/grafana/brightstaff_dashboard.json
Normal file
|
|
@ -0,0 +1,541 @@
|
|||
{
|
||||
"annotations": {
|
||||
"list": [
|
||||
{
|
||||
"builtIn": 1,
|
||||
"datasource": "-- Grafana --",
|
||||
"enable": true,
|
||||
"hide": true,
|
||||
"iconColor": "rgba(0, 211, 255, 1)",
|
||||
"name": "Annotations & Alerts",
|
||||
"type": "dashboard"
|
||||
}
|
||||
]
|
||||
},
|
||||
"description": "RED, LLM upstream, routing service, and process metrics for brightstaff. Pair with Envoy admin metrics from cluster=bright_staff.",
|
||||
"editable": true,
|
||||
"fiscalYearStartMonth": 0,
|
||||
"graphTooltip": 1,
|
||||
"id": null,
|
||||
"links": [],
|
||||
"liveNow": false,
|
||||
"panels": [
|
||||
{
|
||||
"collapsed": false,
|
||||
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 0 },
|
||||
"id": 100,
|
||||
"panels": [],
|
||||
"title": "HTTP RED",
|
||||
"type": "row"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": {
|
||||
"axisLabel": "req/s",
|
||||
"drawStyle": "line",
|
||||
"fillOpacity": 10,
|
||||
"lineWidth": 1,
|
||||
"showPoints": "never"
|
||||
},
|
||||
"unit": "reqps"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 8, "w": 12, "x": 0, "y": 1 },
|
||||
"id": 1,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (handler) (rate(brightstaff_http_requests_total[1m]))",
|
||||
"legendFormat": "{{handler}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Rate — brightstaff RPS by handler",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "5xx fraction over 5m. Page-worthy when sustained above ~1%.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "thresholds" },
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{ "color": "green", "value": null },
|
||||
{ "color": "yellow", "value": 0.01 },
|
||||
{ "color": "red", "value": 0.05 }
|
||||
]
|
||||
},
|
||||
"unit": "percentunit"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 8, "w": 12, "x": 12, "y": 1 },
|
||||
"id": 2,
|
||||
"options": {
|
||||
"colorMode": "background",
|
||||
"graphMode": "area",
|
||||
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum(rate(brightstaff_http_requests_total{status_class=\"5xx\"}[5m])) / clamp_min(sum(rate(brightstaff_http_requests_total[5m])), 1)",
|
||||
"legendFormat": "5xx rate",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Errors — brightstaff 5xx rate",
|
||||
"type": "stat"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "p50/p95/p99 by handler, computed from histogram buckets over 5m.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" },
|
||||
"unit": "s"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 9, "w": 24, "x": 0, "y": 9 },
|
||||
"id": 3,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.50, sum by (le, handler) (rate(brightstaff_http_request_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "p50 {{handler}}",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.95, sum by (le, handler) (rate(brightstaff_http_request_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "p95 {{handler}}",
|
||||
"refId": "B"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.99, sum by (le, handler) (rate(brightstaff_http_request_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "p99 {{handler}}",
|
||||
"refId": "C"
|
||||
}
|
||||
],
|
||||
"title": "Duration — p50 / p95 / p99 by handler",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "In-flight requests by handler. Climbs before latency does when brightstaff is saturated.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" },
|
||||
"unit": "short"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 8, "w": 24, "x": 0, "y": 18 },
|
||||
"id": 4,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (handler) (brightstaff_http_in_flight_requests)",
|
||||
"legendFormat": "{{handler}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "In-flight requests by handler",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"collapsed": false,
|
||||
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 26 },
|
||||
"id": 200,
|
||||
"panels": [],
|
||||
"title": "LLM upstream",
|
||||
"type": "row"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" },
|
||||
"unit": "s"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 9, "w": 12, "x": 0, "y": 27 },
|
||||
"id": 5,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.95, sum by (le, provider, model) (rate(brightstaff_llm_upstream_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "p95 {{provider}}/{{model}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "LLM upstream p95 by provider/model",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "All non-success error classes. timeout/connect = network, 5xx/429 = provider, parse = body shape mismatch, stream = mid-stream disconnect.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 30, "lineWidth": 1, "showPoints": "never", "stacking": { "mode": "normal" } },
|
||||
"unit": "reqps"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 9, "w": 12, "x": 12, "y": 27 },
|
||||
"id": 6,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (provider, error_class) (rate(brightstaff_llm_upstream_requests_total{error_class!=\"none\"}[5m]))",
|
||||
"legendFormat": "{{provider}} / {{error_class}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "LLM upstream errors by provider / class",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "Streaming only. Empty if the route never streams.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" },
|
||||
"unit": "s"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 9, "w": 12, "x": 0, "y": 36 },
|
||||
"id": 7,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.95, sum by (le, provider, model) (rate(brightstaff_llm_time_to_first_token_seconds_bucket[5m])))",
|
||||
"legendFormat": "p95 {{provider}}/{{model}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Time-to-first-token p95 (streaming)",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "Tokens/sec by provider/model/kind — proxy for cost. Stacked.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 30, "lineWidth": 1, "showPoints": "never", "stacking": { "mode": "normal" } },
|
||||
"unit": "tokens/s"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 9, "w": 12, "x": 12, "y": 36 },
|
||||
"id": 8,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (provider, model, kind) (rate(brightstaff_llm_tokens_total[5m]))",
|
||||
"legendFormat": "{{provider}}/{{model}} {{kind}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Token throughput by provider / model / kind",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"collapsed": false,
|
||||
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 45 },
|
||||
"id": 300,
|
||||
"panels": [],
|
||||
"title": "Routing service",
|
||||
"type": "row"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "Which models the orchestrator picked over the last 15 minutes.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"unit": "short"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 9, "w": 12, "x": 0, "y": 46 },
|
||||
"id": 9,
|
||||
"options": {
|
||||
"displayMode": "gradient",
|
||||
"orientation": "horizontal",
|
||||
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (selected_model) (increase(brightstaff_router_decisions_total[15m]))",
|
||||
"legendFormat": "{{selected_model}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Model selection distribution (last 15m)",
|
||||
"type": "bargauge"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "Fraction of decisions that fell back (orchestrator returned `none` or errored). High = router can't classify intent or no candidates configured.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" },
|
||||
"unit": "percentunit"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 9, "w": 12, "x": 12, "y": 46 },
|
||||
"id": 10,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (route) (rate(brightstaff_router_decisions_total{fallback=\"true\"}[5m])) / clamp_min(sum by (route) (rate(brightstaff_router_decisions_total[5m])), 1)",
|
||||
"legendFormat": "{{route}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Fallback rate by route",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 5, "lineWidth": 1, "showPoints": "never" },
|
||||
"unit": "s"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 8, "w": 12, "x": 0, "y": 55 },
|
||||
"id": 11,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "histogram_quantile(0.95, sum by (le, route) (rate(brightstaff_router_decision_duration_seconds_bucket[5m])))",
|
||||
"legendFormat": "p95 {{route}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Router decision p95 latency",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "Hit / (hit + miss). Low ratio = sessions aren't being reused or TTL too short.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "thresholds" },
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{ "color": "red", "value": null },
|
||||
{ "color": "yellow", "value": 0.5 },
|
||||
{ "color": "green", "value": 0.8 }
|
||||
]
|
||||
},
|
||||
"unit": "percentunit",
|
||||
"min": 0,
|
||||
"max": 1
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 8, "w": 6, "x": 12, "y": 55 },
|
||||
"id": 12,
|
||||
"options": {
|
||||
"colorMode": "background",
|
||||
"graphMode": "area",
|
||||
"reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum(rate(brightstaff_session_cache_events_total{outcome=\"hit\"}[5m])) / clamp_min(sum(rate(brightstaff_session_cache_events_total{outcome=~\"hit|miss\"}[5m])), 1)",
|
||||
"legendFormat": "hit rate",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "Session cache hit rate",
|
||||
"type": "stat"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "decision_served = a real model picked. no_candidates = sentinel `none` returned. policy_error = orchestrator failed.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 30, "lineWidth": 1, "showPoints": "never", "stacking": { "mode": "normal" } },
|
||||
"unit": "reqps"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 8, "w": 6, "x": 18, "y": 55 },
|
||||
"id": 13,
|
||||
"options": {
|
||||
"legend": { "displayMode": "list", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum by (outcome) (rate(brightstaff_routing_service_requests_total[5m]))",
|
||||
"legendFormat": "{{outcome}}",
|
||||
"refId": "A"
|
||||
}
|
||||
],
|
||||
"title": "/routing/* outcomes",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"collapsed": false,
|
||||
"gridPos": { "h": 1, "w": 24, "x": 0, "y": 63 },
|
||||
"id": 400,
|
||||
"panels": [],
|
||||
"title": "Process & Envoy link",
|
||||
"type": "row"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"description": "Compare to brightstaff RPS (panel 1) — sustained gap = network or Envoy queueing.",
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" },
|
||||
"unit": "reqps"
|
||||
}
|
||||
},
|
||||
"gridPos": { "h": 8, "w": 12, "x": 0, "y": 64 },
|
||||
"id": 14,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum(rate(envoy_cluster_upstream_rq_total{envoy_cluster_name=\"bright_staff\"}[1m]))",
|
||||
"legendFormat": "envoy → bright_staff",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "sum(rate(brightstaff_http_requests_total[1m]))",
|
||||
"legendFormat": "brightstaff served",
|
||||
"refId": "B"
|
||||
}
|
||||
],
|
||||
"title": "Envoy → brightstaff link health",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"color": { "mode": "palette-classic" },
|
||||
"custom": { "drawStyle": "line", "fillOpacity": 10, "lineWidth": 1, "showPoints": "never" }
|
||||
},
|
||||
"overrides": [
|
||||
{
|
||||
"matcher": { "id": "byName", "options": "RSS" },
|
||||
"properties": [{ "id": "unit", "value": "bytes" }]
|
||||
},
|
||||
{
|
||||
"matcher": { "id": "byName", "options": "CPU" },
|
||||
"properties": [{ "id": "unit", "value": "percentunit" }]
|
||||
}
|
||||
]
|
||||
},
|
||||
"gridPos": { "h": 8, "w": 12, "x": 12, "y": 64 },
|
||||
"id": 15,
|
||||
"options": {
|
||||
"legend": { "displayMode": "table", "placement": "bottom", "showLegend": true },
|
||||
"tooltip": { "mode": "multi" }
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "process_resident_memory_bytes{job=\"brightstaff\"}",
|
||||
"legendFormat": "RSS",
|
||||
"refId": "A"
|
||||
},
|
||||
{
|
||||
"datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" },
|
||||
"expr": "rate(process_cpu_seconds_total{job=\"brightstaff\"}[1m])",
|
||||
"legendFormat": "CPU",
|
||||
"refId": "B"
|
||||
}
|
||||
],
|
||||
"title": "Brightstaff process RSS / CPU",
|
||||
"type": "timeseries"
|
||||
}
|
||||
],
|
||||
"refresh": "30s",
|
||||
"schemaVersion": 39,
|
||||
"tags": ["plano", "brightstaff", "llm"],
|
||||
"templating": {
|
||||
"list": [
|
||||
{
|
||||
"name": "DS_PROMETHEUS",
|
||||
"label": "Prometheus",
|
||||
"type": "datasource",
|
||||
"query": "prometheus",
|
||||
"current": { "selected": false, "text": "Prometheus", "value": "DS_PROMETHEUS" },
|
||||
"hide": 0,
|
||||
"refresh": 1,
|
||||
"regex": "",
|
||||
"skipUrlSync": false,
|
||||
"includeAll": false,
|
||||
"multi": false
|
||||
}
|
||||
]
|
||||
},
|
||||
"time": { "from": "now-1h", "to": "now" },
|
||||
"timepicker": {},
|
||||
"timezone": "browser",
|
||||
"title": "Brightstaff (Plano dataplane)",
|
||||
"uid": "brightstaff",
|
||||
"version": 1,
|
||||
"weekStart": ""
|
||||
}
|
||||
43
config/grafana/docker-compose.yaml
Normal file
43
config/grafana/docker-compose.yaml
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
# One-command Prometheus + Grafana stack for observing a locally-running
|
||||
# Plano (Envoy admin :9901 + brightstaff :9092 on the host).
|
||||
#
|
||||
# cd config/grafana
|
||||
# docker compose up -d
|
||||
# open http://localhost:3000 (admin / admin)
|
||||
#
|
||||
# Grafana is preloaded with:
|
||||
# - Prometheus datasource (uid=DS_PROMETHEUS) → http://prometheus:9090
|
||||
# - Brightstaff dashboard (auto-imported from brightstaff_dashboard.json)
|
||||
#
|
||||
# Prometheus scrapes the host's :9092 and :9901 via host.docker.internal.
|
||||
# On Linux this works because of the `extra_hosts: host-gateway` mapping
|
||||
# below. On Mac/Win it works natively.
|
||||
|
||||
services:
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
container_name: plano-prometheus
|
||||
ports:
|
||||
- "9090:9090"
|
||||
volumes:
|
||||
- ./prometheus_scrape.yaml:/etc/prometheus/prometheus.yml:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
restart: unless-stopped
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
container_name: plano-grafana
|
||||
ports:
|
||||
- "3000:3000"
|
||||
environment:
|
||||
GF_SECURITY_ADMIN_USER: admin
|
||||
GF_SECURITY_ADMIN_PASSWORD: admin
|
||||
GF_AUTH_ANONYMOUS_ENABLED: "true"
|
||||
GF_AUTH_ANONYMOUS_ORG_ROLE: Viewer
|
||||
volumes:
|
||||
- ./provisioning:/etc/grafana/provisioning:ro
|
||||
- ./brightstaff_dashboard.json:/var/lib/grafana/dashboards/brightstaff_dashboard.json:ro
|
||||
depends_on:
|
||||
- prometheus
|
||||
restart: unless-stopped
|
||||
44
config/grafana/prometheus_scrape.yaml
Normal file
44
config/grafana/prometheus_scrape.yaml
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Prometheus config that scrapes Plano (Envoy admin + brightstaff). This is
|
||||
# a complete Prometheus config — mount it directly at
|
||||
# /etc/prometheus/prometheus.yml. The included docker-compose.yaml does this
|
||||
# for you.
|
||||
#
|
||||
# Targets:
|
||||
# - envoy:9901 Envoy admin → envoy_cluster_*, envoy_http_*, envoy_server_*.
|
||||
# - brightstaff:9092 Native dataplane → brightstaff_http_*, brightstaff_llm_*,
|
||||
# brightstaff_router_*, process_*.
|
||||
#
|
||||
# Hostname `host.docker.internal` works on Docker Desktop (Mac/Win) and on
|
||||
# Linux when the container is started with `--add-host=host.docker.internal:
|
||||
# host-gateway` (the included compose does this). If Plano runs *inside*
|
||||
# Docker on the same network as Prometheus, replace it with the container
|
||||
# name (e.g. `plano:9092`).
|
||||
#
|
||||
# This file is unrelated to demos/llm_routing/model_routing_service/prometheus.yaml,
|
||||
# which scrapes a fake metrics service to feed the routing engine.
|
||||
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
scrape_timeout: 10s
|
||||
evaluation_interval: 15s
|
||||
|
||||
scrape_configs:
|
||||
- job_name: envoy
|
||||
honor_timestamps: true
|
||||
metrics_path: /stats
|
||||
params:
|
||||
format: ["prometheus"]
|
||||
static_configs:
|
||||
- targets:
|
||||
- host.docker.internal:9901
|
||||
labels:
|
||||
service: plano
|
||||
|
||||
- job_name: brightstaff
|
||||
honor_timestamps: true
|
||||
metrics_path: /metrics
|
||||
static_configs:
|
||||
- targets:
|
||||
- host.docker.internal:9092
|
||||
labels:
|
||||
service: plano
|
||||
15
config/grafana/provisioning/dashboards/brightstaff.yaml
Normal file
15
config/grafana/provisioning/dashboards/brightstaff.yaml
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# Auto-load the brightstaff dashboard JSON on Grafana startup.
|
||||
|
||||
apiVersion: 1
|
||||
|
||||
providers:
|
||||
- name: brightstaff
|
||||
orgId: 1
|
||||
folder: Plano
|
||||
type: file
|
||||
disableDeletion: false
|
||||
updateIntervalSeconds: 30
|
||||
allowUiUpdates: true
|
||||
options:
|
||||
path: /var/lib/grafana/dashboards
|
||||
foldersFromFilesStructure: false
|
||||
14
config/grafana/provisioning/datasources/prometheus.yaml
Normal file
14
config/grafana/provisioning/datasources/prometheus.yaml
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# Auto-provision the Prometheus datasource so the bundled dashboard wires up
|
||||
# without any clicks. The `uid: DS_PROMETHEUS` matches the templated input in
|
||||
# brightstaff_dashboard.json.
|
||||
|
||||
apiVersion: 1
|
||||
|
||||
datasources:
|
||||
- name: Prometheus
|
||||
uid: DS_PROMETHEUS
|
||||
type: prometheus
|
||||
access: proxy
|
||||
url: http://prometheus:9090
|
||||
isDefault: true
|
||||
editable: true
|
||||
|
|
@ -9,6 +9,7 @@ properties:
|
|||
- 0.1-beta
|
||||
- 0.2.0
|
||||
- v0.3.0
|
||||
- v0.4.0
|
||||
|
||||
agents:
|
||||
type: array
|
||||
|
|
@ -85,7 +86,7 @@ properties:
|
|||
type: string
|
||||
default:
|
||||
type: boolean
|
||||
filter_chain:
|
||||
input_filters:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
|
|
@ -93,6 +94,14 @@ properties:
|
|||
required:
|
||||
- id
|
||||
- description
|
||||
input_filters:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
output_filters:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- type
|
||||
|
|
@ -173,15 +182,26 @@ properties:
|
|||
provider_interface:
|
||||
type: string
|
||||
enum:
|
||||
- arch
|
||||
- plano
|
||||
- claude
|
||||
- deepseek
|
||||
- groq
|
||||
- mistral
|
||||
- openai
|
||||
- xiaomi
|
||||
- gemini
|
||||
- chatgpt
|
||||
- digitalocean
|
||||
- vercel
|
||||
- openrouter
|
||||
headers:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: "Additional headers to send with upstream requests (e.g., ChatGPT-Account-Id, originator)."
|
||||
routing_preferences:
|
||||
type: array
|
||||
description: "[DEPRECATED] Inline routing_preferences under a model_provider are auto-migrated to the top-level routing_preferences list by the config generator. New configs should declare routing_preferences at the top level with an explicit models: [...] list. See docs/routing-api.md."
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -220,15 +240,26 @@ properties:
|
|||
provider_interface:
|
||||
type: string
|
||||
enum:
|
||||
- arch
|
||||
- plano
|
||||
- claude
|
||||
- deepseek
|
||||
- groq
|
||||
- mistral
|
||||
- openai
|
||||
- xiaomi
|
||||
- gemini
|
||||
- chatgpt
|
||||
- digitalocean
|
||||
- vercel
|
||||
- openrouter
|
||||
headers:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: "Additional headers to send with upstream requests (e.g., ChatGPT-Account-Id, originator)."
|
||||
routing_preferences:
|
||||
type: array
|
||||
description: "[DEPRECATED] Inline routing_preferences under an llm_provider are auto-migrated to the top-level routing_preferences list by the config generator. New configs should declare routing_preferences at the top level with an explicit models: [...] list. See docs/routing-api.md."
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -265,12 +296,24 @@ properties:
|
|||
type: boolean
|
||||
use_agent_orchestrator:
|
||||
type: boolean
|
||||
disable_signals:
|
||||
type: boolean
|
||||
description: "Disable agentic signal analysis (frustration, repetition, escalation, etc.) on LLM responses to save CPU. Default false."
|
||||
upstream_connect_timeout:
|
||||
type: string
|
||||
description: "Connect timeout for upstream provider clusters (e.g., '5s', '10s'). Default is '5s'."
|
||||
upstream_tls_ca_path:
|
||||
type: string
|
||||
description: "Path to the trusted CA bundle for upstream TLS verification. Default is '/etc/ssl/certs/ca-certificates.crt'."
|
||||
llm_routing_model:
|
||||
type: string
|
||||
description: "Model name for the LLM router (e.g., 'Plano-Orchestrator'). Must match a model in model_providers."
|
||||
agent_orchestration_model:
|
||||
type: string
|
||||
description: "Model name for the agent orchestrator (e.g., 'Plano-Orchestrator'). Must match a model in model_providers."
|
||||
orchestrator_model_context_length:
|
||||
type: integer
|
||||
description: "Maximum token length for the orchestrator/routing model context window. Default is 8192."
|
||||
system_prompt:
|
||||
type: string
|
||||
prompt_targets:
|
||||
|
|
@ -415,6 +458,34 @@ properties:
|
|||
type: string
|
||||
model:
|
||||
type: string
|
||||
session_ttl_seconds:
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: TTL in seconds for session-pinned routing cache entries. Default 600 (10 minutes).
|
||||
session_max_entries:
|
||||
type: integer
|
||||
minimum: 1
|
||||
maximum: 10000
|
||||
description: Maximum number of session-pinned routing cache entries. Default 10000.
|
||||
session_cache:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum:
|
||||
- memory
|
||||
- redis
|
||||
default: memory
|
||||
description: Session cache backend. "memory" (default) is in-process; "redis" is shared across replicas.
|
||||
url:
|
||||
type: string
|
||||
description: Redis URL, e.g. redis://localhost:6379. Required when type is redis.
|
||||
tenant_header:
|
||||
type: string
|
||||
description: >
|
||||
Optional HTTP header name whose value is used as a tenant prefix in the cache key.
|
||||
When set, keys are scoped as plano:affinity:{tenant_id}:{session_id}.
|
||||
additionalProperties: false
|
||||
additionalProperties: false
|
||||
state_storage:
|
||||
type: object
|
||||
|
|
@ -464,6 +535,88 @@ properties:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- jailbreak
|
||||
routing_preferences:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
models:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
minItems: 1
|
||||
selection_policy:
|
||||
type: object
|
||||
properties:
|
||||
prefer:
|
||||
type: string
|
||||
enum:
|
||||
- cheapest
|
||||
- fastest
|
||||
- none
|
||||
additionalProperties: false
|
||||
required:
|
||||
- prefer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- name
|
||||
- description
|
||||
- models
|
||||
|
||||
model_metrics_sources:
|
||||
type: array
|
||||
items:
|
||||
oneOf:
|
||||
- type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: cost
|
||||
provider:
|
||||
type: string
|
||||
enum:
|
||||
- digitalocean
|
||||
refresh_interval:
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "Refresh interval in seconds"
|
||||
model_aliases:
|
||||
type: object
|
||||
description: "Map DO catalog keys (lowercase(creator)/model_id) to Plano model names used in routing_preferences. Example: 'openai/openai-gpt-oss-120b: openai/gpt-4o'"
|
||||
additionalProperties:
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
- provider
|
||||
additionalProperties: false
|
||||
- type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
const: latency
|
||||
provider:
|
||||
type: string
|
||||
enum:
|
||||
- prometheus
|
||||
url:
|
||||
type: string
|
||||
query:
|
||||
type: string
|
||||
refresh_interval:
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "Refresh interval in seconds"
|
||||
required:
|
||||
- type
|
||||
- provider
|
||||
- url
|
||||
- query
|
||||
additionalProperties: false
|
||||
|
||||
additionalProperties: false
|
||||
required:
|
||||
- version
|
||||
|
|
|
|||
|
|
@ -1,14 +1,33 @@
|
|||
[supervisord]
|
||||
nodaemon=true
|
||||
pidfile=/var/run/supervisord.pid
|
||||
|
||||
[program:config_generator]
|
||||
command=/bin/sh -c "\
|
||||
uv run python -m planoai.config_generator && \
|
||||
envsubst < /app/plano_config_rendered.yaml > /app/plano_config_rendered.env_sub.yaml && \
|
||||
envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && \
|
||||
touch /tmp/config_ready || \
|
||||
(echo 'Config generation failed, shutting down'; kill -15 $(cat /var/run/supervisord.pid))"
|
||||
priority=10
|
||||
autorestart=false
|
||||
startsecs=0
|
||||
stdout_logfile=/dev/stdout
|
||||
redirect_stderr=true
|
||||
stdout_logfile_maxbytes=0
|
||||
stderr_logfile_maxbytes=0
|
||||
|
||||
[program:brightstaff]
|
||||
command=sh -c "\
|
||||
envsubst < /app/plano_config_rendered.yaml > /app/plano_config_rendered.env_sub.yaml && \
|
||||
while [ ! -f /tmp/config_ready ]; do echo '[brightstaff] Waiting for config generation...'; sleep 0.5; done && \
|
||||
RUST_LOG=${LOG_LEVEL:-info} \
|
||||
PLANO_CONFIG_PATH_RENDERED=/app/plano_config_rendered.env_sub.yaml \
|
||||
/app/brightstaff 2>&1 | \
|
||||
tee /var/log/brightstaff.log | \
|
||||
while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done"
|
||||
while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done; \
|
||||
echo '[brightstaff] Process exited, shutting down'; kill -15 $(cat /var/run/supervisord.pid)"
|
||||
priority=20
|
||||
autorestart=false
|
||||
stdout_logfile=/dev/stdout
|
||||
redirect_stderr=true
|
||||
stdout_logfile_maxbytes=0
|
||||
|
|
@ -16,13 +35,15 @@ stderr_logfile_maxbytes=0
|
|||
|
||||
[program:envoy]
|
||||
command=/bin/sh -c "\
|
||||
uv run python -m planoai.config_generator && \
|
||||
envsubst < /etc/envoy/envoy.yaml > /etc/envoy.env_sub.yaml && \
|
||||
envoy -c /etc/envoy.env_sub.yaml \
|
||||
--component-log-level wasm:${LOG_LEVEL:-info} \
|
||||
--log-format '[%%Y-%%m-%%d %%T.%%e][%%l] %%v' 2>&1 | \
|
||||
while [ ! -f /tmp/config_ready ]; do echo '[plano_logs] Waiting for config generation...'; sleep 0.5; done && \
|
||||
envoy -c /etc/envoy.env_sub.yaml \
|
||||
--component-log-level wasm:${LOG_LEVEL:-info} \
|
||||
--log-format '[%%Y-%%m-%%d %%T.%%e][%%l] %%v' 2>&1 | \
|
||||
tee /var/log/envoy.log | \
|
||||
while IFS= read -r line; do echo '[plano_logs]' \"$line\"; done"
|
||||
while IFS= read -r line; do echo '[plano_logs]' \"$line\"; done; \
|
||||
echo '[plano_logs] Process exited, shutting down'; kill -15 $(cat /var/run/supervisord.pid)"
|
||||
priority=20
|
||||
autorestart=false
|
||||
stdout_logfile=/dev/stdout
|
||||
redirect_stderr=true
|
||||
stdout_logfile_maxbytes=0
|
||||
|
|
|
|||
|
|
@ -1,6 +1,16 @@
|
|||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
CLI_DIR="$REPO_ROOT/cli"
|
||||
|
||||
# Use uv run if available and cli/ has a pyproject.toml, otherwise fall back to bare python
|
||||
if command -v uv &> /dev/null && [ -f "$CLI_DIR/pyproject.toml" ]; then
|
||||
PYTHON_CMD="uv run --directory $CLI_DIR python"
|
||||
else
|
||||
PYTHON_CMD="python"
|
||||
fi
|
||||
|
||||
failed_files=()
|
||||
|
||||
for file in $(find . -name config.yaml -o -name plano_config_full_reference.yaml); do
|
||||
|
|
@ -14,7 +24,7 @@ for file in $(find . -name config.yaml -o -name plano_config_full_reference.yaml
|
|||
ENVOY_CONFIG_TEMPLATE_FILE="envoy.template.yaml" \
|
||||
PLANO_CONFIG_FILE_RENDERED="$rendered_file" \
|
||||
ENVOY_CONFIG_FILE_RENDERED="/dev/null" \
|
||||
python -m planoai.config_generator 2>&1 > /dev/null
|
||||
$PYTHON_CMD -m planoai.config_generator 2>&1 > /dev/null
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Validation failed for $file"
|
||||
|
|
|
|||
2135
crates/Cargo.lock
generated
2135
crates/Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -3,6 +3,18 @@ name = "brightstaff"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["jemalloc"]
|
||||
jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"]
|
||||
|
||||
[[bin]]
|
||||
name = "brightstaff"
|
||||
path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "signals_replay"
|
||||
path = "src/bin/signals_replay.rs"
|
||||
|
||||
[dependencies]
|
||||
async-openai = "0.30.1"
|
||||
async-trait = "0.1"
|
||||
|
|
@ -26,6 +38,12 @@ opentelemetry-stdout = "0.31"
|
|||
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
|
||||
pretty_assertions = "1.4.1"
|
||||
rand = "0.9.2"
|
||||
regex = "1.10"
|
||||
lru = "0.12"
|
||||
metrics = "0.23"
|
||||
metrics-exporter-prometheus = { version = "0.15", default-features = false, features = ["http-listener"] }
|
||||
metrics-process = "2.1"
|
||||
redis = { version = "0.27", features = ["tokio-comp"] }
|
||||
reqwest = { version = "0.12.15", features = ["stream"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
|
|
@ -33,6 +51,8 @@ serde_with = "3.13.0"
|
|||
strsim = "0.11"
|
||||
serde_yaml = "0.9.34"
|
||||
thiserror = "2.0.12"
|
||||
tikv-jemallocator = { version = "0.6", optional = true }
|
||||
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true }
|
||||
tokio = { version = "1.44.2", features = ["full"] }
|
||||
tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] }
|
||||
tokio-stream = "0.1"
|
||||
|
|
|
|||
30
crates/brightstaff/src/app_state.rs
Normal file
30
crates/brightstaff/src/app_state.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::configuration::{Agent, FilterPipeline, Listener, ModelAlias, SpanAttributes};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::state::StateStorage;
|
||||
|
||||
/// Shared application state bundled into a single Arc-wrapped struct.
|
||||
///
|
||||
/// Instead of cloning 8+ individual `Arc`s per connection, a single
|
||||
/// `Arc<AppState>` is cloned once and passed to the request handler.
|
||||
pub struct AppState {
|
||||
pub orchestrator_service: Arc<OrchestratorService>,
|
||||
pub model_aliases: Option<HashMap<String, ModelAlias>>,
|
||||
pub llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
pub agents_list: Option<Vec<Agent>>,
|
||||
pub listeners: Vec<Listener>,
|
||||
pub state_storage: Option<Arc<dyn StateStorage>>,
|
||||
pub llm_provider_url: String,
|
||||
pub span_attributes: Option<SpanAttributes>,
|
||||
/// Shared HTTP client for upstream LLM requests (connection pooling / keep-alive).
|
||||
pub http_client: reqwest::Client,
|
||||
pub filter_pipeline: Arc<FilterPipeline>,
|
||||
/// When false, agentic signal analysis is skipped on LLM responses to save CPU.
|
||||
/// Controlled by `overrides.disable_signals` in plano config.
|
||||
pub signals_enabled: bool,
|
||||
}
|
||||
175
crates/brightstaff/src/bin/signals_replay.rs
Normal file
175
crates/brightstaff/src/bin/signals_replay.rs
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
//! `signals-replay` — batch driver for the `brightstaff` signal analyzer.
|
||||
//!
|
||||
//! Reads JSONL conversations from stdin (one per line) and emits matching
|
||||
//! JSONL reports on stdout, one per input conversation, in the same order.
|
||||
//!
|
||||
//! Input shape (per line):
|
||||
//! ```json
|
||||
//! {"id": "convo-42", "messages": [{"from": "human", "value": "..."}, ...]}
|
||||
//! ```
|
||||
//!
|
||||
//! Output shape (per line, success):
|
||||
//! ```json
|
||||
//! {"id": "convo-42", "report": { ...python-compatible SignalReport dict... }}
|
||||
//! ```
|
||||
//!
|
||||
//! On per-line failure (parse / analyzer error), emits:
|
||||
//! ```json
|
||||
//! {"id": "convo-42", "error": "..."}
|
||||
//! ```
|
||||
//!
|
||||
//! The output report dict is shaped to match the Python reference's
|
||||
//! `SignalReport.to_dict()` byte-for-byte so the parity comparator can do a
|
||||
//! direct structural diff.
|
||||
|
||||
use std::io::{self, BufRead, BufWriter, Write};
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Map, Value};
|
||||
|
||||
use brightstaff::signals::{SignalAnalyzer, SignalGroup, SignalReport};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct InputLine {
|
||||
id: Value,
|
||||
messages: Vec<MessageRow>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MessageRow {
|
||||
#[serde(default)]
|
||||
from: String,
|
||||
#[serde(default)]
|
||||
value: String,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let stdin = io::stdin();
|
||||
let stdout = io::stdout();
|
||||
let mut out = BufWriter::new(stdout.lock());
|
||||
let analyzer = SignalAnalyzer::default();
|
||||
|
||||
for line in stdin.lock().lines() {
|
||||
let line = match line {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
eprintln!("read error: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let result = process_line(&analyzer, trimmed);
|
||||
// Always emit one line per input line so id ordering stays aligned.
|
||||
if let Err(e) = writeln!(out, "{result}") {
|
||||
eprintln!("write error: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
// Flush periodically isn't strictly needed — BufWriter handles it,
|
||||
// and the parent process reads the whole stream when we're done.
|
||||
}
|
||||
let _ = out.flush();
|
||||
}
|
||||
|
||||
fn process_line(analyzer: &SignalAnalyzer, line: &str) -> Value {
|
||||
let parsed: InputLine = match serde_json::from_str(line) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return json!({
|
||||
"id": Value::Null,
|
||||
"error": format!("input parse: {e}"),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let id = parsed.id.clone();
|
||||
|
||||
let view: Vec<brightstaff::signals::analyzer::ShareGptMessage<'_>> = parsed
|
||||
.messages
|
||||
.iter()
|
||||
.map(|m| brightstaff::signals::analyzer::ShareGptMessage {
|
||||
from: m.from.as_str(),
|
||||
value: m.value.as_str(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let report = analyzer.analyze_sharegpt(&view);
|
||||
let report_dict = report_to_python_dict(&report);
|
||||
json!({
|
||||
"id": id,
|
||||
"report": report_dict,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert a `SignalReport` into the Python reference's `to_dict()` shape.
|
||||
///
|
||||
/// Ordering of category keys in each layer dict follows the Python source
|
||||
/// exactly so even string-equality comparisons behave deterministically.
|
||||
fn report_to_python_dict(r: &SignalReport) -> Value {
|
||||
let mut interaction = Map::new();
|
||||
interaction.insert(
|
||||
"misalignment".to_string(),
|
||||
signal_group_to_python(&r.interaction.misalignment),
|
||||
);
|
||||
interaction.insert(
|
||||
"stagnation".to_string(),
|
||||
signal_group_to_python(&r.interaction.stagnation),
|
||||
);
|
||||
interaction.insert(
|
||||
"disengagement".to_string(),
|
||||
signal_group_to_python(&r.interaction.disengagement),
|
||||
);
|
||||
interaction.insert(
|
||||
"satisfaction".to_string(),
|
||||
signal_group_to_python(&r.interaction.satisfaction),
|
||||
);
|
||||
|
||||
let mut execution = Map::new();
|
||||
execution.insert(
|
||||
"failure".to_string(),
|
||||
signal_group_to_python(&r.execution.failure),
|
||||
);
|
||||
execution.insert(
|
||||
"loops".to_string(),
|
||||
signal_group_to_python(&r.execution.loops),
|
||||
);
|
||||
|
||||
let mut environment = Map::new();
|
||||
environment.insert(
|
||||
"exhaustion".to_string(),
|
||||
signal_group_to_python(&r.environment.exhaustion),
|
||||
);
|
||||
|
||||
json!({
|
||||
"interaction_signals": Value::Object(interaction),
|
||||
"execution_signals": Value::Object(execution),
|
||||
"environment_signals": Value::Object(environment),
|
||||
"overall_quality": r.overall_quality.as_str(),
|
||||
"summary": r.summary,
|
||||
})
|
||||
}
|
||||
|
||||
fn signal_group_to_python(g: &SignalGroup) -> Value {
|
||||
let signals: Vec<Value> = g
|
||||
.signals
|
||||
.iter()
|
||||
.map(|s| {
|
||||
json!({
|
||||
"signal_type": s.signal_type.as_str(),
|
||||
"message_index": s.message_index,
|
||||
"snippet": s.snippet,
|
||||
"confidence": s.confidence,
|
||||
"metadata": s.metadata,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({
|
||||
"category": g.category,
|
||||
"count": g.count,
|
||||
"severity": g.severity,
|
||||
"signals": signals,
|
||||
})
|
||||
}
|
||||
41
crates/brightstaff/src/handlers/agents/errors.rs
Normal file
41
crates/brightstaff/src/handlers/agents/errors.rs
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use hyper::Response;
|
||||
use serde_json::json;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::handlers::response::ResponseHandler;
|
||||
|
||||
/// Build a JSON error response from an `AgentFilterChainError`, logging the
|
||||
/// full error chain along the way.
|
||||
///
|
||||
/// Returns `Ok(Response)` so it can be used directly as a handler return value.
|
||||
pub fn build_error_chain_response<E: std::error::Error>(
|
||||
err: &E,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let mut error_chain = Vec::new();
|
||||
let mut current: &dyn std::error::Error = err;
|
||||
loop {
|
||||
error_chain.push(current.to_string());
|
||||
match current.source() {
|
||||
Some(source) => current = source,
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
warn!(error_chain = ?error_chain, "agent chat error chain");
|
||||
warn!(root_error = ?err, "root error");
|
||||
|
||||
let error_json = json!({
|
||||
"error": {
|
||||
"type": "AgentFilterChainError",
|
||||
"message": err.to_string(),
|
||||
"error_chain": error_chain,
|
||||
"debug_info": format!("{:?}", err)
|
||||
}
|
||||
});
|
||||
|
||||
info!(error = %error_json, "structured error info");
|
||||
|
||||
Ok(ResponseHandler::create_json_error_response(&error_json))
|
||||
}
|
||||
5
crates/brightstaff/src/handlers/agents/mod.rs
Normal file
5
crates/brightstaff/src/handlers/agents/mod.rs
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
pub mod errors;
|
||||
pub mod jsonrpc;
|
||||
pub mod orchestrator;
|
||||
pub mod pipeline;
|
||||
pub mod selector;
|
||||
|
|
@ -2,63 +2,56 @@ use std::sync::Arc;
|
|||
use std::time::Instant;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::configuration::SpanAttributes;
|
||||
use common::errors::BrightStaffError;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use hermesllm::apis::OpenAIMessage;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::providers::request::ProviderRequest;
|
||||
use hermesllm::ProviderRequestType;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper::{Request, Response};
|
||||
use opentelemetry::trace::get_active_span;
|
||||
use serde::ser::Error as SerError;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use super::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use super::response_handler::ResponseHandler;
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
use super::errors::build_error_chain_response;
|
||||
use super::pipeline::{PipelineError, PipelineProcessor};
|
||||
use super::selector::{AgentSelectionError, AgentSelector};
|
||||
use crate::app_state::AppState;
|
||||
use crate::handlers::extract_request_id;
|
||||
use crate::handlers::response::ResponseHandler;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
/// Main errors for agent chat completions
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AgentFilterChainError {
|
||||
#[error("Forwarded error: {0}")]
|
||||
Brightstaff(#[from] BrightStaffError),
|
||||
#[error("Agent selection error: {0}")]
|
||||
Selection(#[from] AgentSelectionError),
|
||||
#[error("Pipeline processing error: {0}")]
|
||||
Pipeline(#[from] PipelineError),
|
||||
#[error("Response handling error: {0}")]
|
||||
Response(#[from] common::errors::BrightStaffError),
|
||||
#[error("Request parsing error: {0}")]
|
||||
RequestParsing(#[from] serde_json::Error),
|
||||
RequestParsing(String),
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(#[from] hyper::Error),
|
||||
#[error("Unsupported endpoint: {0}")]
|
||||
UnsupportedEndpoint(String),
|
||||
#[error("No agents configured")]
|
||||
NoAgentsConfigured,
|
||||
#[error("Agent '{0}' not found in configuration")]
|
||||
AgentNotFound(String),
|
||||
#[error("No messages in conversation history")]
|
||||
EmptyHistory,
|
||||
#[error("Agent chain completed without producing a response")]
|
||||
IncompleteChain,
|
||||
}
|
||||
|
||||
pub async fn agent_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
_: String,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
span_attributes: Arc<Option<SpanAttributes>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_id = extract_request_id(&request);
|
||||
let custom_attrs =
|
||||
collect_custom_trace_attributes(request.headers(), span_attributes.as_ref().as_ref());
|
||||
// Extract request_id from headers or generate a new one
|
||||
let request_id: String = match request
|
||||
.headers()
|
||||
.get(common::consts::REQUEST_ID_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
{
|
||||
Some(id) => id,
|
||||
None => uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
collect_custom_trace_attributes(request.headers(), state.span_attributes.as_ref());
|
||||
|
||||
// Create a span with request_id that will be included in all log lines
|
||||
let request_span = info_span!(
|
||||
|
|
@ -74,17 +67,7 @@ pub async fn agent_chat(
|
|||
// Set service name for orchestrator operations
|
||||
set_service_name(operation_component::ORCHESTRATOR);
|
||||
|
||||
match handle_agent_chat_inner(
|
||||
request,
|
||||
orchestrator_service,
|
||||
agents_list,
|
||||
listeners,
|
||||
llm_providers,
|
||||
request_id,
|
||||
custom_attrs,
|
||||
)
|
||||
.await
|
||||
{
|
||||
match handle_agent_chat_inner(request, state, request_id, custom_attrs).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
// Check if this is a client error from the pipeline that should be cascaded
|
||||
|
|
@ -101,7 +84,6 @@ pub async fn agent_chat(
|
|||
"client error from agent"
|
||||
);
|
||||
|
||||
// Create error response with the original status code and body
|
||||
let error_json = serde_json::json!({
|
||||
"error": "ClientError",
|
||||
"agent": agent,
|
||||
|
|
@ -109,52 +91,19 @@ pub async fn agent_chat(
|
|||
"agent_response": body
|
||||
});
|
||||
|
||||
let status_code = hyper::StatusCode::from_u16(*status)
|
||||
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
let json_string = error_json.to_string();
|
||||
return Ok(BrightStaffError::ForwardedError {
|
||||
status_code,
|
||||
message: json_string,
|
||||
}
|
||||
.into_response());
|
||||
let mut response =
|
||||
Response::new(ResponseHandler::create_full_body(json_string));
|
||||
*response.status_mut() = hyper::StatusCode::from_u16(*status)
|
||||
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
// Print detailed error information with full error chain for other errors
|
||||
let mut error_chain = Vec::new();
|
||||
let mut current_error: &dyn std::error::Error = &err;
|
||||
|
||||
// Collect the full error chain
|
||||
loop {
|
||||
error_chain.push(current_error.to_string());
|
||||
match current_error.source() {
|
||||
Some(source) => current_error = source,
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Log the complete error chain
|
||||
warn!(error_chain = ?error_chain, "agent chat error chain");
|
||||
warn!(root_error = ?err, "root error");
|
||||
|
||||
// Create structured error response as JSON
|
||||
let error_json = serde_json::json!({
|
||||
"error": {
|
||||
"type": "AgentFilterChainError",
|
||||
"message": err.to_string(),
|
||||
"error_chain": error_chain,
|
||||
"debug_info": format!("{:?}", err)
|
||||
}
|
||||
});
|
||||
|
||||
// Log the error for debugging
|
||||
info!(error = %error_json, "structured error info");
|
||||
|
||||
Ok(BrightStaffError::ForwardedError {
|
||||
status_code: StatusCode::BAD_REQUEST,
|
||||
message: error_json.to_string(),
|
||||
}
|
||||
.into_response())
|
||||
build_error_chain_response(&err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -162,19 +111,22 @@ pub async fn agent_chat(
|
|||
.await
|
||||
}
|
||||
|
||||
async fn handle_agent_chat_inner(
|
||||
/// Parsed and validated agent request data.
|
||||
struct AgentRequest {
|
||||
client_request: ProviderRequestType,
|
||||
messages: Vec<OpenAIMessage>,
|
||||
request_headers: hyper::HeaderMap,
|
||||
request_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Parse the incoming HTTP request, resolve the listener, and extract messages.
|
||||
async fn parse_agent_request(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
request_id: String,
|
||||
custom_attrs: std::collections::HashMap<String, String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
// Initialize services
|
||||
let agent_selector = AgentSelector::new(orchestrator_service);
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let response_handler = ResponseHandler::new();
|
||||
state: &AppState,
|
||||
request_id: &str,
|
||||
custom_attrs: &std::collections::HashMap<String, String>,
|
||||
) -> Result<(AgentRequest, common::configuration::Listener, AgentSelector), AgentFilterChainError> {
|
||||
let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service));
|
||||
|
||||
// Extract listener name from headers
|
||||
let listener_name = request
|
||||
|
|
@ -183,16 +135,11 @@ async fn handle_agent_chat_inner(
|
|||
.and_then(|name| name.to_str().ok());
|
||||
|
||||
// Find the appropriate listener
|
||||
let listener: common::configuration::Listener = {
|
||||
let listeners = listeners.read().await;
|
||||
agent_selector
|
||||
.find_listener(listener_name, &listeners)
|
||||
.await?
|
||||
};
|
||||
let listener = agent_selector.find_listener(listener_name, &state.listeners)?;
|
||||
|
||||
get_active_span(|span| {
|
||||
span.update_name(listener.name.to_string());
|
||||
for (key, value) in &custom_attrs {
|
||||
for (key, value) in custom_attrs {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
||||
}
|
||||
});
|
||||
|
|
@ -200,24 +147,20 @@ async fn handle_agent_chat_inner(
|
|||
info!(listener = %listener.name, "handling request");
|
||||
|
||||
// Parse request body
|
||||
let request_path = request
|
||||
.uri()
|
||||
.path()
|
||||
.to_string()
|
||||
let full_path = request.uri().path().to_string();
|
||||
let request_path = full_path
|
||||
.strip_prefix("/agents")
|
||||
.unwrap()
|
||||
.unwrap_or(&full_path)
|
||||
.to_string();
|
||||
|
||||
let request_headers = {
|
||||
let mut headers = request.headers().clone();
|
||||
headers.remove(common::consts::ENVOY_ORIGINAL_PATH_HEADER);
|
||||
|
||||
// Set the request_id in headers if not already present
|
||||
if !headers.contains_key(common::consts::REQUEST_ID_HEADER) {
|
||||
headers.insert(
|
||||
common::consts::REQUEST_ID_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&request_id).unwrap(),
|
||||
);
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
|
||||
headers.insert(common::consts::REQUEST_ID_HEADER, val);
|
||||
}
|
||||
}
|
||||
|
||||
headers
|
||||
|
|
@ -230,63 +173,62 @@ async fn handle_agent_chat_inner(
|
|||
"received request body"
|
||||
);
|
||||
|
||||
// Determine the API type from the endpoint
|
||||
let api_type =
|
||||
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
|
||||
let err_msg = format!("Unsupported endpoint: {}", request_path);
|
||||
warn!("{}", err_msg);
|
||||
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
|
||||
warn!(path = %request_path, "unsupported endpoint");
|
||||
AgentFilterChainError::UnsupportedEndpoint(request_path.clone())
|
||||
})?;
|
||||
|
||||
let mut client_request =
|
||||
match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
return Err(AgentFilterChainError::RequestParsing(
|
||||
serde_json::Error::custom(err_msg),
|
||||
));
|
||||
}
|
||||
};
|
||||
let client_request = ProviderRequestType::try_from((&chat_request_bytes[..], &api_type))
|
||||
.map_err(|err| {
|
||||
warn!(error = %err, "failed to parse request as ProviderRequestType");
|
||||
AgentFilterChainError::RequestParsing(format!("Failed to parse request: {}", err))
|
||||
})?;
|
||||
|
||||
// If model is not specified in the request, resolve from default provider
|
||||
if client_request.model().is_empty() {
|
||||
match llm_providers.read().await.default() {
|
||||
Some(default_provider) => {
|
||||
let default_model = default_provider.name.clone();
|
||||
info!(default_model = %default_model, "no model specified in request, using default provider");
|
||||
client_request.set_model(default_model);
|
||||
}
|
||||
None => {
|
||||
let err_msg = "No model specified in request and no default provider configured";
|
||||
warn!("{}", err_msg);
|
||||
return Ok(BrightStaffError::NoModelSpecified.into_response());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let message: Vec<OpenAIMessage> = client_request.get_messages();
|
||||
let messages: Vec<OpenAIMessage> = client_request.get_messages();
|
||||
|
||||
let request_id = request_headers
|
||||
.get(common::consts::REQUEST_ID_HEADER)
|
||||
.and_then(|val| val.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Create agent map for pipeline processing and agent selection
|
||||
let agent_map = {
|
||||
let agents = agents_list.read().await;
|
||||
let agents = agents.as_ref().unwrap();
|
||||
agent_selector.create_agent_map(agents)
|
||||
};
|
||||
Ok((
|
||||
AgentRequest {
|
||||
client_request,
|
||||
messages,
|
||||
request_headers,
|
||||
request_id,
|
||||
},
|
||||
listener,
|
||||
agent_selector,
|
||||
))
|
||||
}
|
||||
|
||||
/// Select agents via the orchestrator model and record selection metrics.
|
||||
async fn select_and_build_agent_map(
|
||||
agent_selector: &AgentSelector,
|
||||
state: &AppState,
|
||||
messages: &[OpenAIMessage],
|
||||
listener: &common::configuration::Listener,
|
||||
request_id: Option<String>,
|
||||
) -> Result<
|
||||
(
|
||||
Vec<common::configuration::AgentFilterChain>,
|
||||
std::collections::HashMap<String, common::configuration::Agent>,
|
||||
),
|
||||
AgentFilterChainError,
|
||||
> {
|
||||
let agents = state
|
||||
.agents_list
|
||||
.as_ref()
|
||||
.ok_or(AgentFilterChainError::NoAgentsConfigured)?;
|
||||
let agent_map = agent_selector.create_agent_map(agents);
|
||||
|
||||
// Select appropriate agents using arch orchestrator llm model
|
||||
let selection_start = Instant::now();
|
||||
let selected_agents = agent_selector
|
||||
.select_agents(&message, &listener, request_id.clone())
|
||||
.select_agents(messages, listener, request_id)
|
||||
.await?;
|
||||
|
||||
// Record selection attributes on the current orchestrator span
|
||||
let selection_elapsed_ms = selection_start.elapsed().as_secs_f64() * 1000.0;
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
|
|
@ -316,12 +258,25 @@ async fn handle_agent_chat_inner(
|
|||
"selected agents for execution"
|
||||
);
|
||||
|
||||
// Execute agents sequentially, passing output from one to the next
|
||||
let mut current_messages = message.clone();
|
||||
Ok((selected_agents, agent_map))
|
||||
}
|
||||
|
||||
/// Execute the agent chain: run each selected agent sequentially, streaming
|
||||
/// the final agent's response back to the client.
|
||||
async fn execute_agent_chain(
|
||||
selected_agents: &[common::configuration::AgentFilterChain],
|
||||
agent_map: &std::collections::HashMap<String, common::configuration::Agent>,
|
||||
client_request: ProviderRequestType,
|
||||
messages: Vec<OpenAIMessage>,
|
||||
request_headers: &hyper::HeaderMap,
|
||||
custom_attrs: &std::collections::HashMap<String, String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let response_handler = ResponseHandler::new();
|
||||
let mut current_messages = messages;
|
||||
let agent_count = selected_agents.len();
|
||||
|
||||
for (agent_index, selected_agent) in selected_agents.iter().enumerate() {
|
||||
// Get agent name
|
||||
let agent_name = selected_agent.id.clone();
|
||||
let is_last_agent = agent_index == agent_count - 1;
|
||||
|
||||
|
|
@ -332,18 +287,40 @@ async fn handle_agent_chat_inner(
|
|||
"processing agent"
|
||||
);
|
||||
|
||||
// Process the filter chain
|
||||
let chat_history = pipeline_processor
|
||||
.process_filter_chain(
|
||||
¤t_messages,
|
||||
selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
)
|
||||
.await?;
|
||||
let chat_history = if selected_agent
|
||||
.input_filters
|
||||
.as_ref()
|
||||
.map(|f| !f.is_empty())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let filter_body = serde_json::json!({
|
||||
"model": client_request.model(),
|
||||
"messages": current_messages,
|
||||
});
|
||||
let filter_bytes =
|
||||
serde_json::to_vec(&filter_body).map_err(PipelineError::ParseError)?;
|
||||
|
||||
// Get agent details and invoke
|
||||
let agent = agent_map.get(&agent_name).unwrap();
|
||||
let filtered_bytes = pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&filter_bytes,
|
||||
selected_agent,
|
||||
agent_map,
|
||||
request_headers,
|
||||
"/v1/chat/completions",
|
||||
)
|
||||
.await?;
|
||||
|
||||
let filtered_body: serde_json::Value =
|
||||
serde_json::from_slice(&filtered_bytes).map_err(PipelineError::ParseError)?;
|
||||
serde_json::from_value(filtered_body["messages"].clone())
|
||||
.map_err(PipelineError::ParseError)?
|
||||
} else {
|
||||
current_messages.clone()
|
||||
};
|
||||
|
||||
let agent = agent_map
|
||||
.get(&agent_name)
|
||||
.ok_or_else(|| AgentFilterChainError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
debug!(agent = %agent_name, "invoking agent");
|
||||
|
||||
|
|
@ -357,7 +334,7 @@ async fn handle_agent_chat_inner(
|
|||
set_service_name(operation_component::AGENT);
|
||||
get_active_span(|span| {
|
||||
span.update_name(format!("{} /v1/chat/completions", agent_name));
|
||||
for (key, value) in &custom_attrs {
|
||||
for (key, value) in custom_attrs {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
||||
}
|
||||
});
|
||||
|
|
@ -367,28 +344,25 @@ async fn handle_agent_chat_inner(
|
|||
&chat_history,
|
||||
client_request.clone(),
|
||||
agent,
|
||||
&request_headers,
|
||||
request_headers,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.instrument(agent_span.clone())
|
||||
.await?;
|
||||
|
||||
// If this is the last agent, return the streaming response
|
||||
if is_last_agent {
|
||||
info!(
|
||||
agent = %agent_name,
|
||||
"completed agent chain, returning response"
|
||||
);
|
||||
// Capture the orchestrator span (parent of the agent span) so it
|
||||
// stays open for the full streaming duration alongside the agent span.
|
||||
let orchestrator_span = tracing::Span::current();
|
||||
return async {
|
||||
response_handler
|
||||
.create_streaming_response(
|
||||
llm_response,
|
||||
tracing::Span::current(), // agent span (inner)
|
||||
orchestrator_span, // orchestrator span (outer)
|
||||
tracing::Span::current(),
|
||||
orchestrator_span,
|
||||
)
|
||||
.await
|
||||
.map_err(AgentFilterChainError::from)
|
||||
|
|
@ -397,7 +371,6 @@ async fn handle_agent_chat_inner(
|
|||
.await;
|
||||
}
|
||||
|
||||
// For intermediate agents, collect the full response and pass to next agent
|
||||
debug!(agent = %agent_name, "collecting response from intermediate agent");
|
||||
let response_text = async { response_handler.collect_full_response(llm_response).await }
|
||||
.instrument(agent_span)
|
||||
|
|
@ -409,11 +382,11 @@ async fn handle_agent_chat_inner(
|
|||
"agent completed, passing response to next agent"
|
||||
);
|
||||
|
||||
// remove last message and add new one at the end
|
||||
let last_message = current_messages.pop().unwrap();
|
||||
let Some(last_message) = current_messages.pop() else {
|
||||
warn!(agent = %agent_name, "no messages in conversation history");
|
||||
return Err(AgentFilterChainError::EmptyHistory);
|
||||
};
|
||||
|
||||
// Create a new message with the agent's response as assistant message
|
||||
// and add it to the conversation history
|
||||
current_messages.push(OpenAIMessage {
|
||||
role: hermesllm::apis::openai::Role::Assistant,
|
||||
content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)),
|
||||
|
|
@ -425,6 +398,34 @@ async fn handle_agent_chat_inner(
|
|||
current_messages.push(last_message);
|
||||
}
|
||||
|
||||
// This should never be reached since we return in the last agent iteration
|
||||
unreachable!("Agent execution loop should have returned a response")
|
||||
Err(AgentFilterChainError::IncompleteChain)
|
||||
}
|
||||
|
||||
async fn handle_agent_chat_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
state: Arc<AppState>,
|
||||
request_id: String,
|
||||
custom_attrs: std::collections::HashMap<String, String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
let (agent_req, listener, agent_selector) =
|
||||
parse_agent_request(request, &state, &request_id, &custom_attrs).await?;
|
||||
|
||||
let (selected_agents, agent_map) = select_and_build_agent_map(
|
||||
&agent_selector,
|
||||
&state,
|
||||
&agent_req.messages,
|
||||
&listener,
|
||||
agent_req.request_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
execute_agent_chain(
|
||||
&selected_agents,
|
||||
&agent_map,
|
||||
agent_req.client_request,
|
||||
agent_req.messages,
|
||||
&agent_req.request_headers,
|
||||
&custom_attrs,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, AgentFilterChain};
|
||||
use common::consts::{
|
||||
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
|
||||
|
|
@ -11,7 +12,7 @@ use opentelemetry::global;
|
|||
use opentelemetry_http::HeaderInjector;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
use crate::handlers::jsonrpc::{
|
||||
use super::jsonrpc::{
|
||||
JsonRpcId, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, JSON_RPC_VERSION,
|
||||
MCP_INITIALIZE, MCP_INITIALIZE_NOTIFICATION, TOOL_CALL_METHOD,
|
||||
};
|
||||
|
|
@ -35,8 +36,6 @@ pub enum PipelineError {
|
|||
NoResultInResponse(String),
|
||||
#[error("No structured content in response from agent '{0}'")]
|
||||
NoStructuredContentInResponse(String),
|
||||
#[error("No messages in response from agent '{0}'")]
|
||||
NoMessagesInResponse(String),
|
||||
#[error("Client error from agent '{agent}' (HTTP {status}): {body}")]
|
||||
ClientError {
|
||||
agent: String,
|
||||
|
|
@ -79,74 +78,11 @@ impl PipelineProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
// /// Process the filter chain of agents (all except the terminal agent)
|
||||
// #[instrument(
|
||||
// skip(self, chat_history, agent_filter_chain, agent_map, request_headers),
|
||||
// fields(
|
||||
// filter_count = agent_filter_chain.filter_chain.as_ref().map(|fc| fc.len()).unwrap_or(0),
|
||||
// message_count = chat_history.len()
|
||||
// )
|
||||
// )]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn process_filter_chain(
|
||||
&mut self,
|
||||
chat_history: &[Message],
|
||||
agent_filter_chain: &AgentFilterChain,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
let mut chat_history_updated = chat_history.to_vec();
|
||||
|
||||
// If filter_chain is None or empty, proceed without filtering
|
||||
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
|
||||
Some(fc) if !fc.is_empty() => fc,
|
||||
_ => return Ok(chat_history_updated),
|
||||
};
|
||||
|
||||
for agent_name in filter_chain {
|
||||
debug!(agent = %agent_name, "processing filter agent");
|
||||
|
||||
let agent = agent_map
|
||||
.get(agent_name)
|
||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
|
||||
info!(
|
||||
agent = %agent_name,
|
||||
tool = %tool_name,
|
||||
url = %agent.url,
|
||||
agent_type = %agent.agent_type.as_deref().unwrap_or("mcp"),
|
||||
conversation_len = chat_history.len(),
|
||||
"executing filter"
|
||||
);
|
||||
|
||||
if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" {
|
||||
chat_history_updated = self
|
||||
.execute_mcp_filter(&chat_history_updated, agent, request_headers)
|
||||
.await?;
|
||||
} else {
|
||||
chat_history_updated = self
|
||||
.execute_http_filter(&chat_history_updated, agent, request_headers)
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
agent = %agent_name,
|
||||
updated_len = chat_history_updated.len(),
|
||||
"filter completed"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(chat_history_updated)
|
||||
}
|
||||
|
||||
/// Build common MCP headers for requests
|
||||
fn build_mcp_headers(
|
||||
&self,
|
||||
/// Prepare headers shared by all agent/filter requests: removes
|
||||
/// content-length, injects trace context, sets upstream host and retry.
|
||||
fn build_agent_headers(
|
||||
request_headers: &HeaderMap,
|
||||
agent_id: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<HeaderMap, PipelineError> {
|
||||
let mut headers = request_headers.clone();
|
||||
headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
|
@ -167,24 +103,34 @@ impl PipelineProcessor {
|
|||
|
||||
headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
hyper::header::HeaderValue::from_static("3"),
|
||||
);
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Build headers for MCP requests (adds Accept, Content-Type, optional session id).
|
||||
fn build_mcp_headers(
|
||||
&self,
|
||||
request_headers: &HeaderMap,
|
||||
agent_id: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<HeaderMap, PipelineError> {
|
||||
let mut headers = Self::build_agent_headers(request_headers, agent_id)?;
|
||||
|
||||
headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
|
||||
);
|
||||
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
if let Some(sid) = session_id {
|
||||
headers.insert(
|
||||
"mcp-session-id",
|
||||
hyper::header::HeaderValue::from_str(sid).unwrap(),
|
||||
);
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(sid) {
|
||||
headers.insert("mcp-session-id", val);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
|
|
@ -261,14 +207,17 @@ impl PipelineProcessor {
|
|||
Ok(response)
|
||||
}
|
||||
|
||||
/// Build a tools/call JSON-RPC request
|
||||
fn build_tool_call_request(
|
||||
/// Build a tools/call JSON-RPC request with a full body dict and path hint.
|
||||
/// Used by execute_mcp_filter_raw so MCP tools receive the same contract as HTTP filters.
|
||||
fn build_tool_call_request_with_body(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
messages: &[Message],
|
||||
body: &serde_json::Value,
|
||||
path: &str,
|
||||
) -> Result<JsonRpcRequest, PipelineError> {
|
||||
let mut arguments = HashMap::new();
|
||||
arguments.insert("messages".to_string(), serde_json::to_value(messages)?);
|
||||
arguments.insert("body".to_string(), serde_json::to_value(body)?);
|
||||
arguments.insert("path".to_string(), serde_json::to_value(path)?);
|
||||
|
||||
let mut params = HashMap::new();
|
||||
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
|
||||
|
|
@ -282,35 +231,28 @@ impl PipelineProcessor {
|
|||
})
|
||||
}
|
||||
|
||||
/// Send request to a specific agent and return the response content
|
||||
#[instrument(
|
||||
skip(self, messages, agent, request_headers),
|
||||
fields(
|
||||
agent_id = %agent.id,
|
||||
filter_name = %agent.id,
|
||||
message_count = messages.len()
|
||||
)
|
||||
)]
|
||||
async fn execute_mcp_filter(
|
||||
/// Like execute_mcp_filter_raw but passes the full raw body dict + path hint as MCP tool arguments.
|
||||
/// The MCP tool receives (body: dict, path: str) and returns the modified body dict.
|
||||
async fn execute_mcp_filter_raw(
|
||||
&mut self,
|
||||
messages: &[Message],
|
||||
raw_bytes: &[u8],
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
// Set service name for this filter span
|
||||
request_path: &str,
|
||||
) -> Result<Bytes, PipelineError> {
|
||||
set_service_name(operation_component::AGENT_FILTER);
|
||||
|
||||
// Update current span name to include filter name
|
||||
use opentelemetry::trace::get_active_span;
|
||||
get_active_span(|span| {
|
||||
span.update_name(format!("execute_mcp_filter ({})", agent.id));
|
||||
span.update_name(format!("execute_mcp_filter_raw ({})", agent.id));
|
||||
});
|
||||
|
||||
// Get or create MCP session
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(raw_bytes).map_err(PipelineError::ParseError)?;
|
||||
|
||||
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
|
||||
session_id.clone()
|
||||
} else {
|
||||
let session_id = self.get_new_session_id(&agent.id, request_headers).await;
|
||||
let session_id = self.get_new_session_id(&agent.id, request_headers).await?;
|
||||
self.agent_id_session_map
|
||||
.insert(agent.id.clone(), session_id.clone());
|
||||
session_id
|
||||
|
|
@ -321,11 +263,10 @@ impl PipelineProcessor {
|
|||
mcp_session_id, agent.id
|
||||
);
|
||||
|
||||
// Build JSON-RPC request
|
||||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
let json_rpc_request = self.build_tool_call_request(tool_name, messages)?;
|
||||
let json_rpc_request =
|
||||
self.build_tool_call_request_with_body(tool_name, &body, request_path)?;
|
||||
|
||||
// Build headers
|
||||
let agent_headers =
|
||||
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id))?;
|
||||
|
||||
|
|
@ -335,7 +276,6 @@ impl PipelineProcessor {
|
|||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
// Handle HTTP errors
|
||||
if !http_status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
return Err(if http_status.is_client_error() {
|
||||
|
|
@ -353,20 +293,12 @@ impl PipelineProcessor {
|
|||
});
|
||||
}
|
||||
|
||||
info!(
|
||||
"Response from agent {}: {}",
|
||||
agent.id,
|
||||
String::from_utf8_lossy(&response_bytes)
|
||||
);
|
||||
|
||||
// Parse SSE response
|
||||
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
|
||||
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
|
||||
let response_result = response
|
||||
.result
|
||||
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
|
||||
|
||||
// Check if error field is set in response result
|
||||
if response_result
|
||||
.get("isError")
|
||||
.and_then(|v| v.as_bool())
|
||||
|
|
@ -388,21 +320,28 @@ impl PipelineProcessor {
|
|||
});
|
||||
}
|
||||
|
||||
// Extract structured content and parse messages
|
||||
let response_json = response_result
|
||||
// FastMCP puts structured Pydantic return values in structuredContent.result,
|
||||
// but plain dicts land in content[0].text as a JSON string. Try both.
|
||||
let result = if let Some(structured) = response_result
|
||||
.get("structuredContent")
|
||||
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
|
||||
.and_then(|v| v.get("result"))
|
||||
.cloned()
|
||||
{
|
||||
structured
|
||||
} else {
|
||||
let text = response_result
|
||||
.get("content")
|
||||
.and_then(|v| v.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|v| v.get("text"))
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
|
||||
serde_json::from_str(text).map_err(PipelineError::ParseError)?
|
||||
};
|
||||
|
||||
let messages: Vec<Message> = response_json
|
||||
.get("result")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| PipelineError::NoMessagesInResponse(agent.id.clone()))?
|
||||
.iter()
|
||||
.map(|msg_value| serde_json::from_value(msg_value.clone()))
|
||||
.collect::<Result<Vec<Message>, _>>()
|
||||
.map_err(PipelineError::ParseError)?;
|
||||
|
||||
Ok(messages)
|
||||
Ok(Bytes::from(
|
||||
serde_json::to_vec(&result).map_err(PipelineError::ParseError)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Build an initialize JSON-RPC request
|
||||
|
|
@ -464,18 +403,19 @@ impl PipelineProcessor {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_new_session_id(&self, agent_id: &str, request_headers: &HeaderMap) -> String {
|
||||
async fn get_new_session_id(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<String, PipelineError> {
|
||||
info!("initializing MCP session for agent {}", agent_id);
|
||||
|
||||
let initialize_request = self.build_initialize_request();
|
||||
let headers = self
|
||||
.build_mcp_headers(request_headers, agent_id, None)
|
||||
.expect("Failed to build headers for initialization");
|
||||
let headers = self.build_mcp_headers(request_headers, agent_id, None)?;
|
||||
|
||||
let response = self
|
||||
.send_mcp_request(&initialize_request, &headers, agent_id)
|
||||
.await
|
||||
.expect("Failed to initialize MCP session");
|
||||
.await?;
|
||||
|
||||
info!("initialize response status: {}", response.status());
|
||||
|
||||
|
|
@ -483,8 +423,13 @@ impl PipelineProcessor {
|
|||
.headers()
|
||||
.get("mcp-session-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.expect("No mcp-session-id in response")
|
||||
.to_string();
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| {
|
||||
PipelineError::NoContentInResponse(format!(
|
||||
"No mcp-session-id header in initialize response from agent {}",
|
||||
agent_id
|
||||
))
|
||||
})?;
|
||||
|
||||
info!(
|
||||
"created new MCP session for agent {}: {}",
|
||||
|
|
@ -493,88 +438,63 @@ impl PipelineProcessor {
|
|||
|
||||
// Send initialized notification
|
||||
self.send_initialized_notification(agent_id, &session_id, &headers)
|
||||
.await
|
||||
.expect("Failed to send initialized notification");
|
||||
.await?;
|
||||
|
||||
session_id
|
||||
Ok(session_id)
|
||||
}
|
||||
|
||||
/// Execute a HTTP-based filter agent
|
||||
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
|
||||
/// Used for input and output filters where the full raw request/response is passed through.
|
||||
/// No MCP protocol wrapping; agent_type is ignored.
|
||||
#[instrument(
|
||||
skip(self, messages, agent, request_headers),
|
||||
skip(self, raw_bytes, agent, request_headers),
|
||||
fields(
|
||||
agent_id = %agent.id,
|
||||
agent_url = %agent.url,
|
||||
filter_name = %agent.id,
|
||||
message_count = messages.len()
|
||||
bytes_len = raw_bytes.len()
|
||||
)
|
||||
)]
|
||||
async fn execute_http_filter(
|
||||
async fn execute_raw_filter(
|
||||
&mut self,
|
||||
messages: &[Message],
|
||||
raw_bytes: &[u8],
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
// Set service name for this filter span
|
||||
request_path: &str,
|
||||
) -> Result<Bytes, PipelineError> {
|
||||
set_service_name(operation_component::AGENT_FILTER);
|
||||
|
||||
// Update current span name to include filter name
|
||||
use opentelemetry::trace::get_active_span;
|
||||
get_active_span(|span| {
|
||||
span.update_name(format!("execute_http_filter ({})", agent.id));
|
||||
span.update_name(format!("execute_raw_filter ({})", agent.id));
|
||||
});
|
||||
|
||||
// Build headers
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
// Inject OpenTelemetry trace context automatically
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
||||
});
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
|
||||
let mut agent_headers = Self::build_agent_headers(request_headers, &agent.id)?;
|
||||
agent_headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
debug!(
|
||||
"Sending HTTP request to agent {} at URL: {}",
|
||||
agent.id, agent.url
|
||||
);
|
||||
// Append the original request path so the filter endpoint encodes the API format.
|
||||
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
|
||||
// -> POST http://host/anonymize/v1/chat/completions
|
||||
let url = format!("{}{}", agent.url, request_path);
|
||||
debug!(agent = %agent.id, url = %url, "sending raw filter request");
|
||||
|
||||
// Send messages array directly as request body
|
||||
let response = self
|
||||
.client
|
||||
.post(&agent.url)
|
||||
.post(&url)
|
||||
.headers(agent_headers)
|
||||
.json(&messages)
|
||||
.body(raw_bytes.to_vec())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
// Handle HTTP errors
|
||||
if !http_status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
return Err(if http_status.is_client_error() {
|
||||
|
|
@ -592,17 +512,56 @@ impl PipelineProcessor {
|
|||
});
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Response from HTTP agent {}: {}",
|
||||
agent.id,
|
||||
String::from_utf8_lossy(&response_bytes)
|
||||
);
|
||||
debug!(agent = %agent.id, bytes_len = response_bytes.len(), "raw filter response received");
|
||||
Ok(response_bytes)
|
||||
}
|
||||
|
||||
// Parse response - expecting array of messages directly
|
||||
let messages: Vec<Message> =
|
||||
serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?;
|
||||
/// Process a chain of raw-bytes filters sequentially.
|
||||
/// Input: raw request or response bytes. Output: filtered bytes.
|
||||
/// Each agent receives the output of the previous one.
|
||||
pub async fn process_raw_filter_chain(
|
||||
&mut self,
|
||||
raw_bytes: &[u8],
|
||||
agent_filter_chain: &AgentFilterChain,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
request_path: &str,
|
||||
) -> Result<Bytes, PipelineError> {
|
||||
let filter_chain = match agent_filter_chain.input_filters.as_ref() {
|
||||
Some(fc) if !fc.is_empty() => fc,
|
||||
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
|
||||
};
|
||||
|
||||
Ok(messages)
|
||||
let mut current_bytes = Bytes::copy_from_slice(raw_bytes);
|
||||
|
||||
for agent_name in filter_chain {
|
||||
debug!(agent = %agent_name, "processing raw filter agent");
|
||||
|
||||
let agent = agent_map
|
||||
.get(agent_name)
|
||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
let agent_type = agent.agent_type.as_deref().unwrap_or("mcp");
|
||||
info!(
|
||||
agent = %agent_name,
|
||||
url = %agent.url,
|
||||
agent_type = %agent_type,
|
||||
bytes_len = current_bytes.len(),
|
||||
"executing raw filter"
|
||||
);
|
||||
|
||||
current_bytes = if agent_type == "mcp" {
|
||||
self.execute_mcp_filter_raw(¤t_bytes, agent, request_headers, request_path)
|
||||
.await?
|
||||
} else {
|
||||
self.execute_raw_filter(¤t_bytes, agent, request_headers, request_path)
|
||||
.await?
|
||||
};
|
||||
|
||||
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
|
||||
}
|
||||
|
||||
Ok(current_bytes)
|
||||
}
|
||||
|
||||
/// Send request to terminal agent and return the raw response for streaming
|
||||
|
|
@ -615,36 +574,15 @@ impl PipelineProcessor {
|
|||
terminal_agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<reqwest::Response, PipelineError> {
|
||||
// let mut request = original_request.clone();
|
||||
original_request.set_messages(messages);
|
||||
|
||||
let request_url = "/v1/chat/completions";
|
||||
|
||||
let request_body = ProviderRequestType::to_bytes(&original_request).unwrap();
|
||||
// let request_body = serde_json::to_string(&request)?;
|
||||
let request_body = ProviderRequestType::to_bytes(&original_request)
|
||||
.map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?;
|
||||
debug!("sending request to terminal agent {}", terminal_agent.id);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
// Inject OpenTelemetry trace context automatically
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
||||
});
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&terminal_agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?,
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
let agent_headers = Self::build_agent_headers(request_headers, &terminal_agent.id)?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
|
|
@ -661,24 +599,13 @@ impl PipelineProcessor {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use mockito::Server;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
|
||||
AgentFilterChain {
|
||||
id: "test-agent".to_string(),
|
||||
filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()),
|
||||
input_filters: Some(agents.iter().map(|s| s.to_string()).collect()),
|
||||
description: None,
|
||||
default: None,
|
||||
}
|
||||
|
|
@ -690,12 +617,19 @@ mod tests {
|
|||
let agent_map = HashMap::new();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Hello")];
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
|
||||
let raw_bytes = serde_json::to_vec(&body).unwrap();
|
||||
|
||||
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
|
||||
|
||||
let result = processor
|
||||
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers)
|
||||
.process_raw_filter_chain(
|
||||
&raw_bytes,
|
||||
&pipeline,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
"/v1/chat/completions",
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
|
|
@ -725,11 +659,12 @@ mod tests {
|
|||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Hello")];
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
|
||||
let raw_bytes = serde_json::to_vec(&body).unwrap();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_mcp_filter(&messages, &agent, &request_headers)
|
||||
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
@ -764,11 +699,12 @@ mod tests {
|
|||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Ping")];
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Ping"}]});
|
||||
let raw_bytes = serde_json::to_vec(&body).unwrap();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_mcp_filter(&messages, &agent, &request_headers)
|
||||
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
@ -816,11 +752,12 @@ mod tests {
|
|||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Hi")];
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hi"}]});
|
||||
let raw_bytes = serde_json::to_vec(&body).unwrap();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_mcp_filter(&messages, &agent, &request_headers)
|
||||
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
@ -7,7 +7,7 @@ use common::configuration::{
|
|||
use hermesllm::apis::openai::Message;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
|
||||
/// Errors that can occur during agent selection
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -37,7 +37,7 @@ impl AgentSelector {
|
|||
}
|
||||
|
||||
/// Find listener by name from the request headers
|
||||
pub async fn find_listener(
|
||||
pub fn find_listener(
|
||||
&self,
|
||||
listener_name: Option<&str>,
|
||||
listeners: &[common::configuration::Listener],
|
||||
|
|
@ -84,7 +84,7 @@ impl AgentSelector {
|
|||
}
|
||||
|
||||
/// Convert agent descriptions to orchestration preferences
|
||||
async fn convert_agent_description_to_orchestration_preferences(
|
||||
fn convert_agent_description_to_orchestration_preferences(
|
||||
&self,
|
||||
agents: &[AgentFilterChain],
|
||||
) -> Vec<AgentUsagePreference> {
|
||||
|
|
@ -121,9 +121,7 @@ impl AgentSelector {
|
|||
return Ok(vec![agents[0].clone()]);
|
||||
}
|
||||
|
||||
let usage_preferences = self
|
||||
.convert_agent_description_to_orchestration_preferences(agents)
|
||||
.await;
|
||||
let usage_preferences = self.convert_agent_description_to_orchestration_preferences(agents);
|
||||
debug!(
|
||||
"Agents usage preferences for orchestration: {}",
|
||||
serde_json::to_string(&usage_preferences).unwrap_or_default()
|
||||
|
|
@ -172,12 +170,14 @@ impl AgentSelector {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::{AgentFilterChain, Listener};
|
||||
use common::configuration::{AgentFilterChain, Listener, ListenerType};
|
||||
|
||||
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
||||
Arc::new(OrchestratorService::new(
|
||||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
crate::router::orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
))
|
||||
}
|
||||
|
||||
|
|
@ -186,14 +186,17 @@ mod tests {
|
|||
id: name.to_string(),
|
||||
description: Some(description.to_string()),
|
||||
default: Some(is_default),
|
||||
filter_chain: Some(vec![name.to_string()]),
|
||||
input_filters: Some(vec![name.to_string()]),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
|
||||
Listener {
|
||||
listener_type: ListenerType::Agent,
|
||||
name: name.to_string(),
|
||||
agents: Some(agents),
|
||||
input_filters: None,
|
||||
output_filters: None,
|
||||
port: 8080,
|
||||
router: None,
|
||||
}
|
||||
|
|
@ -218,9 +221,7 @@ mod tests {
|
|||
let listener2 = create_test_listener("other-listener", vec![]);
|
||||
let listeners = vec![listener1.clone(), listener2];
|
||||
|
||||
let result = selector
|
||||
.find_listener(Some("test-listener"), &listeners)
|
||||
.await;
|
||||
let result = selector.find_listener(Some("test-listener"), &listeners);
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().name, "test-listener");
|
||||
|
|
@ -233,9 +234,7 @@ mod tests {
|
|||
|
||||
let listeners = vec![create_test_listener("other-listener", vec![])];
|
||||
|
||||
let result = selector
|
||||
.find_listener(Some("nonexistent"), &listeners)
|
||||
.await;
|
||||
let result = selector.find_listener(Some("nonexistent"), &listeners);
|
||||
|
||||
assert!(result.is_err());
|
||||
matches!(
|
||||
53
crates/brightstaff/src/handlers/debug.rs
Normal file
53
crates/brightstaff/src/handlers/debug.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use hyper::{Response, StatusCode};
|
||||
|
||||
use super::full;
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct MemStats {
|
||||
allocated_bytes: usize,
|
||||
resident_bytes: usize,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
/// Returns jemalloc memory statistics as JSON.
|
||||
/// Falls back to a stub when the jemalloc feature is disabled.
|
||||
pub async fn memstats() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let stats = get_jemalloc_stats();
|
||||
let json = serde_json::to_string(&stats).unwrap();
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(full(json))
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
#[cfg(feature = "jemalloc")]
|
||||
fn get_jemalloc_stats() -> MemStats {
|
||||
use tikv_jemalloc_ctl::{epoch, stats};
|
||||
|
||||
if let Err(e) = epoch::advance() {
|
||||
return MemStats {
|
||||
allocated_bytes: 0,
|
||||
resident_bytes: 0,
|
||||
error: Some(format!("failed to advance jemalloc epoch: {e}")),
|
||||
};
|
||||
}
|
||||
|
||||
MemStats {
|
||||
allocated_bytes: stats::allocated::read().unwrap_or(0),
|
||||
resident_bytes: stats::resident::read().unwrap_or(0),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "jemalloc"))]
|
||||
fn get_jemalloc_stats() -> MemStats {
|
||||
MemStats {
|
||||
allocated_bytes: 0,
|
||||
resident_bytes: 0,
|
||||
error: Some("jemalloc feature not enabled".to_string()),
|
||||
}
|
||||
}
|
||||
|
|
@ -441,10 +441,8 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
}
|
||||
// Handle str/string conversions
|
||||
"str" | "string" => {
|
||||
if !value.is_string() {
|
||||
return Ok(json!(value.to_string()));
|
||||
}
|
||||
"str" | "string" if !value.is_string() => {
|
||||
return Ok(json!(value.to_string()));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
|
@ -762,7 +760,7 @@ impl ArchFunctionHandler {
|
|||
|
||||
// Keep system message if present
|
||||
if let Some(first) = messages.first() {
|
||||
if first.role == Role::System {
|
||||
if first.role == Role::System || first.role == Role::Developer {
|
||||
if let Some(MessageContent::Text(content)) = &first.content {
|
||||
num_tokens += content.len() / 4; // Approximate 4 chars per token
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,12 +3,11 @@ use std::sync::Arc;
|
|||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hyper::header::HeaderMap;
|
||||
|
||||
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
use crate::handlers::pipeline_processor::PipelineProcessor;
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
use common::errors::BrightStaffError;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::StatusCode;
|
||||
use crate::handlers::agents::pipeline::PipelineProcessor;
|
||||
use crate::handlers::agents::selector::{AgentSelectionError, AgentSelector};
|
||||
use crate::handlers::response::ResponseHandler;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
|
||||
/// Integration test that demonstrates the modular agent chat flow
|
||||
/// This test shows how the three main components work together:
|
||||
/// 1. AgentSelector - selects the appropriate agents based on orchestration
|
||||
|
|
@ -17,12 +16,14 @@ use hyper::StatusCode;
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::{Agent, AgentFilterChain, Listener};
|
||||
use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType};
|
||||
|
||||
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
||||
Arc::new(OrchestratorService::new(
|
||||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
crate::router::orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
))
|
||||
}
|
||||
|
||||
|
|
@ -63,7 +64,7 @@ mod tests {
|
|||
|
||||
let agent_pipeline = AgentFilterChain {
|
||||
id: "terminal-agent".to_string(),
|
||||
filter_chain: Some(vec![
|
||||
input_filters: Some(vec![
|
||||
"filter-agent".to_string(),
|
||||
"terminal-agent".to_string(),
|
||||
]),
|
||||
|
|
@ -72,8 +73,11 @@ mod tests {
|
|||
};
|
||||
|
||||
let listener = Listener {
|
||||
listener_type: ListenerType::Agent,
|
||||
name: "test-listener".to_string(),
|
||||
agents: Some(vec![agent_pipeline.clone()]),
|
||||
input_filters: None,
|
||||
output_filters: None,
|
||||
port: 8080,
|
||||
router: None,
|
||||
};
|
||||
|
|
@ -82,9 +86,7 @@ mod tests {
|
|||
let messages = vec![create_test_message(Role::User, "Hello world!")];
|
||||
|
||||
// Test 1: Agent Selection
|
||||
let selected_listener = agent_selector
|
||||
.find_listener(Some("test-listener"), &listeners)
|
||||
.await;
|
||||
let selected_listener = agent_selector.find_listener(Some("test-listener"), &listeners);
|
||||
|
||||
assert!(selected_listener.is_ok());
|
||||
let listener = selected_listener.unwrap();
|
||||
|
|
@ -106,58 +108,51 @@ mod tests {
|
|||
// Create a pipeline with empty filter chain to avoid network calls
|
||||
let test_pipeline = AgentFilterChain {
|
||||
id: "terminal-agent".to_string(),
|
||||
filter_chain: Some(vec![]), // Empty filter chain - no network calls needed
|
||||
input_filters: Some(vec![]), // Empty filter chain - no network calls needed
|
||||
description: None,
|
||||
default: None,
|
||||
};
|
||||
|
||||
let headers = HeaderMap::new();
|
||||
let request_bytes = serde_json::to_vec(&request).expect("failed to serialize request");
|
||||
let result = pipeline_processor
|
||||
.process_filter_chain(&request.messages, &test_pipeline, &agent_map, &headers)
|
||||
.process_raw_filter_chain(
|
||||
&request_bytes,
|
||||
&test_pipeline,
|
||||
&agent_map,
|
||||
&headers,
|
||||
"/v1/chat/completions",
|
||||
)
|
||||
.await;
|
||||
|
||||
println!("Pipeline processing result: {:?}", result);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let processed_messages = result.unwrap();
|
||||
// With empty filter chain, should return the original messages unchanged
|
||||
assert_eq!(processed_messages.len(), 1);
|
||||
if let Some(MessageContent::Text(content)) = &processed_messages[0].content {
|
||||
let processed_bytes = result.unwrap();
|
||||
// With empty filter chain, should return the original bytes unchanged
|
||||
let processed_request: ChatCompletionsRequest =
|
||||
serde_json::from_slice(&processed_bytes).expect("failed to deserialize response");
|
||||
assert_eq!(processed_request.messages.len(), 1);
|
||||
if let Some(MessageContent::Text(content)) = &processed_request.messages[0].content {
|
||||
assert_eq!(content, "Hello world!");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
}
|
||||
|
||||
// Test 4: Error Response Creation
|
||||
let err = BrightStaffError::ModelNotFound("gpt-5-secret".to_string());
|
||||
let response = err.into_response();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
// Helper to extract body as JSON
|
||||
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
|
||||
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
|
||||
assert_eq!(body["error"]["code"], "ModelNotFound");
|
||||
assert_eq!(
|
||||
body["error"]["details"]["rejected_model_id"],
|
||||
"gpt-5-secret"
|
||||
);
|
||||
assert!(body["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.contains("gpt-5-secret"));
|
||||
let error_response = ResponseHandler::create_bad_request("Test error");
|
||||
assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST);
|
||||
|
||||
println!("✅ All modular components working correctly!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_handling_flow() {
|
||||
let router_service = create_test_orchestrator_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let orchestrator_service = create_test_orchestrator_service();
|
||||
let agent_selector = AgentSelector::new(orchestrator_service);
|
||||
|
||||
// Test listener not found
|
||||
let result = agent_selector.find_listener(Some("nonexistent"), &[]).await;
|
||||
let result = agent_selector.find_listener(Some("nonexistent"), &[]);
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
|
|
@ -165,21 +160,12 @@ mod tests {
|
|||
AgentSelectionError::ListenerNotFound(_)
|
||||
));
|
||||
|
||||
let technical_reason = "Database connection timed out";
|
||||
let err = BrightStaffError::InternalServerError(technical_reason.to_string());
|
||||
|
||||
let response = err.into_response();
|
||||
|
||||
// --- 1. EXTRACT BYTES ---
|
||||
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
|
||||
|
||||
// --- 2. DECLARE body_json HERE ---
|
||||
let body_json: serde_json::Value =
|
||||
serde_json::from_slice(&body_bytes).expect("Failed to parse JSON body");
|
||||
|
||||
// --- 3. USE body_json ---
|
||||
assert_eq!(body_json["error"]["code"], "InternalServerError");
|
||||
assert_eq!(body_json["error"]["details"]["reason"], technical_reason);
|
||||
// Test error response creation
|
||||
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
|
||||
assert_eq!(
|
||||
error_response.status(),
|
||||
hyper::StatusCode::INTERNAL_SERVER_ERROR
|
||||
);
|
||||
|
||||
println!("✅ Error handling working correctly!");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,553 +0,0 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{ModelAlias, SpanAttributes};
|
||||
use common::consts::{
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use hermesllm::apis::openai_responses::InputParam;
|
||||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response};
|
||||
use opentelemetry::global;
|
||||
use opentelemetry::trace::get_active_span;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::handlers::utils::{
|
||||
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
||||
};
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
use crate::state::{
|
||||
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
|
||||
};
|
||||
use crate::tracing::{
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
|
||||
};
|
||||
|
||||
use common::errors::BrightStaffError;
|
||||
|
||||
pub async fn llm_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
span_attributes: Arc<Option<SpanAttributes>>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
let request_id: String = match request_headers
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
{
|
||||
Some(id) => id,
|
||||
None => uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
let custom_attrs =
|
||||
collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref());
|
||||
|
||||
// Create a span with request_id that will be included in all log lines
|
||||
let request_span = info_span!(
|
||||
"llm",
|
||||
component = "llm",
|
||||
request_id = %request_id,
|
||||
http.method = %request.method(),
|
||||
http.path = %request_path,
|
||||
llm.model = tracing::field::Empty,
|
||||
llm.tools = tracing::field::Empty,
|
||||
llm.user_message_preview = tracing::field::Empty,
|
||||
llm.temperature = tracing::field::Empty,
|
||||
);
|
||||
|
||||
// Execute the rest of the handler inside the span
|
||||
llm_chat_inner(
|
||||
request,
|
||||
router_service,
|
||||
full_qualified_llm_provider_url,
|
||||
model_aliases,
|
||||
llm_providers,
|
||||
custom_attrs,
|
||||
state_storage,
|
||||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn llm_chat_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
custom_attrs: HashMap<String, String>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
request_id: String,
|
||||
request_path: String,
|
||||
mut request_headers: hyper::HeaderMap,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// Set service name for LLM operations
|
||||
set_service_name(operation_component::LLM);
|
||||
get_active_span(|span| {
|
||||
for (key, value) in &custom_attrs {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
||||
}
|
||||
});
|
||||
|
||||
// Extract or generate traceparent - this establishes the trace context for all spans
|
||||
let traceparent: String = match request_headers
|
||||
.get(TRACE_PARENT_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
{
|
||||
Some(tp) => tp,
|
||||
None => {
|
||||
use uuid::Uuid;
|
||||
let trace_id = Uuid::new_v4().to_string().replace("-", "");
|
||||
let generated_tp = format!("00-{}-0000000000000000-01", trace_id);
|
||||
warn!(
|
||||
generated_traceparent = %generated_tp,
|
||||
"TRACE_PARENT header missing, generated new traceparent"
|
||||
);
|
||||
generated_tp
|
||||
}
|
||||
};
|
||||
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
body = %String::from_utf8_lossy(&chat_request_bytes),
|
||||
"request body received"
|
||||
);
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
error = %err,
|
||||
"failed to parse request as ProviderRequestType"
|
||||
);
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
// === v1/responses state management: Extract input items early ===
|
||||
let mut original_input_items = Vec::new();
|
||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
|
||||
let is_responses_api_client = matches!(
|
||||
client_api,
|
||||
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
|
||||
);
|
||||
|
||||
// If model is not specified in the request, resolve from default provider
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let model_from_request = if model_from_request.is_empty() {
|
||||
match llm_providers.read().await.default() {
|
||||
Some(default_provider) => {
|
||||
let default_model = default_provider.name.clone();
|
||||
info!(default_model = %default_model, "no model specified in request, using default provider");
|
||||
client_request.set_model(default_model.clone());
|
||||
default_model
|
||||
}
|
||||
None => {
|
||||
let err_msg = "No model specified in request and no default provider configured";
|
||||
warn!("{}", err_msg);
|
||||
return Ok(BrightStaffError::NoModelSpecified.into_response());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
model_from_request
|
||||
};
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let temperature = client_request.get_temperature();
|
||||
let is_streaming_request = client_request.is_streaming();
|
||||
let alias_resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
||||
|
||||
// Validate that the requested model exists in configuration
|
||||
// This matches the validation in llm_gateway routing.rs
|
||||
if llm_providers
|
||||
.read()
|
||||
.await
|
||||
.get(&alias_resolved_model)
|
||||
.is_none()
|
||||
{
|
||||
warn!(model = %alias_resolved_model, "model not found in configured providers");
|
||||
return Ok(BrightStaffError::ModelNotFound(alias_resolved_model).into_response());
|
||||
}
|
||||
|
||||
// Handle provider/model slug format (e.g., "openai/gpt-4")
|
||||
// Extract just the model name for upstream (providers don't understand the slug)
|
||||
let model_name_only = if let Some((_, model)) = alias_resolved_model.split_once('/') {
|
||||
model.to_string()
|
||||
} else {
|
||||
alias_resolved_model.clone()
|
||||
};
|
||||
|
||||
// Extract tool names and user message preview for span attributes
|
||||
let tool_names = client_request.get_tool_names();
|
||||
let user_message_preview = client_request
|
||||
.get_recent_user_message()
|
||||
.map(|msg| truncate_message(&msg, 50));
|
||||
let span = tracing::Span::current();
|
||||
if let Some(temp) = temperature {
|
||||
span.record(tracing_llm::TEMPERATURE, tracing::field::display(temp));
|
||||
}
|
||||
if let Some(tools) = &tool_names {
|
||||
let formatted_tools = tools
|
||||
.iter()
|
||||
.map(|name| format!("{}(...)", name))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
span.record(tracing_llm::TOOLS, formatted_tools.as_str());
|
||||
}
|
||||
if let Some(preview) = &user_message_preview {
|
||||
span.record(tracing_llm::USER_MESSAGE_PREVIEW, preview.as_str());
|
||||
}
|
||||
|
||||
// Extract messages for signal analysis (clone before moving client_request)
|
||||
let messages_for_signals = Some(client_request.get_messages());
|
||||
|
||||
// Set the model to just the model name (without provider prefix)
|
||||
// This ensures upstream receives "gpt-4" not "openai/gpt-4"
|
||||
client_request.set_model(model_name_only.clone());
|
||||
if client_request.remove_metadata_key("plano_preference_config") {
|
||||
debug!("removed plano_preference_config from metadata");
|
||||
}
|
||||
|
||||
// === v1/responses state management: Determine upstream API and combine input if needed ===
|
||||
// Do this BEFORE routing since routing consumes the request
|
||||
// Only process state if state_storage is configured
|
||||
let mut should_manage_state = false;
|
||||
if is_responses_api_client {
|
||||
if let (
|
||||
ProviderRequestType::ResponsesAPIRequest(ref mut responses_req),
|
||||
Some(ref state_store),
|
||||
) = (&mut client_request, &state_storage)
|
||||
{
|
||||
// Extract original input once
|
||||
original_input_items = extract_input_items(&responses_req.input);
|
||||
|
||||
// Get the upstream path and check if it's ResponsesAPI
|
||||
let upstream_path = get_upstream_path(
|
||||
&llm_providers,
|
||||
&alias_resolved_model,
|
||||
&request_path,
|
||||
&alias_resolved_model,
|
||||
is_streaming_request,
|
||||
)
|
||||
.await;
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
|
||||
|
||||
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
|
||||
should_manage_state = !matches!(
|
||||
upstream_api,
|
||||
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
|
||||
);
|
||||
|
||||
if should_manage_state {
|
||||
// Retrieve and combine conversation history if previous_response_id exists
|
||||
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
|
||||
match retrieve_and_combine_input(
|
||||
state_store.clone(),
|
||||
prev_resp_id,
|
||||
original_input_items, // Pass ownership instead of cloning
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(combined_input) => {
|
||||
// Update both the request and original_input_items
|
||||
responses_req.input = InputParam::Items(combined_input.clone());
|
||||
original_input_items = combined_input;
|
||||
info!(
|
||||
items = original_input_items.len(),
|
||||
"updated request with conversation history"
|
||||
);
|
||||
}
|
||||
Err(StateStorageError::NotFound(_)) => {
|
||||
// Return 409 Conflict when previous_response_id not found
|
||||
warn!(previous_response_id = %prev_resp_id, "previous response_id not found");
|
||||
return Ok(BrightStaffError::ConversationStateNotFound(
|
||||
prev_resp_id.to_string(),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
Err(e) => {
|
||||
// Log warning but continue on other storage errors
|
||||
warn!(
|
||||
previous_response_id = %prev_resp_id,
|
||||
error = %e,
|
||||
"failed to retrieve conversation state"
|
||||
);
|
||||
// Restore original_input_items since we passed ownership
|
||||
original_input_items = extract_input_items(&responses_req.input);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("upstream supports ResponsesAPI natively");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize request for upstream BEFORE router consumes it
|
||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||
|
||||
// Determine routing using the dedicated router_chat module
|
||||
// This gets its own span for latency and error tracking
|
||||
let routing_span = info_span!(
|
||||
"routing",
|
||||
component = "routing",
|
||||
http.method = "POST",
|
||||
http.target = %request_path,
|
||||
model.requested = %model_from_request,
|
||||
model.alias_resolved = %alias_resolved_model,
|
||||
route.selected_model = tracing::field::Empty,
|
||||
routing.determination_ms = tracing::field::Empty,
|
||||
);
|
||||
let routing_result = match async {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
router_chat_get_upstream_model(
|
||||
router_service,
|
||||
client_request, // Pass the original request - router_chat will convert it
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.instrument(routing_span)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
return Ok(BrightStaffError::ForwardedError {
|
||||
status_code: err.status_code,
|
||||
message: err.message,
|
||||
}
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
// Determine final model to use
|
||||
// Router returns "none" as a sentinel value when it doesn't select a specific model
|
||||
let router_selected_model = routing_result.model_name;
|
||||
let resolved_model = if router_selected_model != "none" {
|
||||
// Router selected a specific model via routing preferences
|
||||
router_selected_model
|
||||
} else {
|
||||
// Router returned "none" sentinel, use validated resolved_model from request
|
||||
alias_resolved_model.clone()
|
||||
};
|
||||
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
|
||||
|
||||
let span_name = if model_from_request == resolved_model {
|
||||
format!("POST {} {}", request_path, resolved_model)
|
||||
} else {
|
||||
format!(
|
||||
"POST {} {} -> {}",
|
||||
request_path, model_from_request, resolved_model
|
||||
)
|
||||
};
|
||||
get_active_span(|span| {
|
||||
span.update_name(span_name.clone());
|
||||
});
|
||||
|
||||
debug!(
|
||||
url = %full_qualified_llm_provider_url,
|
||||
provider_hint = %resolved_model,
|
||||
upstream_model = %model_name_only,
|
||||
"Routing to upstream"
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(&resolved_model).unwrap(),
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
// remove content-length header if it exists
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
// Inject current LLM span's trace context so upstream spans are children of plano(llm)
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
|
||||
});
|
||||
|
||||
// Capture start time right before sending request to upstream
|
||||
let request_start_time = std::time::Instant::now();
|
||||
let _request_start_system_time = std::time::SystemTime::now();
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(&full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(client_request_bytes_for_upstream)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
return Ok(BrightStaffError::InternalServerError(format!(
|
||||
"Failed to send request: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
// copy over the headers and status code from the original response
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
let mut response = Response::builder().status(upstream_status);
|
||||
let headers = response.headers_mut().unwrap();
|
||||
for (header_name, header_value) in response_headers.iter() {
|
||||
headers.insert(header_name, header_value.clone());
|
||||
}
|
||||
|
||||
// Build LLM span with actual status code using constants
|
||||
let byte_stream = llm_response.bytes_stream();
|
||||
|
||||
// Create base processor for metrics and tracing
|
||||
let base_processor = ObservableStreamProcessor::new(
|
||||
operation_component::LLM,
|
||||
span_name,
|
||||
request_start_time,
|
||||
messages_for_signals,
|
||||
);
|
||||
|
||||
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
|
||||
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
|
||||
let streaming_response = if let (true, false, Some(state_store)) = (
|
||||
should_manage_state,
|
||||
original_input_items.is_empty(),
|
||||
state_storage,
|
||||
) {
|
||||
// Extract Content-Encoding header to handle decompression for state parsing
|
||||
let content_encoding = response_headers
|
||||
.get("content-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Wrap with state management processor to store state after response completes
|
||||
let state_processor = ResponsesStateProcessor::new(
|
||||
base_processor,
|
||||
state_store,
|
||||
original_input_items,
|
||||
alias_resolved_model.clone(),
|
||||
resolved_model.clone(),
|
||||
is_streaming_request,
|
||||
false, // Not OpenAI upstream since should_manage_state is true
|
||||
content_encoding,
|
||||
request_id,
|
||||
);
|
||||
create_streaming_response(byte_stream, state_processor, 16)
|
||||
} else {
|
||||
// Use base processor without state management
|
||||
create_streaming_response(byte_stream, base_processor, 16)
|
||||
};
|
||||
|
||||
match response.body(streaming_response.body) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => Ok(BrightStaffError::InternalServerError(format!(
|
||||
"Failed to create response: {}",
|
||||
err
|
||||
))
|
||||
.into_response()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolves model aliases by looking up the requested model in the model_aliases map.
|
||||
/// Returns the target model if an alias is found, otherwise returns the original model.
|
||||
fn resolve_model_alias(
|
||||
model_from_request: &str,
|
||||
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
) -> String {
|
||||
if let Some(aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = aliases.get(model_from_request) {
|
||||
debug!(
|
||||
"Model Alias: 'From {}' -> 'To {}'",
|
||||
model_from_request, model_alias.target
|
||||
);
|
||||
return model_alias.target.clone();
|
||||
}
|
||||
}
|
||||
model_from_request.to_string()
|
||||
}
|
||||
|
||||
/// Calculates the upstream path for the provider based on the model name.
|
||||
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
|
||||
/// then uses target_endpoint_for_provider to calculate the correct upstream path.
|
||||
async fn get_upstream_path(
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
request_path: &str,
|
||||
resolved_model: &str,
|
||||
is_streaming: bool,
|
||||
) -> String {
|
||||
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
|
||||
|
||||
// Calculate the upstream path using the proper API
|
||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
|
||||
.expect("Should have valid API endpoint");
|
||||
|
||||
client_api.target_endpoint_for_provider(
|
||||
&provider_id,
|
||||
request_path,
|
||||
resolved_model,
|
||||
is_streaming,
|
||||
base_url_path_prefix.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
|
||||
async fn get_provider_info(
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
) -> (hermesllm::ProviderId, Option<String>) {
|
||||
let providers_lock = llm_providers.read().await;
|
||||
|
||||
// Try to find by model name or provider name using LlmProviders::get
|
||||
// This handles both "gpt-4" and "openai/gpt-4" formats
|
||||
if let Some(provider) = providers_lock.get(model_name) {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
return (provider_id, prefix);
|
||||
}
|
||||
|
||||
// Fall back to default provider
|
||||
if let Some(provider) = providers_lock.default() {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
(provider_id, prefix)
|
||||
} else {
|
||||
// Last resort: use OpenAI as hardcoded fallback
|
||||
warn!("No default provider found, falling back to OpenAI");
|
||||
(hermesllm::ProviderId::OpenAI, None)
|
||||
}
|
||||
}
|
||||
1019
crates/brightstaff/src/handlers/llm/mod.rs
Normal file
1019
crates/brightstaff/src/handlers/llm/mod.rs
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,15 +1,34 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use common::configuration::TopLevelRoutingPreference;
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hermesllm::ProviderRequestType;
|
||||
use hyper::StatusCode;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::metrics as bs_metrics;
|
||||
use crate::metrics::labels as metric_labels;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::streaming::truncate_message;
|
||||
use crate::tracing::routing;
|
||||
|
||||
/// Classify a request path (already stripped of `/agents` or `/routing` by
|
||||
/// the caller) into the fixed `route` label used on routing metrics.
|
||||
fn route_label_for_path(request_path: &str) -> &'static str {
|
||||
if request_path.starts_with("/agents") {
|
||||
metric_labels::ROUTE_AGENT
|
||||
} else if request_path.starts_with("/routing") {
|
||||
metric_labels::ROUTE_ROUTING
|
||||
} else {
|
||||
metric_labels::ROUTE_LLM
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RoutingResult {
|
||||
/// Primary model to use (first in the ranked list).
|
||||
pub model_name: String,
|
||||
/// Full ranked list — use subsequent entries as fallbacks on 429/5xx.
|
||||
pub models: Vec<String>,
|
||||
pub route_name: Option<String>,
|
||||
}
|
||||
|
||||
pub struct RoutingError {
|
||||
|
|
@ -32,15 +51,12 @@ impl RoutingError {
|
|||
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
|
||||
/// * `Err(RoutingError)` - Contains error details and optional span ID
|
||||
pub async fn router_chat_get_upstream_model(
|
||||
router_service: Arc<RouterService>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
client_request: ProviderRequestType,
|
||||
traceparent: &str,
|
||||
request_path: &str,
|
||||
request_id: &str,
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
) -> Result<RoutingResult, RoutingError> {
|
||||
// Clone metadata for routing before converting (which consumes client_request)
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
// Convert to ChatCompletionsRequest for routing (regardless of input type)
|
||||
let chat_request = match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
|
|
@ -75,17 +91,6 @@ pub async fn router_chat_get_upstream_model(
|
|||
"router request"
|
||||
);
|
||||
|
||||
// Extract usage preferences from metadata
|
||||
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("plano_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
|
||||
// Prepare log message with latest message from chat request
|
||||
let latest_message_for_log = chat_request
|
||||
.messages
|
||||
|
|
@ -96,19 +101,9 @@ pub async fn router_chat_get_upstream_model(
|
|||
.map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n"))
|
||||
});
|
||||
|
||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
|
||||
let truncated: String = latest_message_for_log
|
||||
.chars()
|
||||
.take(MAX_MESSAGE_LENGTH)
|
||||
.collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
latest_message_for_log
|
||||
};
|
||||
let latest_message_for_log = truncate_message(&latest_message_for_log, 50);
|
||||
|
||||
info!(
|
||||
has_usage_preferences = usage_preferences.is_some(),
|
||||
path = %request_path,
|
||||
latest_message = %latest_message_for_log,
|
||||
"processing router request"
|
||||
|
|
@ -117,39 +112,59 @@ pub async fn router_chat_get_upstream_model(
|
|||
// Capture start time for routing span
|
||||
let routing_start_time = std::time::Instant::now();
|
||||
|
||||
// Attempt to determine route using the router service
|
||||
let routing_result = router_service
|
||||
let routing_result = orchestrator_service
|
||||
.determine_route(
|
||||
&chat_request.messages,
|
||||
traceparent,
|
||||
usage_preferences,
|
||||
inline_routing_preferences,
|
||||
request_id,
|
||||
)
|
||||
.await;
|
||||
|
||||
let determination_ms = routing_start_time.elapsed().as_millis() as i64;
|
||||
let determination_elapsed = routing_start_time.elapsed();
|
||||
let determination_ms = determination_elapsed.as_millis() as i64;
|
||||
let current_span = tracing::Span::current();
|
||||
current_span.record(routing::ROUTE_DETERMINATION_MS, determination_ms);
|
||||
let route_label = route_label_for_path(request_path);
|
||||
|
||||
match routing_result {
|
||||
Ok(route) => match route {
|
||||
Some((_, model_name)) => {
|
||||
Some((route_name, ranked_models)) => {
|
||||
let model_name = ranked_models.first().cloned().unwrap_or_default();
|
||||
current_span.record("route.selected_model", model_name.as_str());
|
||||
Ok(RoutingResult { model_name })
|
||||
bs_metrics::record_router_decision(
|
||||
route_label,
|
||||
&model_name,
|
||||
false,
|
||||
determination_elapsed,
|
||||
);
|
||||
Ok(RoutingResult {
|
||||
model_name,
|
||||
models: ranked_models,
|
||||
route_name: Some(route_name),
|
||||
})
|
||||
}
|
||||
None => {
|
||||
// No route determined, return sentinel value "none"
|
||||
// This signals to llm.rs to use the original validated request model
|
||||
current_span.record("route.selected_model", "none");
|
||||
info!("no route determined, using default model");
|
||||
bs_metrics::record_router_decision(
|
||||
route_label,
|
||||
"none",
|
||||
true,
|
||||
determination_elapsed,
|
||||
);
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name: "none".to_string(),
|
||||
models: vec!["none".to_string()],
|
||||
route_name: None,
|
||||
})
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
current_span.record("route.selected_model", "unknown");
|
||||
bs_metrics::record_router_decision(route_label, "unknown", true, determination_elapsed);
|
||||
Err(RoutingError::internal_error(format!(
|
||||
"Failed to determine route: {}",
|
||||
err
|
||||
|
|
@ -1,13 +1,58 @@
|
|||
pub mod agent_chat_completions;
|
||||
pub mod agent_selector;
|
||||
pub mod agents;
|
||||
pub mod debug;
|
||||
pub mod function_calling;
|
||||
pub mod jsonrpc;
|
||||
pub mod llm;
|
||||
pub mod models;
|
||||
pub mod pipeline_processor;
|
||||
pub mod response_handler;
|
||||
pub mod router_chat;
|
||||
pub mod utils;
|
||||
pub mod response;
|
||||
pub mod routing_service;
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::consts::TRACE_PARENT_HEADER;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Empty, Full};
|
||||
use hyper::Request;
|
||||
use tracing::warn;
|
||||
|
||||
/// Wrap a chunk into a `BoxBody` for hyper responses.
|
||||
pub fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// An empty HTTP body (used for 404 / OPTIONS responses).
|
||||
pub fn empty() -> BoxBody<Bytes, hyper::Error> {
|
||||
Empty::<Bytes>::new()
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Extract request ID from incoming request headers, or generate a new UUID v4.
|
||||
pub fn extract_request_id<T>(request: &Request<T>) -> String {
|
||||
request
|
||||
.headers()
|
||||
.get(common::consts::REQUEST_ID_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
|
||||
}
|
||||
|
||||
/// Extract or generate a W3C `traceparent` header value.
|
||||
pub fn extract_or_generate_traceparent(headers: &hyper::HeaderMap) -> String {
|
||||
headers
|
||||
.get(TRACE_PARENT_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| {
|
||||
let trace_id = uuid::Uuid::new_v4().to_string().replace("-", "");
|
||||
let tp = format!("00-{}-0000000000000000-01", trace_id);
|
||||
warn!(
|
||||
generated_traceparent = %tp,
|
||||
"TRACE_PARENT header missing, generated new traceparent"
|
||||
);
|
||||
tp
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
use bytes::Bytes;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use hyper::{Response, StatusCode};
|
||||
use serde_json;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::full;
|
||||
|
||||
pub async fn list_models(
|
||||
llm_providers: Arc<tokio::sync::RwLock<LlmProviders>>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
|
|
@ -12,27 +13,15 @@ pub async fn list_models(
|
|||
let models = prov.to_models();
|
||||
|
||||
match serde_json::to_string(&models) {
|
||||
Ok(json) => {
|
||||
let body = Full::new(Bytes::from(json))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
Err(_) => {
|
||||
let body = Full::new(Bytes::from_static(
|
||||
b"{\"error\":\"Failed to serialize models\"}",
|
||||
))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
Ok(json) => Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(full(json))
|
||||
.unwrap(),
|
||||
Err(_) => Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(full("{\"error\":\"Failed to serialize models\"}"))
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use hermesllm::apis::OpenAIApi;
|
|||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use hermesllm::SseEvent;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use http_body_util::StreamBody;
|
||||
use hyper::body::Frame;
|
||||
use hyper::Response;
|
||||
use tokio::sync::mpsc;
|
||||
|
|
@ -12,6 +12,8 @@ use tokio_stream::wrappers::ReceiverStream;
|
|||
use tokio_stream::StreamExt;
|
||||
use tracing::{info, warn, Instrument};
|
||||
|
||||
use super::full;
|
||||
|
||||
/// Service for handling HTTP responses and streaming
|
||||
pub struct ResponseHandler;
|
||||
|
||||
|
|
@ -22,9 +24,40 @@ impl ResponseHandler {
|
|||
|
||||
/// Create a full response body from bytes
|
||||
pub fn create_full_body<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
full(chunk)
|
||||
}
|
||||
|
||||
/// Create a JSON error response with BAD_REQUEST status
|
||||
pub fn create_json_error_response(
|
||||
json: &serde_json::Value,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let body = Self::create_full_body(json.to_string());
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = hyper::StatusCode::BAD_REQUEST;
|
||||
response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
response
|
||||
}
|
||||
|
||||
/// Create a BAD_REQUEST error response with a message
|
||||
pub fn create_bad_request(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let json = serde_json::json!({"error": message});
|
||||
Self::create_json_error_response(&json)
|
||||
}
|
||||
|
||||
/// Create an INTERNAL_SERVER_ERROR response with a message
|
||||
pub fn create_internal_error(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let json = serde_json::json!({"error": message});
|
||||
let body = Self::create_full_body(json.to_string());
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = hyper::StatusCode::INTERNAL_SERVER_ERROR;
|
||||
response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
response
|
||||
}
|
||||
|
||||
/// Create a streaming response from a reqwest response.
|
||||
|
|
@ -112,7 +145,9 @@ impl ResponseHandler {
|
|||
let upstream_api =
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
||||
let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).unwrap();
|
||||
let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).map_err(|e| {
|
||||
BrightStaffError::StreamError(format!("Failed to parse SSE stream: {}", e))
|
||||
})?;
|
||||
let mut accumulated_text = String::new();
|
||||
|
||||
for sse_event in sse_iter {
|
||||
|
|
@ -122,7 +157,13 @@ impl ResponseHandler {
|
|||
}
|
||||
|
||||
let transformed_event =
|
||||
SseEvent::try_from((sse_event, &client_api, &upstream_api)).unwrap();
|
||||
match SseEvent::try_from((sse_event, &client_api, &upstream_api)) {
|
||||
Ok(event) => event,
|
||||
Err(e) => {
|
||||
warn!(error = ?e, "failed to transform SSE event, skipping");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Try to get provider response and extract content delta
|
||||
match transformed_event.provider_response() {
|
||||
436
crates/brightstaff/src/handlers/routing_service.rs
Normal file
436
crates/brightstaff/src/handlers/routing_service.rs
Normal file
|
|
@ -0,0 +1,436 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{SpanAttributes, TopLevelRoutingPreference};
|
||||
use common::consts::{MODEL_AFFINITY_HEADER, REQUEST_ID_HEADER};
|
||||
use common::errors::BrightStaffError;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::ProviderRequestType;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use super::extract_or_generate_traceparent;
|
||||
use crate::handlers::llm::model_selection::router_chat_get_upstream_model;
|
||||
use crate::metrics as bs_metrics;
|
||||
use crate::metrics::labels as metric_labels;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
/// Extracts `routing_preferences` from a JSON body, returning the cleaned body bytes
|
||||
/// and the parsed preferences. The field is removed from the JSON before re-serializing
|
||||
/// so downstream parsers don't see it.
|
||||
pub fn extract_routing_policy(
|
||||
raw_bytes: &[u8],
|
||||
) -> Result<(Bytes, Option<Vec<TopLevelRoutingPreference>>), String> {
|
||||
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
|
||||
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
|
||||
|
||||
let routing_preferences = json_body
|
||||
.as_object_mut()
|
||||
.and_then(|o| o.remove("routing_preferences"))
|
||||
.and_then(
|
||||
|value| match serde_json::from_value::<Vec<TopLevelRoutingPreference>>(value) {
|
||||
Ok(prefs) => {
|
||||
info!(
|
||||
num_routes = prefs.len(),
|
||||
"using inline routing_preferences from request body"
|
||||
);
|
||||
Some(prefs)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse routing_preferences");
|
||||
None
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
|
||||
Ok((bytes, routing_preferences))
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RoutingDecisionResponse {
|
||||
/// Ranked model list — use first, fall back to next on 429/5xx.
|
||||
models: Vec<String>,
|
||||
route: Option<String>,
|
||||
trace_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
pinned: bool,
|
||||
}
|
||||
|
||||
pub async fn routing_decision(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
request_path: String,
|
||||
span_attributes: &Option<SpanAttributes>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_headers = request.headers().clone();
|
||||
let request_id: String = request_headers
|
||||
.get(REQUEST_ID_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
||||
|
||||
let session_id: Option<String> = request_headers
|
||||
.get(MODEL_AFFINITY_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let tenant_id: Option<String> = orchestrator_service
|
||||
.tenant_header()
|
||||
.and_then(|hdr| request_headers.get(hdr))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref());
|
||||
|
||||
let request_span = info_span!(
|
||||
"routing_decision",
|
||||
component = "routing",
|
||||
request_id = %request_id,
|
||||
http.method = %request.method(),
|
||||
http.path = %request_path,
|
||||
);
|
||||
|
||||
routing_decision_inner(
|
||||
request,
|
||||
orchestrator_service,
|
||||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
custom_attrs,
|
||||
session_id,
|
||||
tenant_id,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn routing_decision_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
request_id: String,
|
||||
request_path: String,
|
||||
request_headers: hyper::HeaderMap,
|
||||
custom_attrs: std::collections::HashMap<String, String>,
|
||||
session_id: Option<String>,
|
||||
tenant_id: Option<String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
opentelemetry::trace::get_active_span(|span| {
|
||||
for (key, value) in &custom_attrs {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
||||
}
|
||||
});
|
||||
|
||||
let traceparent = extract_or_generate_traceparent(&request_headers);
|
||||
|
||||
// Extract trace_id from traceparent (format: 00-{trace_id}-{span_id}-{flags})
|
||||
let trace_id = traceparent
|
||||
.split('-')
|
||||
.nth(1)
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
if let Some(ref sid) = session_id {
|
||||
if let Some(cached) = orchestrator_service
|
||||
.get_cached_route(sid, tenant_id.as_deref())
|
||||
.await
|
||||
{
|
||||
info!(
|
||||
session_id = %sid,
|
||||
model = %cached.model_name,
|
||||
route = ?cached.route_name,
|
||||
"returning pinned routing decision from cache"
|
||||
);
|
||||
let response = RoutingDecisionResponse {
|
||||
models: vec![cached.model_name],
|
||||
route: cached.route_name,
|
||||
trace_id,
|
||||
session_id: Some(sid.clone()),
|
||||
pinned: true,
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let body = Full::new(Bytes::from(json))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
let raw_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!(
|
||||
body = %String::from_utf8_lossy(&raw_bytes),
|
||||
"routing decision request body received"
|
||||
);
|
||||
|
||||
// Extract routing_preferences from body before parsing as ProviderRequestType
|
||||
let (chat_request_bytes, inline_routing_preferences) = match extract_routing_policy(&raw_bytes)
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request JSON: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
let client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse request for routing decision");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
let routing_result = router_chat_get_upstream_model(
|
||||
Arc::clone(&orchestrator_service),
|
||||
client_request,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_preferences,
|
||||
)
|
||||
.await;
|
||||
|
||||
match routing_result {
|
||||
Ok(result) => {
|
||||
if let Some(ref sid) = session_id {
|
||||
orchestrator_service
|
||||
.cache_route(
|
||||
sid.clone(),
|
||||
tenant_id.as_deref(),
|
||||
result.model_name.clone(),
|
||||
result.route_name.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let response = RoutingDecisionResponse {
|
||||
models: result.models,
|
||||
route: result.route_name,
|
||||
trace_id,
|
||||
session_id,
|
||||
pinned: false,
|
||||
};
|
||||
|
||||
// Distinguish "decision served" (a concrete model picked) from
|
||||
// "no_candidates" (the sentinel "none" returned when nothing
|
||||
// matched). The handler still responds 200 in both cases, so RED
|
||||
// metrics alone can't tell them apart.
|
||||
let outcome = if response.models.first().map(|m| m == "none").unwrap_or(true) {
|
||||
metric_labels::ROUTING_SVC_NO_CANDIDATES
|
||||
} else {
|
||||
metric_labels::ROUTING_SVC_DECISION_SERVED
|
||||
};
|
||||
bs_metrics::record_routing_service_outcome(outcome);
|
||||
|
||||
info!(
|
||||
primary_model = %response.models.first().map(|s| s.as_str()).unwrap_or("none"),
|
||||
total_models = response.models.len(),
|
||||
route = ?response.route,
|
||||
"routing decision completed"
|
||||
);
|
||||
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let body = Full::new(Bytes::from(json))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap())
|
||||
}
|
||||
Err(err) => {
|
||||
bs_metrics::record_routing_service_outcome(metric_labels::ROUTING_SVC_POLICY_ERROR);
|
||||
warn!(error = %err.message, "routing decision failed");
|
||||
Ok(BrightStaffError::InternalServerError(err.message).into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::SelectionPreference;
|
||||
|
||||
fn make_chat_body(extra_fields: &str) -> Vec<u8> {
|
||||
let extra = if extra_fields.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(", {}", extra_fields)
|
||||
};
|
||||
format!(
|
||||
r#"{{"model": "gpt-4o-mini", "messages": [{{"role": "user", "content": "hello"}}]{}}}"#,
|
||||
extra
|
||||
)
|
||||
.into_bytes()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_no_policy() {
|
||||
let body = make_chat_body("");
|
||||
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
assert!(prefs.is_none());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_invalid_json_returns_error() {
|
||||
let body = b"not valid json";
|
||||
let result = extract_routing_policy(body);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Failed to parse JSON"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_routing_preferences() {
|
||||
let policy = r#""routing_preferences": [
|
||||
{
|
||||
"name": "code generation",
|
||||
"description": "generate new code",
|
||||
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
|
||||
"selection_policy": {"prefer": "fastest"}
|
||||
}
|
||||
]"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
let prefs = prefs.expect("should have parsed routing_preferences");
|
||||
assert_eq!(prefs.len(), 1);
|
||||
assert_eq!(prefs[0].name, "code generation");
|
||||
assert_eq!(prefs[0].models, vec!["openai/gpt-4o", "openai/gpt-4o-mini"]);
|
||||
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_preferences").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_preserves_other_fields() {
|
||||
let policy = r#""routing_preferences": [{"name": "test", "description": "test", "models": ["gpt-4o"], "selection_policy": {"prefer": "none"}}], "temperature": 0.5, "max_tokens": 100"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
assert!(prefs.is_some());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["temperature"], 0.5);
|
||||
assert_eq!(cleaned_json["max_tokens"], 100);
|
||||
assert!(cleaned_json.get("routing_preferences").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_prefer_null_defaults_to_none() {
|
||||
let policy = r#""routing_preferences": [
|
||||
{
|
||||
"name": "coding",
|
||||
"description": "code generation, writing functions, debugging",
|
||||
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
|
||||
"selection_policy": {"prefer": null}
|
||||
}
|
||||
]"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (_cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
let prefs = prefs.expect("should parse routing_preferences when prefer is null");
|
||||
assert_eq!(prefs.len(), 1);
|
||||
assert_eq!(prefs[0].selection_policy.prefer, SelectionPreference::None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_selection_policy_missing_defaults_to_none() {
|
||||
let policy = r#""routing_preferences": [
|
||||
{
|
||||
"name": "coding",
|
||||
"description": "code generation, writing functions, debugging",
|
||||
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"]
|
||||
}
|
||||
]"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (_cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
let prefs =
|
||||
prefs.expect("should parse routing_preferences when selection_policy is missing");
|
||||
assert_eq!(prefs.len(), 1);
|
||||
assert_eq!(prefs[0].selection_policy.prefer, SelectionPreference::None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_prefer_empty_string_defaults_to_none() {
|
||||
let policy = r#""routing_preferences": [
|
||||
{
|
||||
"name": "coding",
|
||||
"description": "code generation, writing functions, debugging",
|
||||
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
|
||||
"selection_policy": {"prefer": ""}
|
||||
}
|
||||
]"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (_cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
let prefs =
|
||||
prefs.expect("should parse routing_preferences when selection_policy.prefer is empty");
|
||||
assert_eq!(prefs.len(), 1);
|
||||
assert_eq!(prefs[0].selection_policy.prefer, SelectionPreference::None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn routing_decision_response_serialization() {
|
||||
let response = RoutingDecisionResponse {
|
||||
models: vec![
|
||||
"openai/gpt-4o-mini".to_string(),
|
||||
"openai/gpt-4o".to_string(),
|
||||
],
|
||||
route: Some("code_generation".to_string()),
|
||||
trace_id: "abc123".to_string(),
|
||||
session_id: Some("sess-abc".to_string()),
|
||||
pinned: true,
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed["models"][0], "openai/gpt-4o-mini");
|
||||
assert_eq!(parsed["models"][1], "openai/gpt-4o");
|
||||
assert_eq!(parsed["route"], "code_generation");
|
||||
assert_eq!(parsed["trace_id"], "abc123");
|
||||
assert_eq!(parsed["session_id"], "sess-abc");
|
||||
assert_eq!(parsed["pinned"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn routing_decision_response_serialization_no_route() {
|
||||
let response = RoutingDecisionResponse {
|
||||
models: vec!["none".to_string()],
|
||||
route: None,
|
||||
trace_id: "abc123".to_string(),
|
||||
session_id: None,
|
||||
pinned: false,
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed["models"][0], "none");
|
||||
assert!(parsed["route"].is_null());
|
||||
assert!(parsed.get("session_id").is_none());
|
||||
assert_eq!(parsed["pinned"], false);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,288 +0,0 @@
|
|||
use bytes::Bytes;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::StreamBody;
|
||||
use hyper::body::Frame;
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use opentelemetry::KeyValue;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{info, warn, Instrument};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
||||
use hermesllm::apis::openai::Message;
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
pub trait StreamProcessor: Send + 'static {
|
||||
/// Process an incoming chunk of bytes
|
||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String>;
|
||||
|
||||
/// Called when the first bytes are received (for time-to-first-token tracking)
|
||||
fn on_first_bytes(&mut self) {}
|
||||
|
||||
/// Called when streaming completes successfully
|
||||
fn on_complete(&mut self) {}
|
||||
|
||||
/// Called when streaming encounters an error
|
||||
fn on_error(&mut self, _error: &str) {}
|
||||
}
|
||||
|
||||
/// A processor that tracks streaming metrics
|
||||
pub struct ObservableStreamProcessor {
|
||||
service_name: String,
|
||||
operation_name: String,
|
||||
total_bytes: usize,
|
||||
chunk_count: usize,
|
||||
start_time: Instant,
|
||||
time_to_first_token: Option<u128>,
|
||||
messages: Option<Vec<Message>>,
|
||||
}
|
||||
|
||||
impl ObservableStreamProcessor {
|
||||
/// Create a new passthrough processor
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `service_name` - The service name for this span (e.g., "plano(llm)")
|
||||
/// This will be set as the `service.name.override` attribute on the current span,
|
||||
/// allowing the ServiceNameOverrideExporter to route spans to different services.
|
||||
/// * `operation_name` - The current span operation name (e.g., "POST /v1/chat/completions gpt-4")
|
||||
/// Used to append the flag marker when concerning signals are detected.
|
||||
/// * `start_time` - When the request started (for duration calculation)
|
||||
/// * `messages` - Optional conversation messages for signal analysis
|
||||
pub fn new(
|
||||
service_name: impl Into<String>,
|
||||
operation_name: impl Into<String>,
|
||||
start_time: Instant,
|
||||
messages: Option<Vec<Message>>,
|
||||
) -> Self {
|
||||
let service_name = service_name.into();
|
||||
|
||||
// Set the service name override on the current span for OpenTelemetry export
|
||||
// This allows the ServiceNameOverrideExporter to route this span to the correct service
|
||||
set_service_name(&service_name);
|
||||
|
||||
Self {
|
||||
service_name,
|
||||
operation_name: operation_name.into(),
|
||||
total_bytes: 0,
|
||||
chunk_count: 0,
|
||||
start_time,
|
||||
time_to_first_token: None,
|
||||
messages,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamProcessor for ObservableStreamProcessor {
|
||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||
self.total_bytes += chunk.len();
|
||||
self.chunk_count += 1;
|
||||
Ok(Some(chunk))
|
||||
}
|
||||
|
||||
fn on_first_bytes(&mut self) {
|
||||
// Record time to first token (only for streaming)
|
||||
if self.time_to_first_token.is_none() {
|
||||
self.time_to_first_token = Some(self.start_time.elapsed().as_millis());
|
||||
}
|
||||
}
|
||||
|
||||
fn on_complete(&mut self) {
|
||||
// Record time-to-first-token as an OTel span attribute + event (streaming only)
|
||||
if let Some(ttft) = self.time_to_first_token {
|
||||
let span = tracing::Span::current();
|
||||
let otel_context = span.context();
|
||||
let otel_span = otel_context.span();
|
||||
otel_span.set_attribute(KeyValue::new(llm::TIME_TO_FIRST_TOKEN_MS, ttft as i64));
|
||||
otel_span.add_event(
|
||||
llm::TIME_TO_FIRST_TOKEN_MS,
|
||||
vec![KeyValue::new(llm::TIME_TO_FIRST_TOKEN_MS, ttft as i64)],
|
||||
);
|
||||
}
|
||||
|
||||
// Analyze signals if messages are available and record as span attributes
|
||||
if let Some(ref messages) = self.messages {
|
||||
let analyzer: Box<dyn SignalAnalyzer> = Box::new(TextBasedSignalAnalyzer::new());
|
||||
let report = analyzer.analyze(messages);
|
||||
|
||||
// Get the current OTel span to set signal attributes
|
||||
let span = tracing::Span::current();
|
||||
let otel_context = span.context();
|
||||
let otel_span = otel_context.span();
|
||||
|
||||
// Add overall quality
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::QUALITY,
|
||||
format!("{:?}", report.overall_quality),
|
||||
));
|
||||
|
||||
// Add repair/follow-up metrics if concerning
|
||||
if report.follow_up.is_concerning || report.follow_up.repair_count > 0 {
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::REPAIR_COUNT,
|
||||
report.follow_up.repair_count as i64,
|
||||
));
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::REPAIR_RATIO,
|
||||
format!("{:.3}", report.follow_up.repair_ratio),
|
||||
));
|
||||
}
|
||||
|
||||
// Add frustration metrics
|
||||
if report.frustration.has_frustration {
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::FRUSTRATION_COUNT,
|
||||
report.frustration.frustration_count as i64,
|
||||
));
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::FRUSTRATION_SEVERITY,
|
||||
report.frustration.severity as i64,
|
||||
));
|
||||
}
|
||||
|
||||
// Add repetition metrics
|
||||
if report.repetition.has_looping {
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::REPETITION_COUNT,
|
||||
report.repetition.repetition_count as i64,
|
||||
));
|
||||
}
|
||||
|
||||
// Add escalation metrics
|
||||
if report.escalation.escalation_requested {
|
||||
otel_span
|
||||
.set_attribute(KeyValue::new(signal_constants::ESCALATION_REQUESTED, true));
|
||||
}
|
||||
|
||||
// Add positive feedback metrics
|
||||
if report.positive_feedback.has_positive_feedback {
|
||||
otel_span.set_attribute(KeyValue::new(
|
||||
signal_constants::POSITIVE_FEEDBACK_COUNT,
|
||||
report.positive_feedback.positive_count as i64,
|
||||
));
|
||||
}
|
||||
|
||||
// Flag the span name if any concerning signal is detected
|
||||
let should_flag = report.frustration.has_frustration
|
||||
|| report.repetition.has_looping
|
||||
|| report.escalation.escalation_requested
|
||||
|| matches!(
|
||||
report.overall_quality,
|
||||
InteractionQuality::Poor | InteractionQuality::Severe
|
||||
);
|
||||
|
||||
if should_flag {
|
||||
otel_span.update_name(format!("{} {}", self.operation_name, FLAG_MARKER));
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
service = %self.service_name,
|
||||
total_bytes = self.total_bytes,
|
||||
chunk_count = self.chunk_count,
|
||||
duration_ms = self.start_time.elapsed().as_millis(),
|
||||
time_to_first_token_ms = ?self.time_to_first_token,
|
||||
"streaming completed"
|
||||
);
|
||||
}
|
||||
|
||||
fn on_error(&mut self, error_msg: &str) {
|
||||
warn!(
|
||||
service = %self.service_name,
|
||||
error = error_msg,
|
||||
duration_ms = self.start_time.elapsed().as_millis(),
|
||||
"stream error"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of creating a streaming response
|
||||
pub struct StreamingResponse {
|
||||
pub body: BoxBody<Bytes, hyper::Error>,
|
||||
pub processor_handle: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
pub fn create_streaming_response<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut processor: P,
|
||||
buffer_size: usize,
|
||||
) -> StreamingResponse
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size);
|
||||
|
||||
// Capture the current span so the spawned task inherits the request context
|
||||
let current_span = tracing::Span::current();
|
||||
|
||||
// Spawn a task to process and forward chunks
|
||||
let processor_handle = tokio::spawn(
|
||||
async move {
|
||||
let mut is_first_chunk = true;
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Error receiving chunk: {:?}", err);
|
||||
warn!(error = %err_msg, "stream error");
|
||||
processor.on_error(&err_msg);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Call on_first_bytes for the first chunk
|
||||
if is_first_chunk {
|
||||
processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
// Process the chunk
|
||||
match processor.process_chunk(chunk) {
|
||||
Ok(Some(processed_chunk)) => {
|
||||
if tx.send(processed_chunk).await.is_err() {
|
||||
warn!("receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
// Skip this chunk
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("processor error: {}", err);
|
||||
processor.on_error(&err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
processor.on_complete();
|
||||
}
|
||||
.instrument(current_span),
|
||||
);
|
||||
|
||||
// Convert channel receiver to HTTP stream
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
StreamingResponse {
|
||||
body: stream_body,
|
||||
processor_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates a message to the specified maximum length, adding "..." if truncated.
|
||||
pub fn truncate_message(message: &str, max_length: usize) -> String {
|
||||
if message.chars().count() > max_length {
|
||||
let truncated: String = message.chars().take(max_length).collect();
|
||||
format!("{}...", truncated)
|
||||
} else {
|
||||
message.to_string()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,9 @@
|
|||
pub mod app_state;
|
||||
pub mod handlers;
|
||||
pub mod metrics;
|
||||
pub mod router;
|
||||
pub mod session_cache;
|
||||
pub mod signals;
|
||||
pub mod state;
|
||||
pub mod streaming;
|
||||
pub mod tracing;
|
||||
pub mod utils;
|
||||
|
|
|
|||
|
|
@ -1,293 +1,590 @@
|
|||
use brightstaff::handlers::agent_chat_completions::agent_chat;
|
||||
#[cfg(feature = "jemalloc")]
|
||||
#[global_allocator]
|
||||
static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||
|
||||
use brightstaff::app_state::AppState;
|
||||
use brightstaff::handlers::agents::orchestrator::agent_chat;
|
||||
use brightstaff::handlers::debug;
|
||||
use brightstaff::handlers::empty;
|
||||
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
||||
use brightstaff::handlers::llm::llm_chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::router::plano_orchestrator::OrchestratorService;
|
||||
use brightstaff::handlers::routing_service::routing_decision;
|
||||
use brightstaff::metrics as bs_metrics;
|
||||
use brightstaff::metrics::labels as metric_labels;
|
||||
use brightstaff::router::model_metrics::ModelMetricsService;
|
||||
use brightstaff::router::orchestrator::OrchestratorService;
|
||||
use brightstaff::session_cache::init_session_cache;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use brightstaff::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, Configuration};
|
||||
use common::consts::{
|
||||
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
||||
use common::configuration::{
|
||||
Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain,
|
||||
};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::header::HeaderValue;
|
||||
use hyper::server::conn::http1;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Method, Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use opentelemetry::global;
|
||||
use opentelemetry::trace::FutureExt;
|
||||
use opentelemetry::{global, Context};
|
||||
use opentelemetry_http::HeaderExtractor;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::{env, fs};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
pub mod router;
|
||||
|
||||
const BIND_ADDRESS: &str = "0.0.0.0:9091";
|
||||
const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router";
|
||||
const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
|
||||
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
|
||||
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
|
||||
|
||||
// Utility function to extract the context from the incoming request headers
|
||||
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
propagator.extract(&HeaderExtractor(req.headers()))
|
||||
})
|
||||
/// Parse a version string like `v0.4.0`, `v0.3.0`, `0.2.0` into a `(major, minor, patch)` tuple.
|
||||
/// Missing parts default to 0. Non-numeric parts are treated as 0.
|
||||
fn parse_semver(version: &str) -> (u32, u32, u32) {
|
||||
let v = version.trim_start_matches('v');
|
||||
let mut parts = v.splitn(3, '.').map(|p| p.parse::<u32>().unwrap_or(0));
|
||||
let major = parts.next().unwrap_or(0);
|
||||
let minor = parts.next().unwrap_or(0);
|
||||
let patch = parts.next().unwrap_or(0);
|
||||
(major, minor, patch)
|
||||
}
|
||||
|
||||
fn empty() -> BoxBody<Bytes, hyper::Error> {
|
||||
Empty::<Bytes>::new()
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
/// CORS pre-flight response for the models endpoint.
|
||||
fn cors_preflight() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let mut response = Response::new(empty());
|
||||
*response.status_mut() = StatusCode::NO_CONTENT;
|
||||
let h = response.headers_mut();
|
||||
h.insert("Allow", HeaderValue::from_static("GET, OPTIONS"));
|
||||
h.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
h.insert(
|
||||
"Access-Control-Allow-Headers",
|
||||
HeaderValue::from_static("Authorization, Content-Type"),
|
||||
);
|
||||
h.insert(
|
||||
"Access-Control-Allow-Methods",
|
||||
HeaderValue::from_static("GET, POST, OPTIONS"),
|
||||
);
|
||||
h.insert("Content-Type", HeaderValue::from_static("application/json"));
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
|
||||
// ---------------------------------------------------------------------------
|
||||
// Configuration loading
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// loading plano_config.yaml file (before tracing init so we can read tracing config)
|
||||
let plano_config_path = env::var("PLANO_CONFIG_PATH_RENDERED")
|
||||
/// Load and parse the YAML configuration file.
|
||||
///
|
||||
/// The path is read from `PLANO_CONFIG_PATH_RENDERED` (env) or falls back to
|
||||
/// `./plano_config_rendered.yaml`.
|
||||
fn load_config() -> Result<Configuration, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let path = env::var("PLANO_CONFIG_PATH_RENDERED")
|
||||
.unwrap_or_else(|_| "./plano_config_rendered.yaml".to_string());
|
||||
eprintln!("loading plano_config.yaml from {}", plano_config_path);
|
||||
eprintln!("loading plano_config.yaml from {}", path);
|
||||
|
||||
let config_contents =
|
||||
fs::read_to_string(&plano_config_path).expect("Failed to read plano_config.yaml");
|
||||
let contents = fs::read_to_string(&path).map_err(|e| format!("failed to read {path}: {e}"))?;
|
||||
|
||||
let config: Configuration =
|
||||
serde_yaml::from_str(&config_contents).expect("Failed to parse plano_config.yaml");
|
||||
serde_yaml::from_str(&contents).map_err(|e| format!("failed to parse {path}: {e}"))?;
|
||||
|
||||
// Initialize tracing using config.yaml tracing section
|
||||
let _tracer_provider = init_tracer(config.tracing.as_ref());
|
||||
info!(path = %plano_config_path, "loaded plano_config.yaml");
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
let plano_config = Arc::new(config);
|
||||
// ---------------------------------------------------------------------------
|
||||
// Application state initialization
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// combine agents and filters into a single list of agents
|
||||
let all_agents: Vec<Agent> = plano_config
|
||||
/// Build the shared [`AppState`] from a parsed [`Configuration`].
|
||||
async fn init_app_state(
|
||||
config: &Configuration,
|
||||
) -> Result<AppState, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
// Combine agents and filters into a single list
|
||||
let all_agents: Vec<Agent> = config
|
||||
.agents
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.chain(plano_config.filters.as_deref().unwrap_or_default())
|
||||
.chain(config.filters.as_deref().unwrap_or_default())
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Create expanded provider list for /v1/models endpoint
|
||||
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
|
||||
.expect("Failed to create LlmProviders");
|
||||
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
let global_agent_map: HashMap<String, Agent> = all_agents
|
||||
.iter()
|
||||
.map(|a| (a.id.clone(), a.clone()))
|
||||
.collect();
|
||||
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let routing_model_name: String = plano_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.model.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_MODEL_NAME.to_string());
|
||||
let llm_providers = LlmProviders::try_from(config.model_providers.clone())
|
||||
.map_err(|e| format!("failed to create LlmProviders: {e}"))?;
|
||||
|
||||
let routing_llm_provider = plano_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.model_provider.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
let model_listener_count = config
|
||||
.listeners
|
||||
.iter()
|
||||
.filter(|l| l.listener_type == ListenerType::Model)
|
||||
.count();
|
||||
if model_listener_count > 1 {
|
||||
return Err(format!(
|
||||
"only one model listener is allowed, found {}",
|
||||
model_listener_count
|
||||
)
|
||||
.into());
|
||||
}
|
||||
let model_listener = config
|
||||
.listeners
|
||||
.iter()
|
||||
.find(|l| l.listener_type == ListenerType::Model);
|
||||
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
|
||||
filter_ids.map(|ids| {
|
||||
let agents = ids
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
global_agent_map
|
||||
.get(id)
|
||||
.map(|a: &Agent| (id.clone(), a.clone()))
|
||||
})
|
||||
.collect();
|
||||
ResolvedFilterChain {
|
||||
filter_ids: ids,
|
||||
agents,
|
||||
}
|
||||
})
|
||||
};
|
||||
let filter_pipeline = Arc::new(FilterPipeline {
|
||||
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
|
||||
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
|
||||
});
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
plano_config.model_providers.clone(),
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
));
|
||||
let overrides = config.overrides.clone().unwrap_or_default();
|
||||
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
|
||||
));
|
||||
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
|
||||
let session_cache = init_session_cache(config).await?;
|
||||
|
||||
let model_aliases = Arc::new(plano_config.model_aliases.clone());
|
||||
let span_attributes = Arc::new(
|
||||
plano_config
|
||||
.tracing
|
||||
.as_ref()
|
||||
.and_then(|tracing| tracing.span_attributes.clone()),
|
||||
);
|
||||
// Validate that top-level routing_preferences requires v0.4.0+.
|
||||
let config_version = parse_semver(&config.version);
|
||||
let is_v040_plus = config_version >= (0, 4, 0);
|
||||
|
||||
// Initialize trace collector and start background flusher
|
||||
// Tracing is enabled if the tracing config is present in plano_config.yaml
|
||||
// Pass Some(true/false) to override, or None to use env var OTEL_TRACING_ENABLED
|
||||
// OpenTelemetry automatic instrumentation is configured in utils/tracing.rs
|
||||
if !is_v040_plus && config.routing_preferences.is_some() {
|
||||
return Err(
|
||||
"top-level routing_preferences requires version v0.4.0 or above. \
|
||||
Update the version field or remove routing_preferences."
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Initialize conversation state storage for v1/responses
|
||||
// Configurable via plano_config.yaml state_storage section
|
||||
// If not configured, state management is disabled
|
||||
// Environment variables are substituted by envsubst before config is read
|
||||
let state_storage: Option<Arc<dyn StateStorage>> =
|
||||
if let Some(storage_config) = &plano_config.state_storage {
|
||||
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||
common::configuration::StateStorageType::Memory => {
|
||||
info!(
|
||||
storage_type = "memory",
|
||||
"initialized conversation state storage"
|
||||
);
|
||||
Arc::new(MemoryConversationalStorage::new())
|
||||
}
|
||||
common::configuration::StateStorageType::Postgres => {
|
||||
let connection_string = storage_config
|
||||
.connection_string
|
||||
.as_ref()
|
||||
.expect("connection_string is required for postgres state_storage");
|
||||
|
||||
debug!(connection_string = %connection_string, "postgres connection");
|
||||
info!(
|
||||
storage_type = "postgres",
|
||||
"initializing conversation state storage"
|
||||
);
|
||||
Arc::new(
|
||||
PostgreSQLConversationStorage::new(connection_string.clone())
|
||||
.await
|
||||
.expect("Failed to initialize Postgres state storage"),
|
||||
// Validate that all models referenced in top-level routing_preferences exist in model_providers.
|
||||
// The CLI renders model_providers with `name` = "openai/gpt-4o" and `model` = "gpt-4o",
|
||||
// so we accept a match against either field.
|
||||
if let Some(ref route_prefs) = config.routing_preferences {
|
||||
let provider_model_names: std::collections::HashSet<&str> = config
|
||||
.model_providers
|
||||
.iter()
|
||||
.flat_map(|p| std::iter::once(p.name.as_str()).chain(p.model.as_deref()))
|
||||
.collect();
|
||||
for pref in route_prefs {
|
||||
for model in &pref.models {
|
||||
if !provider_model_names.contains(model.as_str()) {
|
||||
return Err(format!(
|
||||
"routing_preferences route '{}' references model '{}' \
|
||||
which is not declared in model_providers",
|
||||
pref.name, model
|
||||
)
|
||||
}
|
||||
};
|
||||
Some(storage)
|
||||
} else {
|
||||
info!("no state_storage configured, conversation state management disabled");
|
||||
None
|
||||
};
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::clone(&orchestrator_service);
|
||||
let model_aliases: Arc<
|
||||
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
|
||||
> = Arc::clone(&model_aliases);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = combined_agents_filters_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let orchestrator_service = Arc::clone(&orchestrator_service);
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let agents_list = agents_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
||||
async move {
|
||||
let path = req.uri().path();
|
||||
// Check if path starts with /agents
|
||||
if path.starts_with("/agents") {
|
||||
// Check if it matches one of the agent API paths
|
||||
let stripped_path = path.strip_prefix("/agents").unwrap();
|
||||
if matches!(
|
||||
stripped_path,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path);
|
||||
return agent_chat(
|
||||
req,
|
||||
orchestrator_service,
|
||||
fully_qualified_url,
|
||||
agents_list,
|
||||
listeners,
|
||||
span_attributes,
|
||||
llm_providers,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
match (req.method(), path) {
|
||||
(
|
||||
&Method::POST,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
|
||||
) => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, path);
|
||||
llm_chat(
|
||||
req,
|
||||
router_service,
|
||||
fully_qualified_url,
|
||||
model_aliases,
|
||||
llm_providers,
|
||||
span_attributes,
|
||||
state_storage,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/function_calling") => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, "/v1/chat/completions");
|
||||
function_calling_chat_handler(req, fully_qualified_url)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
|
||||
Ok(list_models(llm_providers).await)
|
||||
}
|
||||
// hack for now to get openw-web-ui to work
|
||||
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
|
||||
let mut response = Response::new(empty());
|
||||
*response.status_mut() = StatusCode::NO_CONTENT;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Allow", "GET, OPTIONS".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Authorization, Content-Type".parse().unwrap(),
|
||||
);
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET, POST, OPTIONS".parse().unwrap(),
|
||||
);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
_ => {
|
||||
debug!(method = %req.method(), path = %req.uri().path(), "no route found");
|
||||
let mut not_found = Response::new(empty());
|
||||
*not_found.status_mut() = StatusCode::NOT_FOUND;
|
||||
Ok(not_found)
|
||||
}
|
||||
.into());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
debug!(peer = ?peer_addr, "accepted connection");
|
||||
if let Err(err) = http1::Builder::new()
|
||||
// .serve_connection(io, service_fn(chat_completion))
|
||||
.serve_connection(io, service)
|
||||
.await
|
||||
{
|
||||
warn!(error = ?err, "error serving connection");
|
||||
// Validate and initialize ModelMetricsService if model_metrics_sources is configured.
|
||||
let metrics_service: Option<Arc<ModelMetricsService>> = if let Some(ref sources) =
|
||||
config.model_metrics_sources
|
||||
{
|
||||
use common::configuration::MetricsSource;
|
||||
let cost_count = sources
|
||||
.iter()
|
||||
.filter(|s| matches!(s, MetricsSource::Cost(_)))
|
||||
.count();
|
||||
let latency_count = sources
|
||||
.iter()
|
||||
.filter(|s| matches!(s, MetricsSource::Latency(_)))
|
||||
.count();
|
||||
if cost_count > 1 {
|
||||
return Err("model_metrics_sources: only one cost metrics source is allowed".into());
|
||||
}
|
||||
if latency_count > 1 {
|
||||
return Err("model_metrics_sources: only one latency metrics source is allowed".into());
|
||||
}
|
||||
let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await;
|
||||
Some(Arc::new(svc))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Validate that selection_policy.prefer is compatible with the configured metric sources.
|
||||
if let Some(ref prefs) = config.routing_preferences {
|
||||
use common::configuration::{MetricsSource, SelectionPreference};
|
||||
|
||||
let has_cost_source = config
|
||||
.model_metrics_sources
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.any(|s| matches!(s, MetricsSource::Cost(_)));
|
||||
let has_latency_source = config
|
||||
.model_metrics_sources
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.any(|s| matches!(s, MetricsSource::Latency(_)));
|
||||
|
||||
for pref in prefs {
|
||||
if pref.selection_policy.prefer == SelectionPreference::Cheapest && !has_cost_source {
|
||||
return Err(format!(
|
||||
"routing_preferences route '{}' uses prefer: cheapest but no cost metrics source is configured — \
|
||||
add a cost metrics source to model_metrics_sources",
|
||||
pref.name
|
||||
)
|
||||
.into());
|
||||
}
|
||||
});
|
||||
if pref.selection_policy.prefer == SelectionPreference::Fastest && !has_latency_source {
|
||||
return Err(format!(
|
||||
"routing_preferences route '{}' uses prefer: fastest but no latency metrics source is configured — \
|
||||
add a latency metrics source to model_metrics_sources",
|
||||
pref.name
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Warn about models in routing_preferences that have no matching pricing/latency data.
|
||||
if let (Some(ref prefs), Some(ref svc)) = (&config.routing_preferences, &metrics_service) {
|
||||
let cost_data = svc.cost_snapshot().await;
|
||||
let latency_data = svc.latency_snapshot().await;
|
||||
for pref in prefs {
|
||||
use common::configuration::SelectionPreference;
|
||||
for model in &pref.models {
|
||||
let missing = match pref.selection_policy.prefer {
|
||||
SelectionPreference::Cheapest => !cost_data.contains_key(model.as_str()),
|
||||
SelectionPreference::Fastest => !latency_data.contains_key(model.as_str()),
|
||||
_ => false,
|
||||
};
|
||||
if missing {
|
||||
warn!(
|
||||
model = %model,
|
||||
route = %pref.name,
|
||||
"model has no metric data — will be ranked last"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let session_tenant_header = config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.session_cache.as_ref())
|
||||
.and_then(|c| c.tenant_header.clone());
|
||||
|
||||
// Resolve model name: prefer llm_routing_model override, then agent_orchestration_model, then default.
|
||||
let orchestrator_model_name: String = overrides
|
||||
.llm_routing_model
|
||||
.as_deref()
|
||||
.or(overrides.agent_orchestration_model.as_deref())
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let orchestrator_llm_provider: String = config
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
|
||||
|
||||
let orchestrator_max_tokens = overrides
|
||||
.orchestrator_model_context_length
|
||||
.unwrap_or(brightstaff::router::orchestrator_model_v1::MAX_TOKEN_LEN);
|
||||
|
||||
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
orchestrator_model_name,
|
||||
orchestrator_llm_provider,
|
||||
config.routing_preferences.clone(),
|
||||
metrics_service,
|
||||
session_ttl_seconds,
|
||||
session_cache,
|
||||
session_tenant_header,
|
||||
orchestrator_max_tokens,
|
||||
));
|
||||
|
||||
let state_storage = init_state_storage(config).await?;
|
||||
|
||||
let span_attributes = config
|
||||
.tracing
|
||||
.as_ref()
|
||||
.and_then(|tracing| tracing.span_attributes.clone());
|
||||
|
||||
let signals_enabled = !overrides.disable_signals.unwrap_or(false);
|
||||
|
||||
Ok(AppState {
|
||||
orchestrator_service,
|
||||
model_aliases: config.model_aliases.clone(),
|
||||
llm_providers: Arc::new(RwLock::new(llm_providers)),
|
||||
agents_list: Some(all_agents),
|
||||
listeners: config.listeners.clone(),
|
||||
state_storage,
|
||||
llm_provider_url,
|
||||
span_attributes,
|
||||
http_client: reqwest::Client::new(),
|
||||
filter_pipeline,
|
||||
signals_enabled,
|
||||
})
|
||||
}
|
||||
|
||||
/// Initialize the conversation state storage backend (if configured).
|
||||
async fn init_state_storage(
|
||||
config: &Configuration,
|
||||
) -> Result<Option<Arc<dyn StateStorage>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let Some(storage_config) = &config.state_storage else {
|
||||
info!("no state_storage configured, conversation state management disabled");
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||
common::configuration::StateStorageType::Memory => {
|
||||
info!(
|
||||
storage_type = "memory",
|
||||
"initialized conversation state storage"
|
||||
);
|
||||
Arc::new(MemoryConversationalStorage::new())
|
||||
}
|
||||
common::configuration::StateStorageType::Postgres => {
|
||||
let connection_string = storage_config
|
||||
.connection_string
|
||||
.as_ref()
|
||||
.ok_or("connection_string is required for postgres state_storage")?;
|
||||
|
||||
debug!(connection_string = %connection_string, "postgres connection");
|
||||
info!(
|
||||
storage_type = "postgres",
|
||||
"initializing conversation state storage"
|
||||
);
|
||||
|
||||
Arc::new(
|
||||
PostgreSQLConversationStorage::new(connection_string.clone())
|
||||
.await
|
||||
.map_err(|e| format!("failed to initialize Postgres state storage: {e}"))?,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Some(storage))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Request routing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Normalized method label — limited set so we never emit a free-form string.
|
||||
fn method_label(method: &Method) -> &'static str {
|
||||
match *method {
|
||||
Method::GET => "GET",
|
||||
Method::POST => "POST",
|
||||
Method::PUT => "PUT",
|
||||
Method::DELETE => "DELETE",
|
||||
Method::PATCH => "PATCH",
|
||||
Method::HEAD => "HEAD",
|
||||
Method::OPTIONS => "OPTIONS",
|
||||
_ => "OTHER",
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the fixed `handler` metric label from the request's path+method.
|
||||
/// Returning `None` for fall-through means `route()` will hand the request to
|
||||
/// the catch-all 404 branch.
|
||||
fn handler_label_for(method: &Method, path: &str) -> &'static str {
|
||||
if let Some(stripped) = path.strip_prefix("/agents") {
|
||||
if matches!(
|
||||
stripped,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
return metric_labels::HANDLER_AGENT_CHAT;
|
||||
}
|
||||
}
|
||||
if let Some(stripped) = path.strip_prefix("/routing") {
|
||||
if matches!(
|
||||
stripped,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
return metric_labels::HANDLER_ROUTING_DECISION;
|
||||
}
|
||||
}
|
||||
match (method, path) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
metric_labels::HANDLER_LLM_CHAT
|
||||
}
|
||||
(&Method::POST, "/function_calling") => metric_labels::HANDLER_FUNCTION_CALLING,
|
||||
(&Method::GET, "/v1/models" | "/agents/v1/models") => metric_labels::HANDLER_LIST_MODELS,
|
||||
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
|
||||
metric_labels::HANDLER_CORS_PREFLIGHT
|
||||
}
|
||||
_ => metric_labels::HANDLER_NOT_FOUND,
|
||||
}
|
||||
}
|
||||
|
||||
/// Route an incoming HTTP request to the appropriate handler.
|
||||
async fn route(
|
||||
req: Request<Incoming>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let handler = handler_label_for(req.method(), req.uri().path());
|
||||
let method = method_label(req.method());
|
||||
let started = std::time::Instant::now();
|
||||
let _in_flight = bs_metrics::InFlightGuard::new(handler);
|
||||
|
||||
let result = dispatch(req, state).await;
|
||||
|
||||
let status = match &result {
|
||||
Ok(resp) => resp.status().as_u16(),
|
||||
// hyper::Error here means the body couldn't be produced; conventionally 500.
|
||||
Err(_) => 500,
|
||||
};
|
||||
bs_metrics::record_http(handler, method, status, started);
|
||||
result
|
||||
}
|
||||
|
||||
/// Inner dispatcher split out so `route()` can wrap it with metrics without
|
||||
/// duplicating the match tree.
|
||||
async fn dispatch(
|
||||
req: Request<Incoming>,
|
||||
state: Arc<AppState>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let parent_cx = global::get_text_map_propagator(|p| p.extract(&HeaderExtractor(req.headers())));
|
||||
let path = req.uri().path().to_string();
|
||||
|
||||
// --- Agent routes (/agents/...) ---
|
||||
if let Some(stripped) = path.strip_prefix("/agents") {
|
||||
if matches!(
|
||||
stripped,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
return agent_chat(req, Arc::clone(&state))
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Routing decision routes (/routing/...) ---
|
||||
if let Some(stripped) = path.strip_prefix("/routing") {
|
||||
let stripped = stripped.to_string();
|
||||
if matches!(
|
||||
stripped.as_str(),
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
return routing_decision(
|
||||
req,
|
||||
Arc::clone(&state.orchestrator_service),
|
||||
stripped,
|
||||
&state.span_attributes,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Standard routes ---
|
||||
match (req.method(), path.as_str()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
llm_chat(req, Arc::clone(&state))
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/function_calling") => {
|
||||
let url = format!("{}/v1/chat/completions", state.llm_provider_url);
|
||||
function_calling_chat_handler(req, url)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
|
||||
Ok(list_models(Arc::clone(&state.llm_providers)).await)
|
||||
}
|
||||
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(),
|
||||
(&Method::GET, "/debug/memstats") => debug::memstats().await,
|
||||
_ => {
|
||||
debug!(method = %req.method(), path = %path, "no route found");
|
||||
let mut not_found = Response::new(empty());
|
||||
*not_found.status_mut() = StatusCode::NOT_FOUND;
|
||||
Ok(not_found)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Server loop
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Accept connections and spawn a task per connection.
|
||||
///
|
||||
/// Listens for `SIGINT` / `ctrl-c` and shuts down gracefully, allowing
|
||||
/// in-flight connections to finish.
|
||||
async fn run_server(state: Arc<AppState>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
|
||||
let listener = TcpListener::bind(&bind_address).await?;
|
||||
info!(address = %bind_address, "server listening");
|
||||
|
||||
let shutdown = tokio::signal::ctrl_c();
|
||||
tokio::pin!(shutdown);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = listener.accept() => {
|
||||
let (stream, _) = result?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
let state = Arc::clone(&state);
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
debug!(peer = ?peer_addr, "accepted connection");
|
||||
|
||||
let service = service_fn(move |req| {
|
||||
let state = Arc::clone(&state);
|
||||
async move { route(req, state).await }
|
||||
});
|
||||
|
||||
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
|
||||
warn!(error = ?err, "error serving connection");
|
||||
}
|
||||
});
|
||||
}
|
||||
_ = &mut shutdown => {
|
||||
info!("received shutdown signal, stopping server");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Entry point
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let config = load_config()?;
|
||||
let _tracer_provider = init_tracer(config.tracing.as_ref());
|
||||
bs_metrics::init();
|
||||
info!("loaded plano_config.yaml");
|
||||
let state = Arc::new(init_app_state(&config).await?);
|
||||
run_server(state).await
|
||||
}
|
||||
|
|
|
|||
38
crates/brightstaff/src/metrics/labels.rs
Normal file
38
crates/brightstaff/src/metrics/labels.rs
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
//! Fixed label-value constants so callers never emit free-form strings
|
||||
//! (which would blow up cardinality).
|
||||
|
||||
// Handler enum — derived from the path+method match in `route()`.
|
||||
pub const HANDLER_AGENT_CHAT: &str = "agent_chat";
|
||||
pub const HANDLER_ROUTING_DECISION: &str = "routing_decision";
|
||||
pub const HANDLER_LLM_CHAT: &str = "llm_chat";
|
||||
pub const HANDLER_FUNCTION_CALLING: &str = "function_calling";
|
||||
pub const HANDLER_LIST_MODELS: &str = "list_models";
|
||||
pub const HANDLER_CORS_PREFLIGHT: &str = "cors_preflight";
|
||||
pub const HANDLER_NOT_FOUND: &str = "not_found";
|
||||
|
||||
// Router "route" class — which brightstaff endpoint prompted the decision.
|
||||
pub const ROUTE_AGENT: &str = "agent";
|
||||
pub const ROUTE_ROUTING: &str = "routing";
|
||||
pub const ROUTE_LLM: &str = "llm";
|
||||
|
||||
// Token kind for brightstaff_llm_tokens_total.
|
||||
pub const TOKEN_KIND_PROMPT: &str = "prompt";
|
||||
pub const TOKEN_KIND_COMPLETION: &str = "completion";
|
||||
|
||||
// LLM error_class values (match docstring in metrics/mod.rs).
|
||||
pub const LLM_ERR_NONE: &str = "none";
|
||||
pub const LLM_ERR_TIMEOUT: &str = "timeout";
|
||||
pub const LLM_ERR_CONNECT: &str = "connect";
|
||||
pub const LLM_ERR_PARSE: &str = "parse";
|
||||
pub const LLM_ERR_OTHER: &str = "other";
|
||||
pub const LLM_ERR_STREAM: &str = "stream";
|
||||
|
||||
// Routing service outcome values.
|
||||
pub const ROUTING_SVC_DECISION_SERVED: &str = "decision_served";
|
||||
pub const ROUTING_SVC_NO_CANDIDATES: &str = "no_candidates";
|
||||
pub const ROUTING_SVC_POLICY_ERROR: &str = "policy_error";
|
||||
|
||||
// Session cache outcome values.
|
||||
pub const SESSION_CACHE_HIT: &str = "hit";
|
||||
pub const SESSION_CACHE_MISS: &str = "miss";
|
||||
pub const SESSION_CACHE_STORE: &str = "store";
|
||||
377
crates/brightstaff/src/metrics/mod.rs
Normal file
377
crates/brightstaff/src/metrics/mod.rs
Normal file
|
|
@ -0,0 +1,377 @@
|
|||
//! Prometheus metrics for brightstaff.
|
||||
//!
|
||||
//! Installs the `metrics` global recorder backed by
|
||||
//! `metrics-exporter-prometheus` and exposes a `/metrics` HTTP endpoint on a
|
||||
//! dedicated admin port (default `0.0.0.0:9092`, overridable via
|
||||
//! `METRICS_BIND_ADDRESS`).
|
||||
//!
|
||||
//! Emitted metric families (see `describe_all` for full list):
|
||||
//! - HTTP RED: `brightstaff_http_requests_total`,
|
||||
//! `brightstaff_http_request_duration_seconds`,
|
||||
//! `brightstaff_http_in_flight_requests`.
|
||||
//! - LLM upstream: `brightstaff_llm_upstream_requests_total`,
|
||||
//! `brightstaff_llm_upstream_duration_seconds`,
|
||||
//! `brightstaff_llm_time_to_first_token_seconds`,
|
||||
//! `brightstaff_llm_tokens_total`,
|
||||
//! `brightstaff_llm_tokens_usage_missing_total`.
|
||||
//! - Routing: `brightstaff_router_decisions_total`,
|
||||
//! `brightstaff_router_decision_duration_seconds`,
|
||||
//! `brightstaff_routing_service_requests_total`,
|
||||
//! `brightstaff_session_cache_events_total`.
|
||||
//! - Process: via `metrics-process`.
|
||||
//! - Build: `brightstaff_build_info`.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub mod labels;
|
||||
|
||||
/// Guard flag so tests don't re-install the global recorder.
|
||||
static INIT: OnceLock<()> = OnceLock::new();
|
||||
|
||||
const DEFAULT_METRICS_BIND: &str = "0.0.0.0:9092";
|
||||
|
||||
/// HTTP request duration buckets (seconds). Capped at 60s.
|
||||
const HTTP_BUCKETS: &[f64] = &[
|
||||
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0,
|
||||
];
|
||||
|
||||
/// LLM upstream / TTFT buckets (seconds). Capped at 120s because provider
|
||||
/// completions routinely run that long.
|
||||
const LLM_BUCKETS: &[f64] = &[0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, 120.0];
|
||||
|
||||
/// Router decision buckets (seconds). The orchestrator call itself is usually
|
||||
/// sub-second but bucketed generously in case of upstream slowness.
|
||||
const ROUTER_BUCKETS: &[f64] = &[
|
||||
0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0,
|
||||
];
|
||||
|
||||
/// Install the global recorder and spawn the `/metrics` HTTP listener.
|
||||
///
|
||||
/// Safe to call more than once; subsequent calls are no-ops so tests that
|
||||
/// construct their own recorder still work.
|
||||
pub fn init() {
|
||||
if INIT.get().is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let bind: SocketAddr = std::env::var("METRICS_BIND_ADDRESS")
|
||||
.unwrap_or_else(|_| DEFAULT_METRICS_BIND.to_string())
|
||||
.parse()
|
||||
.unwrap_or_else(|err| {
|
||||
warn!(error = %err, default = DEFAULT_METRICS_BIND, "invalid METRICS_BIND_ADDRESS, falling back to default");
|
||||
DEFAULT_METRICS_BIND.parse().expect("default bind parses")
|
||||
});
|
||||
|
||||
let builder = PrometheusBuilder::new()
|
||||
.with_http_listener(bind)
|
||||
.set_buckets_for_metric(
|
||||
Matcher::Full("brightstaff_http_request_duration_seconds".to_string()),
|
||||
HTTP_BUCKETS,
|
||||
)
|
||||
.and_then(|b| {
|
||||
b.set_buckets_for_metric(Matcher::Prefix("brightstaff_llm_".to_string()), LLM_BUCKETS)
|
||||
})
|
||||
.and_then(|b| {
|
||||
b.set_buckets_for_metric(
|
||||
Matcher::Full("brightstaff_router_decision_duration_seconds".to_string()),
|
||||
ROUTER_BUCKETS,
|
||||
)
|
||||
});
|
||||
|
||||
let builder = match builder {
|
||||
Ok(b) => b,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to configure metrics buckets, using defaults");
|
||||
PrometheusBuilder::new().with_http_listener(bind)
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = builder.install() {
|
||||
warn!(error = %err, "failed to install Prometheus recorder; metrics disabled");
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = INIT.set(());
|
||||
|
||||
describe_all();
|
||||
emit_build_info();
|
||||
|
||||
// Register process-level collector (RSS, CPU, FDs).
|
||||
let collector = metrics_process::Collector::default();
|
||||
collector.describe();
|
||||
// Prime once at startup; subsequent scrapes refresh via the exporter's
|
||||
// per-scrape render, so we additionally refresh on a short interval to
|
||||
// keep gauges moving between scrapes without requiring client pull.
|
||||
collector.collect();
|
||||
tokio::spawn(async move {
|
||||
let mut tick = tokio::time::interval(Duration::from_secs(10));
|
||||
tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tick.tick().await;
|
||||
collector.collect();
|
||||
}
|
||||
});
|
||||
|
||||
info!(address = %bind, "metrics listener started");
|
||||
}
|
||||
|
||||
fn describe_all() {
|
||||
describe_counter!(
|
||||
"brightstaff_http_requests_total",
|
||||
"Total HTTP requests served by brightstaff, by handler and status class."
|
||||
);
|
||||
describe_histogram!(
|
||||
"brightstaff_http_request_duration_seconds",
|
||||
"Wall-clock duration of HTTP requests served by brightstaff, by handler."
|
||||
);
|
||||
describe_gauge!(
|
||||
"brightstaff_http_in_flight_requests",
|
||||
"Number of HTTP requests currently being served by brightstaff, by handler."
|
||||
);
|
||||
|
||||
describe_counter!(
|
||||
"brightstaff_llm_upstream_requests_total",
|
||||
"LLM upstream request outcomes, by provider, model, status class and error class."
|
||||
);
|
||||
describe_histogram!(
|
||||
"brightstaff_llm_upstream_duration_seconds",
|
||||
"Wall-clock duration of LLM upstream calls (stream close for streaming), by provider and model."
|
||||
);
|
||||
describe_histogram!(
|
||||
"brightstaff_llm_time_to_first_token_seconds",
|
||||
"Time from request start to first streamed byte, by provider and model (streaming only)."
|
||||
);
|
||||
describe_counter!(
|
||||
"brightstaff_llm_tokens_total",
|
||||
"Tokens reported in the provider `usage` field, by provider, model and kind (prompt/completion)."
|
||||
);
|
||||
describe_counter!(
|
||||
"brightstaff_llm_tokens_usage_missing_total",
|
||||
"LLM responses that completed without a usable `usage` block (so token counts are unknown)."
|
||||
);
|
||||
|
||||
describe_counter!(
|
||||
"brightstaff_router_decisions_total",
|
||||
"Routing decisions made by the orchestrator, by route, selected model, and whether a fallback was used."
|
||||
);
|
||||
describe_histogram!(
|
||||
"brightstaff_router_decision_duration_seconds",
|
||||
"Time spent in the orchestrator deciding a route, by route."
|
||||
);
|
||||
describe_counter!(
|
||||
"brightstaff_routing_service_requests_total",
|
||||
"Outcomes of /routing/* decision requests: decision_served, no_candidates, policy_error."
|
||||
);
|
||||
describe_counter!(
|
||||
"brightstaff_session_cache_events_total",
|
||||
"Session affinity cache lookups and stores, by outcome."
|
||||
);
|
||||
|
||||
describe_gauge!(
|
||||
"brightstaff_build_info",
|
||||
"Build metadata. Always 1; labels carry version and git SHA."
|
||||
);
|
||||
}
|
||||
|
||||
fn emit_build_info() {
|
||||
let version = env!("CARGO_PKG_VERSION");
|
||||
let git_sha = option_env!("GIT_SHA").unwrap_or("unknown");
|
||||
gauge!(
|
||||
"brightstaff_build_info",
|
||||
"version" => version.to_string(),
|
||||
"git_sha" => git_sha.to_string(),
|
||||
)
|
||||
.set(1.0);
|
||||
}
|
||||
|
||||
/// Split a provider-qualified model id like `"openai/gpt-4o"` into
|
||||
/// `(provider, model)`. Returns `("unknown", raw)` when there is no `/`.
|
||||
pub fn split_provider_model(full: &str) -> (&str, &str) {
|
||||
match full.split_once('/') {
|
||||
Some((p, m)) => (p, m),
|
||||
None => ("unknown", full),
|
||||
}
|
||||
}
|
||||
|
||||
/// Bucket an HTTP status code into `"2xx"` / `"4xx"` / `"5xx"` / `"1xx"` / `"3xx"`.
|
||||
pub fn status_class(status: u16) -> &'static str {
|
||||
match status {
|
||||
100..=199 => "1xx",
|
||||
200..=299 => "2xx",
|
||||
300..=399 => "3xx",
|
||||
400..=499 => "4xx",
|
||||
500..=599 => "5xx",
|
||||
_ => "other",
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HTTP RED helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// RAII guard that increments the in-flight gauge on construction and
|
||||
/// decrements on drop. Pair with [`HttpTimer`] in the `route()` wrapper so the
|
||||
/// gauge drops even on error paths.
|
||||
pub struct InFlightGuard {
|
||||
handler: &'static str,
|
||||
}
|
||||
|
||||
impl InFlightGuard {
|
||||
pub fn new(handler: &'static str) -> Self {
|
||||
gauge!(
|
||||
"brightstaff_http_in_flight_requests",
|
||||
"handler" => handler,
|
||||
)
|
||||
.increment(1.0);
|
||||
Self { handler }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for InFlightGuard {
|
||||
fn drop(&mut self) {
|
||||
gauge!(
|
||||
"brightstaff_http_in_flight_requests",
|
||||
"handler" => self.handler,
|
||||
)
|
||||
.decrement(1.0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Record the HTTP request counter + duration histogram.
|
||||
pub fn record_http(handler: &'static str, method: &'static str, status: u16, started: Instant) {
|
||||
let class = status_class(status);
|
||||
counter!(
|
||||
"brightstaff_http_requests_total",
|
||||
"handler" => handler,
|
||||
"method" => method,
|
||||
"status_class" => class,
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
"brightstaff_http_request_duration_seconds",
|
||||
"handler" => handler,
|
||||
)
|
||||
.record(started.elapsed().as_secs_f64());
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// LLM upstream helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Classify an outcome of an LLM upstream call for the `error_class` label.
|
||||
pub fn llm_error_class_from_reqwest(err: &reqwest::Error) -> &'static str {
|
||||
if err.is_timeout() {
|
||||
"timeout"
|
||||
} else if err.is_connect() {
|
||||
"connect"
|
||||
} else if err.is_decode() {
|
||||
"parse"
|
||||
} else {
|
||||
"other"
|
||||
}
|
||||
}
|
||||
|
||||
/// Record the outcome of an LLM upstream call. `status` is the HTTP status
|
||||
/// the upstream returned (0 if the call never produced one, e.g. send failure).
|
||||
/// `error_class` is `"none"` on success, or a discriminated error label.
|
||||
pub fn record_llm_upstream(
|
||||
provider: &str,
|
||||
model: &str,
|
||||
status: u16,
|
||||
error_class: &str,
|
||||
duration: Duration,
|
||||
) {
|
||||
let class = if status == 0 {
|
||||
"error"
|
||||
} else {
|
||||
status_class(status)
|
||||
};
|
||||
counter!(
|
||||
"brightstaff_llm_upstream_requests_total",
|
||||
"provider" => provider.to_string(),
|
||||
"model" => model.to_string(),
|
||||
"status_class" => class,
|
||||
"error_class" => error_class.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
"brightstaff_llm_upstream_duration_seconds",
|
||||
"provider" => provider.to_string(),
|
||||
"model" => model.to_string(),
|
||||
)
|
||||
.record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
pub fn record_llm_ttft(provider: &str, model: &str, ttft: Duration) {
|
||||
histogram!(
|
||||
"brightstaff_llm_time_to_first_token_seconds",
|
||||
"provider" => provider.to_string(),
|
||||
"model" => model.to_string(),
|
||||
)
|
||||
.record(ttft.as_secs_f64());
|
||||
}
|
||||
|
||||
pub fn record_llm_tokens(provider: &str, model: &str, kind: &'static str, count: u64) {
|
||||
counter!(
|
||||
"brightstaff_llm_tokens_total",
|
||||
"provider" => provider.to_string(),
|
||||
"model" => model.to_string(),
|
||||
"kind" => kind,
|
||||
)
|
||||
.increment(count);
|
||||
}
|
||||
|
||||
pub fn record_llm_tokens_usage_missing(provider: &str, model: &str) {
|
||||
counter!(
|
||||
"brightstaff_llm_tokens_usage_missing_total",
|
||||
"provider" => provider.to_string(),
|
||||
"model" => model.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Router helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub fn record_router_decision(
|
||||
route: &'static str,
|
||||
selected_model: &str,
|
||||
fallback: bool,
|
||||
duration: Duration,
|
||||
) {
|
||||
counter!(
|
||||
"brightstaff_router_decisions_total",
|
||||
"route" => route,
|
||||
"selected_model" => selected_model.to_string(),
|
||||
"fallback" => if fallback { "true" } else { "false" },
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
"brightstaff_router_decision_duration_seconds",
|
||||
"route" => route,
|
||||
)
|
||||
.record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
pub fn record_routing_service_outcome(outcome: &'static str) {
|
||||
counter!(
|
||||
"brightstaff_routing_service_requests_total",
|
||||
"outcome" => outcome,
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
pub fn record_session_cache_event(outcome: &'static str) {
|
||||
counter!(
|
||||
"brightstaff_session_cache_events_total",
|
||||
"outcome" => outcome,
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
173
crates/brightstaff/src/router/http.rs
Normal file
173
crates/brightstaff/src/router/http.rs
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
use hermesllm::apis::openai::ChatCompletionsResponse;
|
||||
use hyper::header;
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
|
||||
/// Max bytes of raw upstream body we include in a log message or error text
|
||||
/// when the body is not a recognizable error envelope. Keeps logs from being
|
||||
/// flooded by huge HTML error pages.
|
||||
const RAW_BODY_LOG_LIMIT: usize = 512;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HttpError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
Request(#[from] reqwest::Error),
|
||||
|
||||
#[error("Failed to parse JSON response: {0}")]
|
||||
Json(serde_json::Error, String),
|
||||
|
||||
#[error("Upstream returned {status}: {message}")]
|
||||
Upstream { status: u16, message: String },
|
||||
}
|
||||
|
||||
/// Shape of an OpenAI-style error response body, e.g.
|
||||
/// `{"error": {"message": "...", "type": "...", "param": "...", "code": ...}}`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpstreamErrorEnvelope {
|
||||
error: UpstreamErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpstreamErrorBody {
|
||||
message: String,
|
||||
#[serde(default, rename = "type")]
|
||||
err_type: Option<String>,
|
||||
#[serde(default)]
|
||||
param: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract a human-readable error message from an upstream response body.
|
||||
/// Tries to parse an OpenAI-style `{"error": {"message": ...}}` envelope; if
|
||||
/// that fails, falls back to the first `RAW_BODY_LOG_LIMIT` bytes of the raw
|
||||
/// body (UTF-8 safe).
|
||||
fn extract_upstream_error_message(body: &str) -> String {
|
||||
if let Ok(env) = serde_json::from_str::<UpstreamErrorEnvelope>(body) {
|
||||
let mut msg = env.error.message;
|
||||
if let Some(param) = env.error.param {
|
||||
msg.push_str(&format!(" (param={param})"));
|
||||
}
|
||||
if let Some(err_type) = env.error.err_type {
|
||||
msg.push_str(&format!(" [type={err_type}]"));
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
truncate_for_log(body).to_string()
|
||||
}
|
||||
|
||||
fn truncate_for_log(s: &str) -> &str {
|
||||
if s.len() <= RAW_BODY_LOG_LIMIT {
|
||||
return s;
|
||||
}
|
||||
let mut end = RAW_BODY_LOG_LIMIT;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&s[..end]
|
||||
}
|
||||
|
||||
/// Sends a POST request to the given URL and extracts the text content
|
||||
/// from the first choice of the `ChatCompletionsResponse`.
|
||||
///
|
||||
/// Returns `Some((content, elapsed))` on success, `None` if the response
|
||||
/// had no choices or the first choice had no content. Returns
|
||||
/// `HttpError::Upstream` for any non-2xx status, carrying a message
|
||||
/// extracted from the OpenAI-style error envelope (or a truncated raw body
|
||||
/// if the body is not in that shape).
|
||||
pub async fn post_and_extract_content(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
headers: header::HeaderMap,
|
||||
body: String,
|
||||
) -> Result<Option<(String, std::time::Duration)>, HttpError> {
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
let res = client.post(url).headers(headers).body(body).send().await?;
|
||||
let status = res.status();
|
||||
|
||||
let body = res.text().await?;
|
||||
let elapsed = start_time.elapsed();
|
||||
|
||||
if !status.is_success() {
|
||||
let message = extract_upstream_error_message(&body);
|
||||
warn!(
|
||||
status = status.as_u16(),
|
||||
message = %message,
|
||||
body_size = body.len(),
|
||||
"upstream returned error response"
|
||||
);
|
||||
return Err(HttpError::Upstream {
|
||||
status: status.as_u16(),
|
||||
message,
|
||||
});
|
||||
}
|
||||
|
||||
let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| {
|
||||
warn!(
|
||||
error = %err,
|
||||
body = %truncate_for_log(&body),
|
||||
"failed to parse json response",
|
||||
);
|
||||
HttpError::Json(err, format!("Failed to parse JSON: {}", body))
|
||||
})?;
|
||||
|
||||
if response.choices.is_empty() {
|
||||
warn!(body = %truncate_for_log(&body), "no choices in response");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|c| (c.clone(), elapsed)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn extracts_message_from_openai_style_error_envelope() {
|
||||
let body = r#"{"error":{"code":400,"message":"This model's maximum context length is 32768 tokens. However, you requested 0 output tokens and your prompt contains at least 32769 input tokens, for a total of at least 32769 tokens.","param":"input_tokens","type":"BadRequestError"}}"#;
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert!(
|
||||
msg.starts_with("This model's maximum context length is 32768 tokens."),
|
||||
"unexpected message: {msg}"
|
||||
);
|
||||
assert!(msg.contains("(param=input_tokens)"));
|
||||
assert!(msg.contains("[type=BadRequestError]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_message_without_optional_fields() {
|
||||
let body = r#"{"error":{"message":"something broke"}}"#;
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert_eq!(msg, "something broke");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn falls_back_to_raw_body_when_not_error_envelope() {
|
||||
let body = "<html><body>502 Bad Gateway</body></html>";
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert_eq!(msg, body);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_non_envelope_bodies_in_logs() {
|
||||
let body = "x".repeat(RAW_BODY_LOG_LIMIT * 3);
|
||||
let msg = extract_upstream_error_message(&body);
|
||||
assert_eq!(msg.len(), RAW_BODY_LOG_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_for_log_respects_utf8_boundaries() {
|
||||
// 2-byte characters; picking a length that would split mid-char.
|
||||
let body = "é".repeat(RAW_BODY_LOG_LIMIT);
|
||||
let out = truncate_for_log(&body);
|
||||
// Should be a valid &str (implicit — would panic if we returned
|
||||
// a non-boundary slice) and at most RAW_BODY_LOG_LIMIT bytes.
|
||||
assert!(out.len() <= RAW_BODY_LOG_LIMIT);
|
||||
assert!(out.chars().all(|c| c == 'é'));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,187 +0,0 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use common::{
|
||||
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
|
||||
};
|
||||
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::router_model_v1::{self};
|
||||
|
||||
use super::router_model::RouterModel;
|
||||
|
||||
pub struct RouterService {
|
||||
router_url: String,
|
||||
client: reqwest::Client,
|
||||
router_model: Arc<dyn RouterModel>,
|
||||
#[allow(dead_code)]
|
||||
routing_provider_name: String,
|
||||
llm_usage_defined: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
RequestError(#[from] reqwest::Error),
|
||||
|
||||
#[error("Failed to parse JSON: {0}, JSON: {1}")]
|
||||
JsonError(serde_json::Error, String),
|
||||
|
||||
#[error("Router model error: {0}")]
|
||||
RouterModelError(#[from] super::router_model::RoutingModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingError>;
|
||||
|
||||
impl RouterService {
|
||||
pub fn new(
|
||||
providers: Vec<LlmProvider>,
|
||||
router_url: String,
|
||||
routing_model_name: String,
|
||||
routing_provider_name: String,
|
||||
) -> Self {
|
||||
let providers_with_usage = providers
|
||||
.iter()
|
||||
.filter(|provider| provider.routing_preferences.is_some())
|
||||
.cloned()
|
||||
.collect::<Vec<LlmProvider>>();
|
||||
|
||||
let llm_routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
|
||||
.iter()
|
||||
.filter_map(|provider| {
|
||||
provider
|
||||
.routing_preferences
|
||||
.as_ref()
|
||||
.map(|prefs| (provider.name.clone(), prefs.clone()))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
|
||||
llm_routes,
|
||||
routing_model_name,
|
||||
router_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
RouterService {
|
||||
router_url,
|
||||
client: reqwest::Client::new(),
|
||||
router_model,
|
||||
routing_provider_name,
|
||||
llm_usage_defined: !providers_with_usage.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn determine_route(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
traceparent: &str,
|
||||
usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
request_id: &str,
|
||||
) -> Result<Option<(String, String)>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2)
|
||||
&& !self.llm_usage_defined
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let router_request = self
|
||||
.router_model
|
||||
.generate_request(messages, &usage_preferences);
|
||||
|
||||
debug!(
|
||||
model = %self.router_model.get_model_name(),
|
||||
endpoint = %self.router_url,
|
||||
"sending request to arch-router"
|
||||
);
|
||||
|
||||
debug!(
|
||||
body = %serde_json::to_string(&router_request).unwrap(),
|
||||
"arch router request"
|
||||
);
|
||||
|
||||
let mut llm_route_request_headers = header::HeaderMap::new();
|
||||
llm_route_request_headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
llm_route_request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
header::HeaderValue::from_str(&self.routing_provider_name).unwrap(),
|
||||
);
|
||||
|
||||
llm_route_request_headers.insert(
|
||||
header::HeaderName::from_static(TRACE_PARENT_HEADER),
|
||||
header::HeaderValue::from_str(traceparent).unwrap(),
|
||||
);
|
||||
|
||||
llm_route_request_headers.insert(
|
||||
header::HeaderName::from_static(REQUEST_ID_HEADER),
|
||||
header::HeaderValue::from_str(request_id).unwrap(),
|
||||
);
|
||||
|
||||
llm_route_request_headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_static("arch-router"),
|
||||
);
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let res = self
|
||||
.client
|
||||
.post(&self.router_url)
|
||||
.headers(llm_route_request_headers)
|
||||
.body(serde_json::to_string(&router_request).unwrap())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body = res.text().await?;
|
||||
let router_response_time = start_time.elapsed();
|
||||
|
||||
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
error = %err,
|
||||
body = %serde_json::to_string(&body).unwrap(),
|
||||
"failed to parse json response"
|
||||
);
|
||||
return Err(RoutingError::JsonError(
|
||||
err,
|
||||
format!("Failed to parse JSON: {}", body),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if chat_completion_response.choices.is_empty() {
|
||||
warn!(body = %body, "no choices in router response");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(content) = &chat_completion_response.choices[0].message.content {
|
||||
let parsed_response = self
|
||||
.router_model
|
||||
.parse_response(content, &usage_preferences)?;
|
||||
info!(
|
||||
content = %content.replace("\n", "\\n"),
|
||||
selected_model = ?parsed_response,
|
||||
response_time_ms = router_response_time.as_millis(),
|
||||
"arch-router determined route"
|
||||
);
|
||||
|
||||
if let Some(ref parsed_response) = parsed_response {
|
||||
return Ok(Some(parsed_response.clone()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
pub mod llm_router;
|
||||
pub(crate) mod http;
|
||||
pub mod model_metrics;
|
||||
pub mod orchestrator;
|
||||
pub mod orchestrator_model;
|
||||
pub mod orchestrator_model_v1;
|
||||
pub mod plano_orchestrator;
|
||||
pub mod router_model;
|
||||
pub mod router_model_v1;
|
||||
#[cfg(test)]
|
||||
mod stress_tests;
|
||||
|
|
|
|||
388
crates/brightstaff/src/router/model_metrics.rs
Normal file
388
crates/brightstaff/src/router/model_metrics.rs
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use common::configuration::{
|
||||
CostProvider, LatencyProvider, MetricsSource, SelectionPolicy, SelectionPreference,
|
||||
};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
const DO_PRICING_URL: &str = "https://api.digitalocean.com/v2/gen-ai/models/catalog";
|
||||
|
||||
pub struct ModelMetricsService {
|
||||
cost: Arc<RwLock<HashMap<String, f64>>>,
|
||||
latency: Arc<RwLock<HashMap<String, f64>>>,
|
||||
}
|
||||
|
||||
impl ModelMetricsService {
|
||||
pub async fn new(sources: &[MetricsSource], client: reqwest::Client) -> Self {
|
||||
let cost_data = Arc::new(RwLock::new(HashMap::new()));
|
||||
let latency_data = Arc::new(RwLock::new(HashMap::new()));
|
||||
|
||||
for source in sources {
|
||||
match source {
|
||||
MetricsSource::Cost(cfg) => match cfg.provider {
|
||||
CostProvider::Digitalocean => {
|
||||
let aliases = cfg.model_aliases.clone().unwrap_or_default();
|
||||
let data = fetch_do_pricing(&client, &aliases).await;
|
||||
info!(models = data.len(), "fetched digitalocean pricing");
|
||||
*cost_data.write().await = data;
|
||||
|
||||
if let Some(interval_secs) = cfg.refresh_interval {
|
||||
let cost_clone = Arc::clone(&cost_data);
|
||||
let client_clone = client.clone();
|
||||
let interval = Duration::from_secs(interval_secs);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
let data = fetch_do_pricing(&client_clone, &aliases).await;
|
||||
info!(models = data.len(), "refreshed digitalocean pricing");
|
||||
*cost_clone.write().await = data;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
MetricsSource::Latency(cfg) => match cfg.provider {
|
||||
LatencyProvider::Prometheus => {
|
||||
let data = fetch_prometheus_metrics(&cfg.url, &cfg.query, &client).await;
|
||||
info!(models = data.len(), url = %cfg.url, "fetched latency metrics");
|
||||
*latency_data.write().await = data;
|
||||
|
||||
if let Some(interval_secs) = cfg.refresh_interval {
|
||||
let latency_clone = Arc::clone(&latency_data);
|
||||
let client_clone = client.clone();
|
||||
let url = cfg.url.clone();
|
||||
let query = cfg.query.clone();
|
||||
let interval = Duration::from_secs(interval_secs);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
let data =
|
||||
fetch_prometheus_metrics(&url, &query, &client_clone).await;
|
||||
info!(models = data.len(), url = %url, "refreshed latency metrics");
|
||||
*latency_clone.write().await = data;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
ModelMetricsService {
|
||||
cost: cost_data,
|
||||
latency: latency_data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Rank `models` by `policy`, returning them in preference order.
|
||||
/// Models with no metric data are appended at the end in their original order.
|
||||
pub async fn rank_models(&self, models: &[String], policy: &SelectionPolicy) -> Vec<String> {
|
||||
let cost_data = self.cost.read().await;
|
||||
let latency_data = self.latency.read().await;
|
||||
debug!(
|
||||
input_models = ?models,
|
||||
cost_data = ?cost_data.iter().collect::<Vec<_>>(),
|
||||
latency_data = ?latency_data.iter().collect::<Vec<_>>(),
|
||||
prefer = ?policy.prefer,
|
||||
"rank_models called"
|
||||
);
|
||||
|
||||
match policy.prefer {
|
||||
SelectionPreference::Cheapest => {
|
||||
for m in models {
|
||||
if !cost_data.contains_key(m.as_str()) {
|
||||
warn!(model = %m, "no cost data for model — ranking last (prefer: cheapest)");
|
||||
}
|
||||
}
|
||||
rank_by_ascending_metric(models, &cost_data)
|
||||
}
|
||||
SelectionPreference::Fastest => {
|
||||
for m in models {
|
||||
if !latency_data.contains_key(m.as_str()) {
|
||||
warn!(model = %m, "no latency data for model — ranking last (prefer: fastest)");
|
||||
}
|
||||
}
|
||||
rank_by_ascending_metric(models, &latency_data)
|
||||
}
|
||||
SelectionPreference::None => models.to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a snapshot of the current cost data. Used at startup to warn about unmatched models.
|
||||
pub async fn cost_snapshot(&self) -> HashMap<String, f64> {
|
||||
self.cost.read().await.clone()
|
||||
}
|
||||
|
||||
/// Returns a snapshot of the current latency data. Used at startup to warn about unmatched models.
|
||||
pub async fn latency_snapshot(&self) -> HashMap<String, f64> {
|
||||
self.latency.read().await.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn rank_by_ascending_metric(models: &[String], data: &HashMap<String, f64>) -> Vec<String> {
|
||||
let mut with_data: Vec<(&String, f64)> = models
|
||||
.iter()
|
||||
.filter_map(|m| {
|
||||
let v = *data.get(m.as_str())?;
|
||||
if v.is_nan() {
|
||||
None
|
||||
} else {
|
||||
Some((m, v))
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
with_data.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let without_data: Vec<&String> = models
|
||||
.iter()
|
||||
.filter(|m| data.get(m.as_str()).is_none_or(|v| v.is_nan()))
|
||||
.collect();
|
||||
|
||||
with_data
|
||||
.iter()
|
||||
.map(|(m, _)| (*m).clone())
|
||||
.chain(without_data.iter().map(|m| (*m).clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct DoModelList {
|
||||
data: Vec<DoModel>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct DoModel {
|
||||
model_id: String,
|
||||
pricing: Option<DoPricing>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct DoPricing {
|
||||
input_price_per_million: Option<f64>,
|
||||
output_price_per_million: Option<f64>,
|
||||
}
|
||||
|
||||
async fn fetch_do_pricing(
|
||||
client: &reqwest::Client,
|
||||
aliases: &HashMap<String, String>,
|
||||
) -> HashMap<String, f64> {
|
||||
match client.get(DO_PRICING_URL).send().await {
|
||||
Ok(resp) => match resp.json::<DoModelList>().await {
|
||||
Ok(list) => list
|
||||
.data
|
||||
.into_iter()
|
||||
.filter_map(|m| {
|
||||
let pricing = m.pricing?;
|
||||
let raw_key = m.model_id.clone();
|
||||
let key = aliases.get(&raw_key).cloned().unwrap_or(raw_key);
|
||||
let cost = pricing.input_price_per_million.unwrap_or(0.0)
|
||||
+ pricing.output_price_per_million.unwrap_or(0.0);
|
||||
Some((key, cost))
|
||||
})
|
||||
.collect(),
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = DO_PRICING_URL, "failed to parse digitalocean pricing response");
|
||||
HashMap::new()
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = DO_PRICING_URL, "failed to fetch digitalocean pricing");
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct PrometheusResponse {
|
||||
data: PrometheusData,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct PrometheusData {
|
||||
result: Vec<PrometheusResult>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct PrometheusResult {
|
||||
metric: HashMap<String, String>,
|
||||
value: (f64, String), // (timestamp, value_str)
|
||||
}
|
||||
|
||||
async fn fetch_prometheus_metrics(
|
||||
url: &str,
|
||||
query: &str,
|
||||
client: &reqwest::Client,
|
||||
) -> HashMap<String, f64> {
|
||||
let query_url = format!("{}/api/v1/query", url.trim_end_matches('/'));
|
||||
match client
|
||||
.get(&query_url)
|
||||
.query(&[("query", query)])
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => match resp.json::<PrometheusResponse>().await {
|
||||
Ok(prom) => prom
|
||||
.data
|
||||
.result
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
let model_name = r.metric.get("model_name")?.clone();
|
||||
let value: f64 = r.value.1.parse().ok()?;
|
||||
Some((model_name, value))
|
||||
})
|
||||
.collect(),
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = %query_url, "failed to parse prometheus response");
|
||||
HashMap::new()
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = %query_url, "failed to fetch prometheus metrics");
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::SelectionPreference;
|
||||
|
||||
fn make_policy(prefer: SelectionPreference) -> SelectionPolicy {
|
||||
SelectionPolicy { prefer }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rank_by_ascending_metric_picks_lowest_first() {
|
||||
let models = vec!["a".to_string(), "b".to_string(), "c".to_string()];
|
||||
let mut data = HashMap::new();
|
||||
data.insert("a".to_string(), 0.01);
|
||||
data.insert("b".to_string(), 0.005);
|
||||
data.insert("c".to_string(), 0.02);
|
||||
assert_eq!(
|
||||
rank_by_ascending_metric(&models, &data),
|
||||
vec!["b", "a", "c"]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rank_by_ascending_metric_no_data_preserves_order() {
|
||||
let models = vec!["x".to_string(), "y".to_string()];
|
||||
let data = HashMap::new();
|
||||
assert_eq!(rank_by_ascending_metric(&models, &data), vec!["x", "y"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rank_by_ascending_metric_partial_data() {
|
||||
let models = vec!["a".to_string(), "b".to_string()];
|
||||
let mut data = HashMap::new();
|
||||
data.insert("b".to_string(), 100.0);
|
||||
assert_eq!(rank_by_ascending_metric(&models, &data), vec!["b", "a"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rank_models_cheapest() {
|
||||
let service = ModelMetricsService {
|
||||
cost: Arc::new(RwLock::new({
|
||||
let mut m = HashMap::new();
|
||||
m.insert("gpt-4o".to_string(), 0.005);
|
||||
m.insert("gpt-4o-mini".to_string(), 0.0001);
|
||||
m
|
||||
})),
|
||||
latency: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
|
||||
let result = service
|
||||
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
|
||||
.await;
|
||||
assert_eq!(result, vec!["gpt-4o-mini", "gpt-4o"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rank_models_fastest() {
|
||||
let service = ModelMetricsService {
|
||||
cost: Arc::new(RwLock::new(HashMap::new())),
|
||||
latency: Arc::new(RwLock::new({
|
||||
let mut m = HashMap::new();
|
||||
m.insert("gpt-4o".to_string(), 200.0);
|
||||
m.insert("claude-sonnet".to_string(), 120.0);
|
||||
m
|
||||
})),
|
||||
};
|
||||
let models = vec!["gpt-4o".to_string(), "claude-sonnet".to_string()];
|
||||
let result = service
|
||||
.rank_models(&models, &make_policy(SelectionPreference::Fastest))
|
||||
.await;
|
||||
assert_eq!(result, vec!["claude-sonnet", "gpt-4o"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rank_models_fallback_no_metrics() {
|
||||
let service = ModelMetricsService {
|
||||
cost: Arc::new(RwLock::new(HashMap::new())),
|
||||
latency: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
let models = vec!["model-a".to_string(), "model-b".to_string()];
|
||||
let result = service
|
||||
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
|
||||
.await;
|
||||
assert_eq!(result, vec!["model-a", "model-b"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rank_models_partial_data_appended_last() {
|
||||
let service = ModelMetricsService {
|
||||
cost: Arc::new(RwLock::new({
|
||||
let mut m = HashMap::new();
|
||||
m.insert("gpt-4o".to_string(), 0.005);
|
||||
m
|
||||
})),
|
||||
latency: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
let models = vec!["gpt-4o-mini".to_string(), "gpt-4o".to_string()];
|
||||
let result = service
|
||||
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
|
||||
.await;
|
||||
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rank_models_none_preserves_order() {
|
||||
let service = ModelMetricsService {
|
||||
cost: Arc::new(RwLock::new({
|
||||
let mut m = HashMap::new();
|
||||
m.insert("gpt-4o-mini".to_string(), 0.0001);
|
||||
m.insert("gpt-4o".to_string(), 0.005);
|
||||
m
|
||||
})),
|
||||
latency: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
|
||||
let result = service
|
||||
.rank_models(&models, &make_policy(SelectionPreference::None))
|
||||
.await;
|
||||
// none → original order, despite gpt-4o-mini being cheaper
|
||||
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rank_by_ascending_metric_nan_treated_as_missing() {
|
||||
let models = vec![
|
||||
"a".to_string(),
|
||||
"b".to_string(),
|
||||
"c".to_string(),
|
||||
"d".to_string(),
|
||||
];
|
||||
let mut data = HashMap::new();
|
||||
data.insert("a".to_string(), f64::NAN);
|
||||
data.insert("b".to_string(), 0.5);
|
||||
data.insert("c".to_string(), 0.1);
|
||||
// "d" has no entry at all
|
||||
let result = rank_by_ascending_metric(&models, &data);
|
||||
// c (0.1) < b (0.5), then NaN "a" and missing "d" appended in original order
|
||||
assert_eq!(result, vec!["c", "b", "a", "d"]);
|
||||
}
|
||||
}
|
||||
433
crates/brightstaff/src/router/orchestrator.rs
Normal file
433
crates/brightstaff/src/router/orchestrator.rs
Normal file
|
|
@ -0,0 +1,433 @@
|
|||
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use common::{
|
||||
configuration::{AgentUsagePreference, OrchestrationPreference, TopLevelRoutingPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER},
|
||||
};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hyper::header;
|
||||
use opentelemetry::global;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::http::{self, post_and_extract_content};
|
||||
use super::model_metrics::ModelMetricsService;
|
||||
use super::orchestrator_model::OrchestratorModel;
|
||||
|
||||
use crate::metrics as bs_metrics;
|
||||
use crate::metrics::labels as metric_labels;
|
||||
use crate::router::orchestrator_model_v1;
|
||||
use crate::session_cache::SessionCache;
|
||||
|
||||
pub use crate::session_cache::CachedRoute;
|
||||
|
||||
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
|
||||
|
||||
pub struct OrchestratorService {
|
||||
orchestrator_url: String,
|
||||
client: reqwest::Client,
|
||||
orchestrator_model: Arc<dyn OrchestratorModel>,
|
||||
orchestrator_provider_name: String,
|
||||
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
session_cache: Option<Arc<dyn SessionCache>>,
|
||||
session_ttl: Duration,
|
||||
tenant_header: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OrchestrationError {
|
||||
#[error(transparent)]
|
||||
Http(#[from] http::HttpError),
|
||||
|
||||
#[error("Orchestrator model error: {0}")]
|
||||
OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OrchestrationError>;
|
||||
|
||||
impl OrchestratorService {
|
||||
pub fn new(
|
||||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
orchestrator_provider_name: String,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
|
||||
HashMap::new(),
|
||||
orchestration_model_name,
|
||||
max_token_length,
|
||||
));
|
||||
|
||||
OrchestratorService {
|
||||
orchestrator_url,
|
||||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
orchestrator_provider_name,
|
||||
top_level_preferences: HashMap::new(),
|
||||
metrics_service: None,
|
||||
session_cache: None,
|
||||
session_ttl: Duration::from_secs(DEFAULT_SESSION_TTL_SECONDS),
|
||||
tenant_header: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn with_routing(
|
||||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
orchestrator_provider_name: String,
|
||||
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
session_ttl_seconds: Option<u64>,
|
||||
session_cache: Arc<dyn SessionCache>,
|
||||
tenant_header: Option<String>,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
|
||||
.map_or_else(HashMap::new, |prefs| {
|
||||
prefs.into_iter().map(|p| (p.name.clone(), p)).collect()
|
||||
});
|
||||
|
||||
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
|
||||
HashMap::new(),
|
||||
orchestration_model_name,
|
||||
max_token_length,
|
||||
));
|
||||
|
||||
let session_ttl =
|
||||
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
|
||||
|
||||
OrchestratorService {
|
||||
orchestrator_url,
|
||||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
orchestrator_provider_name,
|
||||
top_level_preferences,
|
||||
metrics_service,
|
||||
session_cache: Some(session_cache),
|
||||
session_ttl,
|
||||
tenant_header,
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Session cache methods ----
|
||||
|
||||
#[must_use]
|
||||
pub fn tenant_header(&self) -> Option<&str> {
|
||||
self.tenant_header.as_deref()
|
||||
}
|
||||
|
||||
fn session_key<'a>(tenant_id: Option<&str>, session_id: &'a str) -> Cow<'a, str> {
|
||||
match tenant_id {
|
||||
Some(t) => Cow::Owned(format!("{t}:{session_id}")),
|
||||
None => Cow::Borrowed(session_id),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_cached_route(
|
||||
&self,
|
||||
session_id: &str,
|
||||
tenant_id: Option<&str>,
|
||||
) -> Option<CachedRoute> {
|
||||
let cache = self.session_cache.as_ref()?;
|
||||
let result = cache.get(&Self::session_key(tenant_id, session_id)).await;
|
||||
bs_metrics::record_session_cache_event(if result.is_some() {
|
||||
metric_labels::SESSION_CACHE_HIT
|
||||
} else {
|
||||
metric_labels::SESSION_CACHE_MISS
|
||||
});
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn cache_route(
|
||||
&self,
|
||||
session_id: String,
|
||||
tenant_id: Option<&str>,
|
||||
model_name: String,
|
||||
route_name: Option<String>,
|
||||
) {
|
||||
if let Some(ref cache) = self.session_cache {
|
||||
cache
|
||||
.put(
|
||||
&Self::session_key(tenant_id, &session_id),
|
||||
CachedRoute {
|
||||
model_name,
|
||||
route_name,
|
||||
},
|
||||
self.session_ttl,
|
||||
)
|
||||
.await;
|
||||
bs_metrics::record_session_cache_event(metric_labels::SESSION_CACHE_STORE);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- LLM routing ----
|
||||
|
||||
pub async fn determine_route(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
request_id: &str,
|
||||
) -> Result<Option<(String, Vec<String>)>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let inline_top_map: Option<HashMap<String, TopLevelRoutingPreference>> =
|
||||
inline_routing_preferences
|
||||
.map(|prefs| prefs.into_iter().map(|p| (p.name.clone(), p)).collect());
|
||||
|
||||
if inline_top_map.is_none() && self.top_level_preferences.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let effective_source = inline_top_map
|
||||
.as_ref()
|
||||
.unwrap_or(&self.top_level_preferences);
|
||||
|
||||
let effective_prefs: Vec<AgentUsagePreference> = effective_source
|
||||
.values()
|
||||
.map(|p| AgentUsagePreference {
|
||||
model: p.models.first().cloned().unwrap_or_default(),
|
||||
orchestration_preferences: vec![OrchestrationPreference {
|
||||
name: p.name.clone(),
|
||||
description: p.description.clone(),
|
||||
}],
|
||||
})
|
||||
.collect();
|
||||
|
||||
let orchestration_result = self
|
||||
.determine_orchestration(
|
||||
messages,
|
||||
Some(effective_prefs),
|
||||
Some(request_id.to_string()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let result = if let Some(ref routes) = orchestration_result {
|
||||
if routes.len() > 1 {
|
||||
let all_routes: Vec<&str> = routes.iter().map(|(name, _)| name.as_str()).collect();
|
||||
info!(
|
||||
routes = ?all_routes,
|
||||
using = %all_routes.first().unwrap_or(&"none"),
|
||||
"plano-orchestrator detected multiple intents, using first"
|
||||
);
|
||||
}
|
||||
|
||||
if let Some((route_name, _)) = routes.first() {
|
||||
let top_pref = inline_top_map
|
||||
.as_ref()
|
||||
.and_then(|m| m.get(route_name))
|
||||
.or_else(|| self.top_level_preferences.get(route_name));
|
||||
|
||||
if let Some(pref) = top_pref {
|
||||
let ranked = match &self.metrics_service {
|
||||
Some(svc) => svc.rank_models(&pref.models, &pref.selection_policy).await,
|
||||
None => pref.models.clone(),
|
||||
};
|
||||
Some((route_name.clone(), ranked))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
info!(
|
||||
selected_model = ?result,
|
||||
"plano-orchestrator determined route"
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// ---- Agent orchestration (existing) ----
|
||||
|
||||
pub async fn determine_orchestration(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: Option<Vec<AgentUsagePreference>>,
|
||||
request_id: Option<String>,
|
||||
) -> Result<Option<Vec<(String, String)>>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if usage_preferences
|
||||
.as_ref()
|
||||
.is_none_or(|prefs| prefs.is_empty())
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let orchestrator_request = self
|
||||
.orchestrator_model
|
||||
.generate_request(messages, &usage_preferences);
|
||||
|
||||
debug!(
|
||||
model = %self.orchestrator_model.get_model_name(),
|
||||
endpoint = %self.orchestrator_url,
|
||||
"sending request to plano-orchestrator"
|
||||
);
|
||||
|
||||
let body = serde_json::to_string(&orchestrator_request)
|
||||
.map_err(super::orchestrator_model::OrchestratorModelError::from)?;
|
||||
debug!(body = %body, "plano-orchestrator request");
|
||||
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
header::HeaderValue::from_str(&self.orchestrator_provider_name)
|
||||
.unwrap_or_else(|_| header::HeaderValue::from_static("plano-orchestrator")),
|
||||
);
|
||||
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut headers));
|
||||
});
|
||||
|
||||
if let Some(ref request_id) = request_id {
|
||||
if let Ok(val) = header::HeaderValue::from_str(request_id) {
|
||||
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
|
||||
}
|
||||
}
|
||||
|
||||
headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_static("plano-orchestrator"),
|
||||
);
|
||||
|
||||
let Some((content, elapsed)) =
|
||||
post_and_extract_content(&self.client, &self.orchestrator_url, headers, body).await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let parsed = self
|
||||
.orchestrator_model
|
||||
.parse_response(&content, &usage_preferences)?;
|
||||
|
||||
info!(
|
||||
content = %content.replace("\n", "\\n"),
|
||||
selected_routes = ?parsed,
|
||||
response_time_ms = elapsed.as_millis(),
|
||||
"plano-orchestrator determined routes"
|
||||
);
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::session_cache::memory::MemorySessionCache;
|
||||
|
||||
fn make_orchestrator_service(ttl_seconds: u64, max_entries: usize) -> OrchestratorService {
|
||||
let session_cache = Arc::new(MemorySessionCache::new(max_entries));
|
||||
OrchestratorService::with_routing(
|
||||
"http://localhost:12001/v1/chat/completions".to_string(),
|
||||
"Plano-Orchestrator".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
None,
|
||||
None,
|
||||
Some(ttl_seconds),
|
||||
session_cache,
|
||||
None,
|
||||
orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_miss_returns_none() {
|
||||
let svc = make_orchestrator_service(600, 100);
|
||||
assert!(svc
|
||||
.get_cached_route("unknown-session", None)
|
||||
.await
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_hit_returns_cached_route() {
|
||||
let svc = make_orchestrator_service(600, 100);
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
None,
|
||||
"gpt-4o".to_string(),
|
||||
Some("code".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let cached = svc.get_cached_route("s1", None).await.unwrap();
|
||||
assert_eq!(cached.model_name, "gpt-4o");
|
||||
assert_eq!(cached.route_name, Some("code".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_expired_entry_returns_none() {
|
||||
let svc = make_orchestrator_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_expired_entries_not_returned() {
|
||||
let svc = make_orchestrator_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), None, "claude".to_string(), None)
|
||||
.await;
|
||||
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
assert!(svc.get_cached_route("s2", None).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_evicts_oldest_when_full() {
|
||||
let svc = make_orchestrator_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), None, "model-a".to_string(), None)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
svc.cache_route("s2".to_string(), None, "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route("s3".to_string(), None, "model-c".to_string(), None)
|
||||
.await;
|
||||
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
assert!(svc.get_cached_route("s2", None).await.is_some());
|
||||
assert!(svc.get_cached_route("s3", None).await.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_update_existing_session_does_not_evict() {
|
||||
let svc = make_orchestrator_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), None, "model-a".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), None, "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
None,
|
||||
"model-a-updated".to_string(),
|
||||
Some("route".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let s1 = svc.get_cached_route("s1", None).await.unwrap();
|
||||
assert_eq!(s1.model_name, "model-a-updated");
|
||||
assert!(svc.get_cached_route("s2", None).await.is_some());
|
||||
}
|
||||
}
|
||||
|
|
@ -11,8 +11,7 @@ pub enum OrchestratorModelError {
|
|||
pub type Result<T> = std::result::Result<T, OrchestratorModelError>;
|
||||
|
||||
/// OrchestratorModel trait for handling orchestration requests.
|
||||
/// Unlike RouterModel which returns a single route, OrchestratorModel
|
||||
/// can return multiple routes as the model output format is:
|
||||
/// Returns multiple routes as the model output format is:
|
||||
/// {"route": ["route_name_1", "route_name_2", ...]}
|
||||
pub trait OrchestratorModel: Send + Sync {
|
||||
fn generate_request(
|
||||
|
|
|
|||
|
|
@ -8,7 +8,19 @@ use tracing::{debug, warn};
|
|||
|
||||
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
|
||||
|
||||
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the orchestration model
|
||||
pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model
|
||||
|
||||
/// Hard cap on the number of recent messages considered when building the
|
||||
/// routing prompt. Bounds prompt growth for long-running conversations and
|
||||
/// acts as an outer guardrail before the token-budget loop runs. The most
|
||||
/// recent `MAX_ROUTING_TURNS` filtered messages are kept; older turns are
|
||||
/// dropped entirely.
|
||||
pub const MAX_ROUTING_TURNS: usize = 16;
|
||||
|
||||
/// Unicode ellipsis used to mark where content was trimmed out of a long
|
||||
/// message. Helps signal to the downstream router model that the message was
|
||||
/// truncated.
|
||||
const TRIM_MARKER: &str = "…";
|
||||
|
||||
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
|
||||
struct SpacedJsonFormatter;
|
||||
|
|
@ -176,47 +188,82 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
messages: &[Message],
|
||||
usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>,
|
||||
) -> ChatCompletionsRequest {
|
||||
// remove system prompt, tool calls, tool call response and messages without content
|
||||
// if content is empty its likely a tool call
|
||||
// when role == tool its tool call response
|
||||
let messages_vec = messages
|
||||
// Remove system/developer/tool messages and messages without extractable
|
||||
// text (tool calls have no text content we can classify against).
|
||||
let filtered: Vec<&Message> = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System
|
||||
&& m.role != Role::Developer
|
||||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
.collect();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
|
||||
// Outer guardrail: only consider the last `MAX_ROUTING_TURNS` filtered
|
||||
// messages when building the routing prompt. Keeps prompt growth
|
||||
// predictable for long conversations regardless of per-message size.
|
||||
let start = filtered.len().saturating_sub(MAX_ROUTING_TURNS);
|
||||
let messages_vec: &[&Message] = &filtered[start..];
|
||||
|
||||
// Ensure the conversation does not exceed the configured token budget.
|
||||
// We use `len() / TOKEN_LENGTH_DIVISOR` as a cheap token estimate to
|
||||
// avoid running a real tokenizer on the hot path.
|
||||
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
let mut selected_messages_list_reversed: Vec<Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
let message_text = message.content.extract_text();
|
||||
let message_token_count = message_text.len() / TOKEN_LENGTH_DIVISOR;
|
||||
if token_count + message_token_count > self.max_token_length {
|
||||
let remaining_tokens = self.max_token_length.saturating_sub(token_count);
|
||||
debug!(
|
||||
token_count = token_count,
|
||||
attempted_total_tokens = token_count + message_token_count,
|
||||
max_tokens = self.max_token_length,
|
||||
remaining_tokens,
|
||||
selected = selected_messsage_count,
|
||||
total = messages_vec.len(),
|
||||
"token count exceeds max, truncating conversation"
|
||||
);
|
||||
if message.role == Role::User {
|
||||
// If message that exceeds max token length is from user, we need to keep it
|
||||
selected_messages_list_reversed.push(message);
|
||||
// If the overflow message is from the user we need to keep
|
||||
// some of it so the orchestrator still sees the latest user
|
||||
// intent. Use a middle-trim (head + ellipsis + tail): users
|
||||
// often frame the task at the start AND put the actual ask
|
||||
// at the end of a long pasted block, so preserving both is
|
||||
// better than a head-only cut. The ellipsis also signals to
|
||||
// the router model that content was dropped.
|
||||
if message.role == Role::User && remaining_tokens > 0 {
|
||||
let max_bytes = remaining_tokens.saturating_mul(TOKEN_LENGTH_DIVISOR);
|
||||
let truncated = trim_middle_utf8(&message_text, max_bytes);
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: Role::User,
|
||||
content: Some(MessageContent::Text(truncated)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
// If we are here, it means that the message is within the max token length
|
||||
selected_messages_list_reversed.push(message);
|
||||
token_count += message_token_count;
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: message.role.clone(),
|
||||
content: Some(MessageContent::Text(message_text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
if selected_messages_list_reversed.is_empty() {
|
||||
debug!("no messages selected, using last message");
|
||||
if let Some(last_message) = messages_vec.last() {
|
||||
selected_messages_list_reversed.push(last_message);
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: last_message.role.clone(),
|
||||
content: Some(MessageContent::Text(last_message.content.extract_text())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -236,22 +283,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
}
|
||||
|
||||
// Reverse the selected messages to maintain the conversation order
|
||||
let selected_conversation_list = selected_messages_list_reversed
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|message| Message {
|
||||
role: message.role.clone(),
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
let selected_conversation_list: Vec<Message> =
|
||||
selected_messages_list_reversed.into_iter().rev().collect();
|
||||
|
||||
// Generate the orchestrator request message based on the usage preferences.
|
||||
// If preferences are passed in request then we use them;
|
||||
|
|
@ -404,6 +437,45 @@ fn fix_json_response(body: &str) -> String {
|
|||
body.replace("'", "\"").replace("\\n", "")
|
||||
}
|
||||
|
||||
/// Truncate `s` so the result is at most `max_bytes` bytes long, keeping
|
||||
/// roughly 60% from the start and 40% from the end, with a Unicode ellipsis
|
||||
/// separating the two. All splits respect UTF-8 character boundaries. When
|
||||
/// `max_bytes` is too small to fit the marker at all, falls back to a
|
||||
/// head-only truncation.
|
||||
fn trim_middle_utf8(s: &str, max_bytes: usize) -> String {
|
||||
if s.len() <= max_bytes {
|
||||
return s.to_string();
|
||||
}
|
||||
if max_bytes <= TRIM_MARKER.len() {
|
||||
// Not enough room even for the marker — just keep the start.
|
||||
let mut end = max_bytes;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
return s[..end].to_string();
|
||||
}
|
||||
|
||||
let available = max_bytes - TRIM_MARKER.len();
|
||||
// Bias toward the start (60%) where task framing typically lives, while
|
||||
// still preserving ~40% of the tail where the user's actual ask often
|
||||
// appears after a long paste.
|
||||
let mut start_len = available * 3 / 5;
|
||||
while start_len > 0 && !s.is_char_boundary(start_len) {
|
||||
start_len -= 1;
|
||||
}
|
||||
let end_len = available - start_len;
|
||||
let mut end_start = s.len().saturating_sub(end_len);
|
||||
while end_start < s.len() && !s.is_char_boundary(end_start) {
|
||||
end_start += 1;
|
||||
}
|
||||
|
||||
let mut out = String::with_capacity(start_len + TRIM_MARKER.len() + (s.len() - end_start));
|
||||
out.push_str(&s[..start_len]);
|
||||
out.push_str(TRIM_MARKER);
|
||||
out.push_str(&s[end_start..]);
|
||||
out
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn OrchestratorModel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "OrchestratorModel")
|
||||
|
|
@ -776,6 +848,10 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
#[test]
|
||||
fn test_conversation_trim_upto_user_message() {
|
||||
// With max_token_length=230, the older user message "given the image
|
||||
// In style of Andy Warhol" overflows the remaining budget and gets
|
||||
// middle-trimmed (head + ellipsis + tail) until it fits. Newer turns
|
||||
// are kept in full.
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant that selects the most suitable routes based on user intent.
|
||||
You are provided with a list of available routes enclosed within <routes></routes> XML tags:
|
||||
|
|
@ -788,7 +864,7 @@ You are also given the conversation context enclosed within <conversation></conv
|
|||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol"
|
||||
"content": "given…rhol"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
|
|
@ -861,6 +937,190 @@ If no routes are needed, return an empty list for `route`.
|
|||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_huge_single_user_message_is_middle_trimmed() {
|
||||
// Regression test for the case where a single, extremely large user
|
||||
// message was being passed to the orchestrator verbatim and blowing
|
||||
// past the upstream model's context window. The trimmer must now
|
||||
// middle-trim (head + ellipsis + tail) the oversized message so the
|
||||
// resulting request stays within the configured budget, and the
|
||||
// trim marker must be present so the router model knows content
|
||||
// was dropped.
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let max_token_length = 2048;
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
"test-model".to_string(),
|
||||
max_token_length,
|
||||
);
|
||||
|
||||
// ~500KB of content — same scale as the real payload that triggered
|
||||
// the production upstream 400.
|
||||
let head = "HEAD_MARKER_START ";
|
||||
let tail = " TAIL_MARKER_END";
|
||||
let filler = "A".repeat(500_000);
|
||||
let huge_user_content = format!("{head}{filler}{tail}");
|
||||
|
||||
let conversation = vec![Message {
|
||||
role: Role::User,
|
||||
content: Some(MessageContent::Text(huge_user_content.clone())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
// Prompt must stay bounded. Generous ceiling = budget-in-bytes +
|
||||
// scaffolding + slack. Real result should be well under this.
|
||||
let byte_ceiling = max_token_length * TOKEN_LENGTH_DIVISOR
|
||||
+ ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len()
|
||||
+ 1024;
|
||||
assert!(
|
||||
prompt.len() < byte_ceiling,
|
||||
"prompt length {} exceeded ceiling {} — truncation did not apply",
|
||||
prompt.len(),
|
||||
byte_ceiling,
|
||||
);
|
||||
|
||||
// Not all 500k filler chars survive.
|
||||
let a_count = prompt.chars().filter(|c| *c == 'A').count();
|
||||
assert!(
|
||||
a_count < filler.len(),
|
||||
"expected user message to be truncated; all {} 'A's survived",
|
||||
a_count
|
||||
);
|
||||
assert!(
|
||||
a_count > 0,
|
||||
"expected some of the user message to survive truncation"
|
||||
);
|
||||
|
||||
// Head and tail of the message must both be preserved (that's the
|
||||
// whole point of middle-trim over head-only).
|
||||
assert!(
|
||||
prompt.contains(head),
|
||||
"head marker missing — head was not preserved"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains(tail),
|
||||
"tail marker missing — tail was not preserved"
|
||||
);
|
||||
|
||||
// Trim marker must be present so the router model can see that
|
||||
// content was omitted.
|
||||
assert!(
|
||||
prompt.contains(TRIM_MARKER),
|
||||
"ellipsis trim marker missing from truncated prompt"
|
||||
);
|
||||
|
||||
// Routing prompt scaffolding remains intact.
|
||||
assert!(prompt.contains("<conversation>"));
|
||||
assert!(prompt.contains("<routes>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_turn_cap_limits_routing_history() {
|
||||
// The outer turn-cap guardrail should keep only the last
|
||||
// `MAX_ROUTING_TURNS` filtered messages regardless of how long the
|
||||
// conversation is. We build a conversation with alternating
|
||||
// user/assistant turns tagged with their index and verify that only
|
||||
// the tail of the conversation makes it into the prompt.
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), usize::MAX);
|
||||
|
||||
let mut conversation: Vec<Message> = Vec::new();
|
||||
let total_turns = MAX_ROUTING_TURNS * 2; // well past the cap
|
||||
for i in 0..total_turns {
|
||||
let role = if i % 2 == 0 {
|
||||
Role::User
|
||||
} else {
|
||||
Role::Assistant
|
||||
};
|
||||
conversation.push(Message {
|
||||
role,
|
||||
content: Some(MessageContent::Text(format!("turn-{i:03}"))),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
// The last MAX_ROUTING_TURNS messages (indexes total-cap..total)
|
||||
// must all appear.
|
||||
for i in (total_turns - MAX_ROUTING_TURNS)..total_turns {
|
||||
let tag = format!("turn-{i:03}");
|
||||
assert!(
|
||||
prompt.contains(&tag),
|
||||
"expected recent turn tag {tag} to be present"
|
||||
);
|
||||
}
|
||||
|
||||
// And earlier turns (indexes 0..total-cap) must all be dropped.
|
||||
for i in 0..(total_turns - MAX_ROUTING_TURNS) {
|
||||
let tag = format!("turn-{i:03}");
|
||||
assert!(
|
||||
!prompt.contains(&tag),
|
||||
"old turn tag {tag} leaked past turn cap into the prompt"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trim_middle_utf8_helper() {
|
||||
// No-op when already small enough.
|
||||
assert_eq!(trim_middle_utf8("hello", 100), "hello");
|
||||
assert_eq!(trim_middle_utf8("hello", 5), "hello");
|
||||
|
||||
// 60/40 split with ellipsis when too long.
|
||||
let long = "a".repeat(20);
|
||||
let out = trim_middle_utf8(&long, 10);
|
||||
assert!(out.len() <= 10);
|
||||
assert!(out.contains(TRIM_MARKER));
|
||||
// Exactly one ellipsis, rest are 'a's.
|
||||
assert_eq!(out.matches(TRIM_MARKER).count(), 1);
|
||||
assert!(out.chars().filter(|c| *c == 'a').count() > 0);
|
||||
|
||||
// When max_bytes is smaller than the marker, falls back to
|
||||
// head-only truncation (no marker).
|
||||
let out = trim_middle_utf8("abcdefgh", 2);
|
||||
assert_eq!(out, "ab");
|
||||
|
||||
// UTF-8 boundary safety: 2-byte chars.
|
||||
let s = "é".repeat(50); // 100 bytes
|
||||
let out = trim_middle_utf8(&s, 25);
|
||||
assert!(out.len() <= 25);
|
||||
// Must still be valid UTF-8 that only contains 'é' and the marker.
|
||||
let ok = out.chars().all(|c| c == 'é' || c == '…');
|
||||
assert!(ok, "unexpected char in trimmed output: {out:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_text_input() {
|
||||
let expected_prompt = r#"
|
||||
|
|
|
|||
|
|
@ -1,168 +0,0 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use common::{
|
||||
configuration::{AgentUsagePreference, OrchestrationPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, PLANO_ORCHESTRATOR_MODEL_NAME, REQUEST_ID_HEADER},
|
||||
};
|
||||
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
|
||||
use hyper::header;
|
||||
use opentelemetry::global;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::orchestrator_model_v1::{self};
|
||||
|
||||
use super::orchestrator_model::OrchestratorModel;
|
||||
|
||||
pub struct OrchestratorService {
|
||||
orchestrator_url: String,
|
||||
client: reqwest::Client,
|
||||
orchestrator_model: Arc<dyn OrchestratorModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OrchestrationError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
RequestError(#[from] reqwest::Error),
|
||||
|
||||
#[error("Failed to parse JSON: {0}, JSON: {1}")]
|
||||
JsonError(serde_json::Error, String),
|
||||
|
||||
#[error("Orchestrator model error: {0}")]
|
||||
OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OrchestrationError>;
|
||||
|
||||
impl OrchestratorService {
|
||||
pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self {
|
||||
// Empty agent orchestrations - will be provided via usage_preferences in requests
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
|
||||
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model_name.clone(),
|
||||
orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
OrchestratorService {
|
||||
orchestrator_url,
|
||||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn determine_orchestration(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: Option<Vec<AgentUsagePreference>>,
|
||||
request_id: Option<String>,
|
||||
) -> Result<Option<Vec<(String, String)>>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Require usage_preferences to be provided
|
||||
if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let orchestrator_request = self
|
||||
.orchestrator_model
|
||||
.generate_request(messages, &usage_preferences);
|
||||
|
||||
debug!(
|
||||
model = %self.orchestrator_model.get_model_name(),
|
||||
endpoint = %self.orchestrator_url,
|
||||
"sending request to arch-orchestrator"
|
||||
);
|
||||
|
||||
debug!(
|
||||
body = %serde_json::to_string(&orchestrator_request).unwrap(),
|
||||
"arch orchestrator request"
|
||||
);
|
||||
|
||||
let mut orchestration_request_headers = header::HeaderMap::new();
|
||||
orchestration_request_headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(),
|
||||
);
|
||||
|
||||
// Inject OpenTelemetry trace context from current span
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut orchestration_request_headers));
|
||||
});
|
||||
|
||||
if let Some(request_id) = request_id {
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static(REQUEST_ID_HEADER),
|
||||
header::HeaderValue::from_str(&request_id).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME),
|
||||
);
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let res = self
|
||||
.client
|
||||
.post(&self.orchestrator_url)
|
||||
.headers(orchestration_request_headers)
|
||||
.body(serde_json::to_string(&orchestrator_request).unwrap())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body = res.text().await?;
|
||||
let orchestrator_response_time = start_time.elapsed();
|
||||
|
||||
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
error = %err,
|
||||
body = %serde_json::to_string(&body).unwrap(),
|
||||
"failed to parse json response"
|
||||
);
|
||||
return Err(OrchestrationError::JsonError(
|
||||
err,
|
||||
format!("Failed to parse JSON: {}", body),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if chat_completion_response.choices.is_empty() {
|
||||
warn!(body = %body, "no choices in orchestrator response");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(content) = &chat_completion_response.choices[0].message.content {
|
||||
let parsed_response = self
|
||||
.orchestrator_model
|
||||
.parse_response(content, &usage_preferences)?;
|
||||
info!(
|
||||
content = %content.replace("\n", "\\n"),
|
||||
selected_routes = ?parsed_response,
|
||||
response_time_ms = orchestrator_response_time.as_millis(),
|
||||
"arch-orchestrator determined routes"
|
||||
);
|
||||
|
||||
if let Some(ref parsed_response) = parsed_response {
|
||||
return Ok(Some(parsed_response.clone()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingModelError {
|
||||
#[error("Failed to parse JSON: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||
|
||||
pub trait RouterModel: Send + Sync {
|
||||
fn generate_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> ChatCompletionsRequest;
|
||||
fn parse_response(
|
||||
&self,
|
||||
content: &str,
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> Result<Option<(String, String)>>;
|
||||
fn get_model_name(&self) -> String;
|
||||
}
|
||||
|
|
@ -1,841 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use common::configuration::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hermesllm::transforms::lib::ExtractText;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::router_model::{RouterModel, RoutingModelError};
|
||||
|
||||
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the routing model
|
||||
pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
{routes}
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
{conversation}
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||
pub struct RouterModelV1 {
|
||||
llm_route_json_str: String,
|
||||
llm_route_to_model_map: HashMap<String, String>,
|
||||
routing_model: String,
|
||||
max_token_length: usize,
|
||||
}
|
||||
impl RouterModelV1 {
|
||||
pub fn new(
|
||||
llm_routes: HashMap<String, Vec<RoutingPreference>>,
|
||||
routing_model: String,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let llm_route_values: Vec<RoutingPreference> =
|
||||
llm_routes.values().flatten().cloned().collect();
|
||||
let llm_route_json_str =
|
||||
serde_json::to_string(&llm_route_values).unwrap_or_else(|_| "[]".to_string());
|
||||
let llm_route_to_model_map: HashMap<String, String> = llm_routes
|
||||
.iter()
|
||||
.flat_map(|(model, prefs)| prefs.iter().map(|pref| (pref.name.clone(), model.clone())))
|
||||
.collect();
|
||||
|
||||
RouterModelV1 {
|
||||
routing_model,
|
||||
max_token_length,
|
||||
llm_route_json_str,
|
||||
llm_route_to_model_map,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct LlmRouterResponse {
|
||||
pub route: Option<String>,
|
||||
}
|
||||
|
||||
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
|
||||
|
||||
impl RouterModel for RouterModelV1 {
|
||||
fn generate_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences_from_request: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> ChatCompletionsRequest {
|
||||
// remove system prompt, tool calls, tool call response and messages without content
|
||||
// if content is empty its likely a tool call
|
||||
// when role == tool its tool call response
|
||||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System
|
||||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
|
||||
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
token_count = token_count,
|
||||
max_tokens = self.max_token_length,
|
||||
selected = selected_messsage_count,
|
||||
total = messages_vec.len(),
|
||||
"token count exceeds max, truncating conversation"
|
||||
);
|
||||
if message.role == Role::User {
|
||||
// If message that exceeds max token length is from user, we need to keep it
|
||||
selected_messages_list_reversed.push(message);
|
||||
}
|
||||
break;
|
||||
}
|
||||
// If we are here, it means that the message is within the max token length
|
||||
selected_messages_list_reversed.push(message);
|
||||
}
|
||||
|
||||
if selected_messages_list_reversed.is_empty() {
|
||||
debug!("no messages selected, using last message");
|
||||
if let Some(last_message) = messages_vec.last() {
|
||||
selected_messages_list_reversed.push(last_message);
|
||||
}
|
||||
}
|
||||
|
||||
// ensure that first and last selected message is from user
|
||||
if let Some(first_message) = selected_messages_list_reversed.first() {
|
||||
if first_message.role != Role::User {
|
||||
warn!("last message is not from user, may lead to incorrect routing");
|
||||
}
|
||||
}
|
||||
if let Some(last_message) = selected_messages_list_reversed.last() {
|
||||
if last_message.role != Role::User {
|
||||
warn!("first message is not from user, may lead to incorrect routing");
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse the selected messages to maintain the conversation order
|
||||
let selected_conversation_list = selected_messages_list_reversed
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|message| {
|
||||
Message {
|
||||
role: message.role.clone(),
|
||||
// we can unwrap here because we have already filtered out messages without content
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
||||
// Generate the router request message based on the usage preferences.
|
||||
// If preferences are passed in request then we use them otherwise we use the default routing model preferences.
|
||||
let router_message = match convert_to_router_preferences(usage_preferences_from_request) {
|
||||
Some(prefs) => generate_router_message(&prefs, &selected_conversation_list),
|
||||
None => generate_router_message(&self.llm_route_json_str, &selected_conversation_list),
|
||||
};
|
||||
|
||||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: Some(MessageContent::Text(router_message)),
|
||||
role: Role::User,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: Some(0.01),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(
|
||||
&self,
|
||||
content: &str,
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> Result<Option<(String, String)>> {
|
||||
if content.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
let router_resp_fixed = fix_json_response(content);
|
||||
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
|
||||
|
||||
let selected_route = router_response.route.unwrap_or_default().to_string();
|
||||
|
||||
if selected_route.is_empty() || selected_route == "other" {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(usage_preferences) = usage_preferences {
|
||||
// If usage preferences are defined, we need to find the model that matches the selected route
|
||||
let model_name: Option<String> = usage_preferences
|
||||
.iter()
|
||||
.map(|pref| {
|
||||
pref.routing_preferences
|
||||
.iter()
|
||||
.find(|routing_pref| routing_pref.name == selected_route)
|
||||
.map(|_| pref.model.clone())
|
||||
})
|
||||
.find_map(|model| model);
|
||||
|
||||
if let Some(model_name) = model_name {
|
||||
return Ok(Some((selected_route, model_name)));
|
||||
} else {
|
||||
warn!(
|
||||
route = %selected_route,
|
||||
preferences = ?usage_preferences,
|
||||
"no matching model found for route"
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
// If no usage preferences are passed in request then use the default routing model preferences
|
||||
if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() {
|
||||
return Ok(Some((selected_route, model)));
|
||||
}
|
||||
|
||||
warn!(
|
||||
route = %selected_route,
|
||||
preferences = ?self.llm_route_to_model_map,
|
||||
"no model found for route"
|
||||
);
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn get_model_name(&self) -> String {
|
||||
self.routing_model.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_router_message(prefs: &str, selected_conversation_list: &Vec<Message>) -> String {
|
||||
ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
.replace("{routes}", prefs)
|
||||
.replace(
|
||||
"{conversation}",
|
||||
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn convert_to_router_preferences(
|
||||
prefs_from_request: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> Option<String> {
|
||||
if let Some(usage_preferences) = prefs_from_request {
|
||||
let routing_preferences = usage_preferences
|
||||
.iter()
|
||||
.flat_map(|pref| {
|
||||
pref.routing_preferences
|
||||
.iter()
|
||||
.map(|routing_pref| RoutingPreference {
|
||||
name: routing_pref.name.clone(),
|
||||
description: routing_pref.description.clone(),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<RoutingPreference>>();
|
||||
|
||||
return Some(serde_json::to_string(&routing_preferences).unwrap_or_default());
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn fix_json_response(body: &str) -> String {
|
||||
let mut updated_body = body.to_string();
|
||||
|
||||
updated_body = updated_body.replace("'", "\"");
|
||||
|
||||
if updated_body.contains("\\n") {
|
||||
updated_body = updated_body.replace("\\n", "");
|
||||
}
|
||||
|
||||
if updated_body.starts_with("```json") {
|
||||
updated_body = updated_body
|
||||
.strip_prefix("```json")
|
||||
.unwrap_or(&updated_body)
|
||||
.to_string();
|
||||
}
|
||||
|
||||
if updated_body.ends_with("```") {
|
||||
updated_body = updated_body
|
||||
.strip_suffix("```")
|
||||
.unwrap_or(&updated_body)
|
||||
.to_string();
|
||||
}
|
||||
|
||||
updated_body
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn RouterModel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "RouterModel")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_format() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_format_usage_preferences() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"code-generation","description":"generating new code snippets, functions, or boilerplate based on user prompts or requirements"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let usage_preferences = Some(vec![ModelUsagePreference {
|
||||
model: "claude/claude-3-7-sonnet".to_string(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: "code-generation".to_string(),
|
||||
description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(),
|
||||
}],
|
||||
}]);
|
||||
let req = router.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_exceed_token_count() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 235);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_exceed_token_count_large_single_message() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson and this is a very long message that exceeds the max token length of the routing model, so it should be truncated and only the last user message should be included in the conversation for routing."}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 200);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson and this is a very long message that exceeds the max token length of the routing model, so it should be truncated and only the last user message should be included in the conversation for routing."
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_trim_upto_user_message() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"given the image In style of Andy Warhol"},{"role":"assistant","content":"ok here is the image"},{"role":"user","content":"pls give me another image about Bart and Lisa"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 230);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "ok here is the image"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "pls give me another image about Bart and Lisa"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_text_input() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image.png"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_tool_call() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"What's the weather like in Tokyo?"},{"role":"assistant","content":"The current weather in Tokyo is 22°C and sunny."},{"role":"user","content":"What about in New York?"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Tokyo?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "toolcall-abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{ \"location\": \"Tokyo\" }"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "toolcall-abc123",
|
||||
"content": "{ \"temperature\": \"22°C\", \"condition\": \"Sunny\" }"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The current weather in Tokyo is 22°C and sunny."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What about in New York?"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
// expects conversation to look like this
|
||||
|
||||
// [
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "What's the weather like in Tokyo?"
|
||||
// },
|
||||
// {
|
||||
// "role": "assistant",
|
||||
// "content": "The current weather in Tokyo is 22°C and sunny."
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "What about in New York?"
|
||||
// }
|
||||
// ]
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req: ChatCompletionsRequest = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
|
||||
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
|
||||
|
||||
// Case 1: Valid JSON with non-empty route
|
||||
let input = r#"{"route": "Image generation"}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
|
||||
// Case 2: Valid JSON with empty route
|
||||
let input = r#"{"route": ""}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 3: Valid JSON with null route
|
||||
let input = r#"{"route": null}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 4: JSON missing route field
|
||||
let input = r#"{}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 4.1: empty string
|
||||
let input = r#""#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 5: Malformed JSON
|
||||
let input = r#"{"route": "route1""#; // missing closing }
|
||||
let result = router.parse_response(input, &None);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Case 6: Single quotes and \n in JSON
|
||||
let input = "{'route': 'Image generation'}\\n";
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
|
||||
// Case 7: Code block marker
|
||||
let input = "```json\n{\"route\": \"Image generation\"}\n```";
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
}
|
||||
}
|
||||
260
crates/brightstaff/src/router/stress_tests.rs
Normal file
260
crates/brightstaff/src/router/stress_tests.rs
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::session_cache::memory::MemorySessionCache;
|
||||
use common::configuration::{SelectionPolicy, SelectionPreference, TopLevelRoutingPreference};
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn make_messages(n: usize) -> Vec<Message> {
|
||||
(0..n)
|
||||
.map(|i| Message {
|
||||
role: if i % 2 == 0 {
|
||||
Role::User
|
||||
} else {
|
||||
Role::Assistant
|
||||
},
|
||||
content: Some(MessageContent::Text(format!(
|
||||
"This is message number {i} with some padding text to make it realistic."
|
||||
))),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn make_routing_prefs() -> Vec<TopLevelRoutingPreference> {
|
||||
vec![
|
||||
TopLevelRoutingPreference {
|
||||
name: "code_generation".to_string(),
|
||||
description: "Code generation and debugging tasks".to_string(),
|
||||
models: vec![
|
||||
"openai/gpt-4o".to_string(),
|
||||
"openai/gpt-4o-mini".to_string(),
|
||||
],
|
||||
selection_policy: SelectionPolicy {
|
||||
prefer: SelectionPreference::None,
|
||||
},
|
||||
},
|
||||
TopLevelRoutingPreference {
|
||||
name: "summarization".to_string(),
|
||||
description: "Summarizing documents and text".to_string(),
|
||||
models: vec![
|
||||
"anthropic/claude-3-sonnet".to_string(),
|
||||
"openai/gpt-4o-mini".to_string(),
|
||||
],
|
||||
selection_policy: SelectionPolicy {
|
||||
prefer: SelectionPreference::None,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/// Stress test: exercise the full routing code path N times using a mock
|
||||
/// HTTP server and measure jemalloc allocated bytes before/after.
|
||||
///
|
||||
/// This catches:
|
||||
/// - Memory leaks in generate_request / parse_response
|
||||
/// - Leaks in reqwest connection handling
|
||||
/// - String accumulation in the orchestrator model
|
||||
/// - Fragmentation (jemalloc allocated vs resident)
|
||||
#[tokio::test]
|
||||
async fn stress_test_routing_determine_route() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
let router_url = format!("{}/v1/chat/completions", server.url());
|
||||
|
||||
let mock_response = serde_json::json!({
|
||||
"id": "chatcmpl-mock",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "plano-orchestrator",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "{\"route\": \"code_generation\"}"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 10, "total_tokens": 110}
|
||||
});
|
||||
|
||||
let _mock = server
|
||||
.mock("POST", "/v1/chat/completions")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(mock_response.to_string())
|
||||
.expect_at_least(1)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let prefs = make_routing_prefs();
|
||||
let session_cache = Arc::new(MemorySessionCache::new(1000));
|
||||
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
|
||||
router_url,
|
||||
"Plano-Orchestrator".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
Some(prefs.clone()),
|
||||
None,
|
||||
None,
|
||||
session_cache,
|
||||
None,
|
||||
2048,
|
||||
));
|
||||
|
||||
// Warm up: a few requests to stabilize allocator state
|
||||
for _ in 0..10 {
|
||||
let msgs = make_messages(5);
|
||||
let _ = orchestrator_service
|
||||
.determine_route(&msgs, None, "warmup")
|
||||
.await;
|
||||
}
|
||||
|
||||
// Snapshot memory after warmup
|
||||
let baseline = get_allocated();
|
||||
|
||||
let num_iterations = 2000;
|
||||
|
||||
for i in 0..num_iterations {
|
||||
let msgs = make_messages(5 + (i % 10));
|
||||
let inline = if i % 3 == 0 {
|
||||
Some(make_routing_prefs())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let _ = orchestrator_service
|
||||
.determine_route(&msgs, inline, &format!("req-{i}"))
|
||||
.await;
|
||||
}
|
||||
|
||||
let after = get_allocated();
|
||||
|
||||
let growth = after.saturating_sub(baseline);
|
||||
let growth_mb = growth as f64 / (1024.0 * 1024.0);
|
||||
let per_request = growth.checked_div(num_iterations).unwrap_or(0);
|
||||
|
||||
eprintln!("=== Routing Stress Test Results ===");
|
||||
eprintln!(" Iterations: {num_iterations}");
|
||||
eprintln!(" Baseline alloc: {} bytes", baseline);
|
||||
eprintln!(" Final alloc: {} bytes", after);
|
||||
eprintln!(" Growth: {} bytes ({growth_mb:.2} MB)", growth);
|
||||
eprintln!(" Per-request: {} bytes", per_request);
|
||||
|
||||
// Allow up to 256 bytes per request of retained growth (connection pool, etc.)
|
||||
// A true leak would show thousands of bytes per request.
|
||||
assert!(
|
||||
per_request < 256,
|
||||
"Possible memory leak: {per_request} bytes/request retained after {num_iterations} iterations"
|
||||
);
|
||||
}
|
||||
|
||||
/// Stress test with high concurrency: many parallel determine_route calls.
|
||||
#[tokio::test]
|
||||
async fn stress_test_routing_concurrent() {
|
||||
let mut server = mockito::Server::new_async().await;
|
||||
let router_url = format!("{}/v1/chat/completions", server.url());
|
||||
|
||||
let mock_response = serde_json::json!({
|
||||
"id": "chatcmpl-mock",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "plano-orchestrator",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "{\"route\": \"summarization\"}"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 10, "total_tokens": 110}
|
||||
});
|
||||
|
||||
let _mock = server
|
||||
.mock("POST", "/v1/chat/completions")
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(mock_response.to_string())
|
||||
.expect_at_least(1)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let prefs = make_routing_prefs();
|
||||
let session_cache = Arc::new(MemorySessionCache::new(1000));
|
||||
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
|
||||
router_url,
|
||||
"Plano-Orchestrator".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
Some(prefs),
|
||||
None,
|
||||
None,
|
||||
session_cache,
|
||||
None,
|
||||
2048,
|
||||
));
|
||||
|
||||
// Warm up
|
||||
for _ in 0..20 {
|
||||
let msgs = make_messages(3);
|
||||
let _ = orchestrator_service
|
||||
.determine_route(&msgs, None, "warmup")
|
||||
.await;
|
||||
}
|
||||
|
||||
let baseline = get_allocated();
|
||||
|
||||
let concurrency = 50;
|
||||
let requests_per_task = 100;
|
||||
let total = concurrency * requests_per_task;
|
||||
|
||||
let mut handles = vec![];
|
||||
for t in 0..concurrency {
|
||||
let svc = Arc::clone(&orchestrator_service);
|
||||
let handle = tokio::spawn(async move {
|
||||
for r in 0..requests_per_task {
|
||||
let msgs = make_messages(3 + (r % 8));
|
||||
let _ = svc
|
||||
.determine_route(&msgs, None, &format!("req-{t}-{r}"))
|
||||
.await;
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for h in handles {
|
||||
h.await.unwrap();
|
||||
}
|
||||
|
||||
let after = get_allocated();
|
||||
let growth = after.saturating_sub(baseline);
|
||||
let per_request = growth / total;
|
||||
|
||||
eprintln!("=== Concurrent Routing Stress Test Results ===");
|
||||
eprintln!(" Tasks: {concurrency} x {requests_per_task} = {total}");
|
||||
eprintln!(" Baseline: {} bytes", baseline);
|
||||
eprintln!(" Final: {} bytes", after);
|
||||
eprintln!(
|
||||
" Growth: {} bytes ({:.2} MB)",
|
||||
growth,
|
||||
growth as f64 / 1_048_576.0
|
||||
);
|
||||
eprintln!(" Per-request: {} bytes", per_request);
|
||||
|
||||
assert!(
|
||||
per_request < 512,
|
||||
"Possible memory leak under concurrency: {per_request} bytes/request retained after {total} requests"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "jemalloc")]
|
||||
fn get_allocated() -> usize {
|
||||
tikv_jemalloc_ctl::epoch::advance().unwrap();
|
||||
tikv_jemalloc_ctl::stats::allocated::read().unwrap_or(0)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "jemalloc"))]
|
||||
fn get_allocated() -> usize {
|
||||
0
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue